fishy.models.deep.gatedmlp

GatedMLP: A pure gated MLP baseline for REIMS spectral data.

The 1D Transformer applied to REIMS data uses seq_len=1 (the entire spectrum is a single token). With a single token, self-attention degenerates:

softmax(Q K^T / sqrt(d)) with shape [B, H, 1, 1] = 1.0 always

The attention output is therefore just V = x W_v, a plain linear transform. The transformer is then functionally equivalent to:

[Linear → FFN] × num_layers → pool → fc_out

GatedMLP makes this explicit: it applies modern building blocks (RMSNorm, SwiGLU, residual connections, dropout) directly in input_dim space with no large embedding projection. This isolates the contribution of the architecture from any large fixed embedding.

Classes

class fishy.models.deep.gatedmlp.GatedMLP(input_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 4, dropout: float = 0.3, embed_dim: int = 512, **kwargs)[source]

Bases: Module, TTTMixin

Gated MLP baseline.

Operates directly in input_dim space, matching the effective computation of the Transformer on single-token REIMS spectra (where attention is degenerate and the FFN does all the real work).

Set embed_dim > 0 to add a projection to a fixed embedding space first.

__init__(input_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 4, dropout: float = 0.3, embed_dim: int = 512, **kwargs) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, return_attention: bool = False, *args, **kwargs) Tensor | Tuple[Tensor, List[Tensor]][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fishy.models.deep.gatedmlp.GatedMLPBlock(dim: int, hidden_dim: int, dropout: float = 0.3)[source]

Bases: Module

RMSNorm → SwiGLU → Dropout → Residual, operating in input_dim space.

__init__(dim: int, hidden_dim: int, dropout: float = 0.3)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class fishy.models.deep.gatedmlp.RMSNorm(dim: int, eps: float = 1e-06)[source]

Bases: Module

__init__(dim: int, eps: float = 1e-06)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

s