fishy.models.deep.nextformer¶
NextFormer model for spectral analysis.
The NextFormer Architecture features: 1. RMSNorm: Replaces standard LayerNorm. It operates strictly on the variance (omitting the mean centering)
for faster computation while maintaining/exceeding training stability.
Pre-Norm Architecture: Normalization is applied before the attention and feed-forward blocks, creating an unimpeded residual pathway (crucial for deep networks).
Grouped Query Attention (GQA): Replaces Multi-Head Attention (MHA). By sharing Keys and Values across groups of Query heads, it reduces KV cache memory usage and accelerates generation.
SwiGLU Feed-Forward Networks: Introduces a gating mechanism (SiLU(xW1) * xW3)W2 that consistently outperforms traditional MLPs.
Bias-free Linear Layers: All internal linear layers omit biases for training stability.
Classes
- class fishy.models.deep.nextformer.GroupedQueryAttention(dim: int, n_heads: int, n_kv_heads: int)[source]¶
Bases:
ModuleGrouped Query Attention (GQA) mechanism.
- n_heads¶
Total number of query heads.
- Type:
int
- n_kv_heads¶
Total number of key/value heads.
- Type:
int
- n_rep¶
Number of times KV heads are repeated.
- Type:
int
- head_dim¶
Dimension of each head.
- Type:
int
- wq¶
Query projection.
- Type:
nn.Linear
- wk¶
Key projection.
- Type:
nn.Linear
- wv¶
Value projection.
- Type:
nn.Linear
- wo¶
Output projection.
- Type:
nn.Linear
- class fishy.models.deep.nextformer.NextFormer(input_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 4, dropout: float = 0.1, num_heads: int = 8, num_kv_heads: int = 2, **kwargs)[source]¶
Bases:
ModuleNextFormer Architecture for 1D spectral data.
- blocks¶
Stack of NextFormer blocks.
- Type:
nn.ModuleList
- fc_out¶
Output classification/regression head.
- Type:
nn.Linear
- __init__(input_dim: int, output_dim: int, hidden_dim: int = 128, num_layers: int = 4, dropout: float = 0.1, num_heads: int = 8, num_kv_heads: int = 2, **kwargs) None[source]¶
Initializes the NextFormer model.
- Parameters:
input_dim (int) – Number of input features.
output_dim (int) – Number of output classes/dimensions.
hidden_dim (int, optional) – Intermediate dimension of the feed-forward layer. Defaults to 128.
num_layers (int, optional) – Number of transformer blocks. Defaults to 4.
dropout (float, optional) – Dropout rate. Defaults to 0.1.
num_heads (int, optional) – Number of attention heads. Defaults to 8.
num_kv_heads (int, optional) – Number of KV heads for GQA. Defaults to 2.
- class fishy.models.deep.nextformer.NextFormerBlock(dim: int, n_heads: int, n_kv_heads: int, hidden_dim: int, dropout: float)[source]¶
Bases:
ModuleA single NextFormer block.
- attention¶
GQA layer.
- Type:
- dropout¶
Dropout for regularization.
- Type:
nn.Dropout
- class fishy.models.deep.nextformer.RMSNorm(dim: int, eps: float = 1e-06)[source]¶
Bases:
ModuleRoot Mean Square Layer Normalization.
- eps¶
A small value added to the denominator for numerical stability.
- Type:
float
- weight¶
Learnable scaling parameter.
- Type:
nn.Parameter
- class fishy.models.deep.nextformer.SwiGLU(dim: int, hidden_dim: int)[source]¶
Bases:
ModuleSwiGLU activation function / Feed-Forward Network.
- w1¶
First linear layer for gating.
- Type:
nn.Linear
- w2¶
Output projection layer.
- Type:
nn.Linear
- w3¶
Gating value layer.
- Type:
nn.Linear
Functions
- fishy.models.deep.nextformer.repeat_kv(x: Tensor, n_rep: int) Tensor[source]¶
Repeats the Key and Value heads for Grouped Query Attention.
- Parameters:
x (torch.Tensor) – Input tensor of shape (batch, n_kv_heads, seq_len, head_dim).
n_rep (int) – Number of times to repeat each head.
- Returns:
Repeated tensor of shape (batch, n_heads, seq_len, head_dim).
- Return type:
torch.Tensor
s