fishy.models.deep.ensemble

Multi-Scale/Resolution Ensemble of Transformers.

This model ensembles three Transformer architectures with varying depths and attention heads: 1. Small: 2 layers, 2 heads 2. Medium: 4 layers, 4 heads 3. Large: 8 layers, 8 heads

The outputs are concatenated and passed through a final classification/regression head.

Classes

class fishy.models.deep.ensemble.MultiScaleTransformerEnsemble(input_dim: int, output_dim: int, hidden_dim: int = 128, dropout: float = 0.1, **kwargs)[source]

Bases: Module

An ensemble of Transformers with different scales/resolutions.

experts

List of Transformer models with different configurations.

Type:

nn.ModuleList

classifier

Final classification/regression head.

Type:

nn.Sequential

__init__(input_dim: int, output_dim: int, hidden_dim: int = 128, dropout: float = 0.1, **kwargs) None[source]

Initializes the MultiScaleTransformerEnsemble model.

Parameters:
  • input_dim (int) – Number of input features.

  • output_dim (int) – Number of output classes/dimensions.

  • hidden_dim (int, optional) – Hidden dimension for transformers. Defaults to 128.

  • dropout (float, optional) – Dropout rate. Defaults to 0.1.

forward(x: Tensor, *args, **kwargs) Tensor[source]

Forward pass.

Parameters:

x (torch.Tensor) – Input spectrum of shape (batch_size, input_dim).

Returns:

Ensemble output.

Return type:

torch.Tensor

s