fishy.models.deep.nextformer

NextFormer model for spectral analysis.

The NextFormer Architecture features: 1. RMSNorm: Replaces standard LayerNorm. It operates strictly on the variance (omitting the mean centering)

for faster computation while maintaining/exceeding training stability.

  1. Pre-Norm Architecture: Normalization is applied before the attention and feed-forward blocks, creating an unimpeded residual pathway (crucial for deep networks).

  2. Grouped Query Attention (GQA): Replaces Multi-Head Attention (MHA). By sharing Keys and Values across groups of Query heads, it reduces KV cache memory usage and accelerates generation.

  3. SwiGLU Feed-Forward Networks: Introduces a gating mechanism (SiLU(xW1) * xW3)W2 that consistently outperforms traditional MLPs.

  4. Bias-free Linear Layers: All internal linear layers omit biases for training stability.

Classes

class fishy.models.deep.nextformer.GroupedQueryAttention(dim: int, n_heads: int, n_kv_heads: int)[source]

Bases: Module

Grouped Query Attention (GQA) mechanism.

n_heads

Total number of query heads.

Type:

int

n_kv_heads

Total number of key/value heads.

Type:

int

n_rep

Number of times KV heads are repeated.

Type:

int

head_dim

Dimension of each head.

Type:

int

wq

Query projection.

Type:

nn.Linear

wk

Key projection.

Type:

nn.Linear

wv

Value projection.

Type:

nn.Linear

wo

Output projection.

Type:

nn.Linear

__init__(dim: int, n_heads: int, n_kv_heads: int) None[source]

Initializes GQA.

Parameters:
  • dim (int) – Input dimension.

  • n_heads (int) – Number of query heads.

  • n_kv_heads (int) – Number of KV heads.

forward(x: Tensor, return_attention: bool = False) Tensor | Tuple[Tensor, Tensor][source]

Forward pass for GQA.

class fishy.models.deep.nextformer.NextFormer(input_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 4, dropout: float = 0.1, num_heads: int = 8, num_kv_heads: int = 2, **kwargs)[source]

Bases: Module

NextFormer Architecture for 1D spectral data.

blocks

Stack of NextFormer blocks.

Type:

nn.ModuleList

norm

Final output normalization.

Type:

RMSNorm

fc_out

Output classification/regression head.

Type:

nn.Linear

__init__(input_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 4, dropout: float = 0.1, num_heads: int = 8, num_kv_heads: int = 2, **kwargs) None[source]

Initializes the NextFormer model.

Parameters:
  • input_dim (int) – Number of input features.

  • output_dim (int) – Number of output classes/dimensions.

  • hidden_dim (int, optional) – Intermediate dimension of the feed-forward layer. Defaults to 128.

  • num_layers (int, optional) – Number of transformer blocks. Defaults to 4.

  • dropout (float, optional) – Dropout rate. Defaults to 0.1.

  • num_heads (int, optional) – Number of attention heads. Defaults to 8.

  • num_kv_heads (int, optional) – Number of KV heads for GQA. Defaults to 2.

forward(x: Tensor, return_attention: bool = False, *args, **kwargs) Tensor | Tuple[Tensor, List[Tensor]][source]

Forward pass.

class fishy.models.deep.nextformer.NextFormerBlock(dim: int, n_heads: int, n_kv_heads: int, hidden_dim: int, dropout: float)[source]

Bases: Module

A single NextFormer block.

attention

GQA layer.

Type:

GroupedQueryAttention

feed_forward

SwiGLU FFN layer.

Type:

SwiGLU

attention_norm

Norm before attention.

Type:

RMSNorm

ffn_norm

Norm before FFN.

Type:

RMSNorm

dropout

Dropout for regularization.

Type:

nn.Dropout

__init__(dim: int, n_heads: int, n_kv_heads: int, hidden_dim: int, dropout: float) None[source]

Initializes the block.

forward(x: Tensor, return_attention: bool = False) Tensor | Tuple[Tensor, Tensor][source]

Forward pass with pre-norm architecture and residual connections.

class fishy.models.deep.nextformer.RMSNorm(dim: int, eps: float = 1e-06)[source]

Bases: Module

Root Mean Square Layer Normalization.

eps

A small value added to the denominator for numerical stability.

Type:

float

weight

Learnable scaling parameter.

Type:

nn.Parameter

__init__(dim: int, eps: float = 1e-06) None[source]

Initializes the RMSNorm layer.

Parameters:
  • dim (int) – Input dimension.

  • eps (float, optional) – Epsilon for numerical stability. Defaults to 1e-6.

forward(x: Tensor) Tensor[source]

Forward pass.

class fishy.models.deep.nextformer.SwiGLU(dim: int, hidden_dim: int)[source]

Bases: Module

SwiGLU activation function / Feed-Forward Network.

w1

First linear layer for gating.

Type:

nn.Linear

w2

Output projection layer.

Type:

nn.Linear

w3

Gating value layer.

Type:

nn.Linear

__init__(dim: int, hidden_dim: int) None[source]

Initializes the SwiGLU layer.

Parameters:
  • dim (int) – Input/output dimension.

  • hidden_dim (int) – Intermediate dimension.

forward(x: Tensor) Tensor[source]

Forward pass applying the SwiGLU gating mechanism.

Functions

fishy.models.deep.nextformer.repeat_kv(x: Tensor, n_rep: int) Tensor[source]

Repeats the Key and Value heads for Grouped Query Attention.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch, n_kv_heads, seq_len, head_dim).

  • n_rep (int) – Number of times to repeat each head.

Returns:

Repeated tensor of shape (batch, n_heads, seq_len, head_dim).

Return type:

torch.Tensor

s