fishy.models.deep.ordinal¶

Classes

class fishy.models.deep.ordinal.CumulativeLinkLoss(*args: Any, **kwargs: Any)[source]¶

Bases: Module

forward(cum_logits, y_true)[source]¶

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fishy.models.deep.ordinal.EarlyStopping(patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=<built-in function print>)[source]¶

Bases: object

Early stops the training if validation loss doesn’t improve after a given patience.

__init__(patience=5, verbose=False, delta=0, path='checkpoint.pt', trace_func=<built-in function print>)[source]¶
save_checkpoint(val_loss, model)[source]¶
class fishy.models.deep.ordinal.PositionalEncoding(d_model: int, dropout: float = 0.1, max_len: int = 5000)[source]¶

Bases: Module

Standard sinusoidal positional encoding for Transformers.

__init__(d_model: int, dropout: float = 0.1, max_len: int = 5000) None[source]¶

Initializes the positional encoding.

Parameters:
  • d_model (int) – The model dimension.

  • dropout (float, optional) – Dropout probability. Defaults to 0.1.

  • max_len (int, optional) – Maximum sequence length. Defaults to 5000.

forward(x: Tensor) Tensor[source]¶

Adds positional encoding to the input tensor.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Encoded tensor.

Return type:

torch.Tensor

class fishy.models.deep.ordinal.TransformerOrdinal(input_features: int, d_model: int, nhead: int, num_encoder_layers: int, dim_feedforward: int, dropout: float, method: str = 'regression', num_classes: int | None = None, batch_first: bool = True)[source]¶

Bases: Module

Transformer model specialized for ordinal regression and classification.

method¶

Task type (‘regression’, ‘classification’, ‘coral’, ‘clm’).

Type:

str

input_embedding¶

Initial linear projection.

Type:

nn.Linear

pos_encoder¶

Positional embeddings.

Type:

PositionalEncoding

transformer_encoder¶

Transformer backbone.

Type:

nn.TransformerEncoder

output_layer¶

Task-specific output head.

Type:

nn.Linear

__init__(input_features: int, d_model: int, nhead: int, num_encoder_layers: int, dim_feedforward: int, dropout: float, method: str = 'regression', num_classes: int | None = None, batch_first: bool = True) None[source]¶

Initializes the TransformerOrdinal model.

Parameters:
  • input_features (int) – Dimensionality of input features.

  • d_model (int) – Hidden dimension.

  • nhead (int) – Number of attention heads.

  • num_encoder_layers (int) – Number of transformer layers.

  • dim_feedforward (int) – Intermediate dimension of feed-forward layers.

  • dropout (float) – Dropout probability.

  • method (str, optional) – Training mode (‘regression’, ‘classification’, ‘coral’, ‘clm’). Defaults to “regression”.

  • num_classes (Optional[int], optional) – Number of target classes. Required for non-regression modes. Defaults to None.

  • batch_first (bool, optional) – If True, input is (B, S, D). Defaults to True.

forward(src)[source]¶

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Functions

fishy.models.deep.ordinal.convert_class_labels_to_integers(df)[source]¶
fishy.models.deep.ordinal.coral_logits_to_prediction(logits)[source]¶
fishy.models.deep.ordinal.evaluate(model, data_loader, device, method)[source]¶

Calculates MAE and Balanced Classification Accuracy (BCA).

fishy.models.deep.ordinal.filter_data_for_oil(df)[source]¶
fishy.models.deep.ordinal.get_dataloader(df, batch_size=32)[source]¶
fishy.models.deep.ordinal.label_to_coral_levels(labels, num_classes)[source]¶
fishy.models.deep.ordinal.read_reims_excel_file(fp)[source]¶
fishy.models.deep.ordinal.remove_first_two_characters(df)[source]¶
fishy.models.deep.ordinal.train(model, optimizer, criterion, train_df, val_df, epochs, device, batch_size, method, num_classes, patience, checkpoint_path, verbose=False)[source]¶

s