fishy.experiments.pre_training

Pre-training strategies for models on mass spectrometry data.

This module implements various pre-training tasks such as masked spectra modeling, next spectra prediction, peak detection, denoising autoencoding, peak parameter regression, segment reordering, and contrastive invariance learning. These tasks are designed to leverage unlabeled or semi-labeled mass spectrometry data to learn useful representations for downstream tasks.

Classes

class fishy.experiments.pre_training.PreTrainer(model: Module, config: PreTrainingConfig, optimizer: Optimizer | None = None, logger: Logger | None = None)[source]

Bases: object

Handles the execution of various self-supervised pre-training tasks.

Examples

>>> import torch.nn as nn
>>> model = nn.Linear(10, 10)
>>> config = PreTrainingConfig(n_features=10, device='cpu')
>>> trainer = PreTrainer(model, config)
>>> isinstance(trainer.model, nn.Module)
True
Parameters:
  • model (nn.Module) – The model to be pre-trained.

  • config (PreTrainingConfig) – Configuration parameters for the tasks.

  • optimizer (Optional[torch.optim.Optimizer]) – Optimizer to use. If None, AdamW is initialized.

  • logger (Optional[logging.Logger]) – Logger instance.

__init__(model: Module, config: PreTrainingConfig, optimizer: Optimizer | None = None, logger: Logger | None = None) None[source]
pre_train_contrastive_invariance(train_loader: DataLoader, val_loader: DataLoader | None = None, temperature: float = 0.1, embedding_dim: int = 128) float[source]

Pre-trains the model using Contrastive Transformation Invariance Learning (CTIL). Uses NT-Xent loss to maximize similarity between differently augmented views of the same spectrum.

pre_train_denoising_autoencoder(train_loader: DataLoader, val_loader: DataLoader | None = None, noise_std_dev: float = 0.1, mask_point_prob: float = 0.05) Module[source]

Pre-trains the model using Spectrum Denoising Autoencoding (SDA). Model learns to reconstruct clean spectra from inputs corrupted by noise and point masking.

pre_train_masked_spectra(train_loader: DataLoader) Module[source]

Pre-trains the model using Masked Spectra Modelling (MSM). The model attempts to reconstruct contiguous chunks of masked spectral data.

pre_train_next_spectra(train_loader: DataLoader, val_loader: DataLoader) Module[source]

Pre-trains the model using Next Spectra Prediction (NSP). Learns to distinguish between related (anchor-positive) and unrelated (anchor-negative) masked views.

pre_train_peak_prediction(train_loader: DataLoader, val_loader: DataLoader | None = None, peak_threshold: float = 0.1, window_size: int = 5) Module[source]

Pre-trains the model for Peak Prediction. Point-wise binary classification task to identify spectral peaks.

class fishy.experiments.pre_training.PreTrainingConfig(num_epochs: int = 100, file_path: str = 'transformer_checkpoint.pth', n_features: int = 2080, chunk_size: int = 50, device: str | device = device(type='cpu'), learning_rate: float = 0.0001, noise_enabled: bool = False, shift_enabled: bool = False, scale_enabled: bool = False, crop_enabled: bool = False, flip_enabled: bool = False, permutation_enabled: bool = False, crop_size: float = 0.8)[source]

Bases: object

Configuration for pre-training tasks.

Examples

>>> config = PreTrainingConfig(num_epochs=5, n_features=100)
>>> config.num_epochs == 5
True
>>> config.n_features == 100
True
num_epochs

Number of epochs to train for each task.

Type:

int

file_path

Path to save the model checkpoints.

Type:

str

n_features

Number of input features (spectrum length).

Type:

int

chunk_size

Size of the contiguous mask for MSM.

Type:

int

device

Device to run training on.

Type:

Device

learning_rate

Learning rate for the optimizer.

Type:

float

noise_enabled

Enable noise during augmentation for CTIL.

Type:

bool

shift_enabled

Enable shift during augmentation for CTIL.

Type:

bool

scale_enabled

Enable scale during augmentation for CTIL.

Type:

bool

crop_enabled

Enable crop during augmentation for CTIL.

Type:

bool

flip_enabled

Enable flip during augmentation for CTIL.

Type:

bool

permutation_enabled

Enable permutation during augmentation for CTIL.

Type:

bool

crop_size

Portion of spectrum to keep when cropping.

Type:

float

__init__(num_epochs: int = 100, file_path: str = 'transformer_checkpoint.pth', n_features: int = 2080, chunk_size: int = 50, device: str | device = device(type='cpu'), learning_rate: float = 0.0001, noise_enabled: bool = False, shift_enabled: bool = False, scale_enabled: bool = False, crop_enabled: bool = False, flip_enabled: bool = False, permutation_enabled: bool = False, crop_size: float = 0.8) None[source]
chunk_size: int = 50
crop_enabled: bool = False
crop_size: float = 0.8
device: str | device = device(type='cpu')
file_path: str = 'transformer_checkpoint.pth'
flip_enabled: bool = False
learning_rate: float = 0.0001
n_features: int = 2080
noise_enabled: bool = False
num_epochs: int = 100
permutation_enabled: bool = False
scale_enabled: bool = False
shift_enabled: bool = False
class fishy.experiments.pre_training.PreTrainingOrchestrator(config: TrainingConfig, device: device, input_dim: int, ctx: RunContext, logger: Logger | None = None)[source]

Bases: object

Handles the orchestration of multiple self-supervised pre-training tasks.

Uses external configuration (pre_training.yaml) to define tasks and their hyperparameters. Supports weight chaining between sequential tasks.

Examples

>>> from fishy._core.config import TrainingConfig
>>> from fishy._core.utils import RunContext
>>> cfg = TrainingConfig()
>>> ctx = RunContext("ds", "method", "model")
INFO ... Initialized RunContext: model on ds...
>>> orch = PreTrainingOrchestrator(cfg, torch.device('cpu'), 10, ctx)
>>> orch.input_dim == 10
True
config

Global training configuration.

Type:

TrainingConfig

device

Computation device.

Type:

torch.device

input_dim

Dimensionality of input spectra.

Type:

int

ctx

Context for logging and checkpointing.

Type:

RunContext

logger

Logger instance.

Type:

logging.Logger

task_configs

List of task definitions from config.

Type:

List[Dict]

__init__(config: TrainingConfig, device: device, input_dim: int, ctx: RunContext, logger: Logger | None = None) None[source]

Initializes the PreTrainingOrchestrator.

Parameters:
  • config (TrainingConfig) – Configuration object.

  • device (torch.device) – Computing device.

  • input_dim (int) – Input feature dimension.

  • ctx (RunContext) – Experiment context.

  • logger (Optional[logging.Logger], optional) – Custom logger. Defaults to None.

adapt_for_finetuning(model: Module, pre_trained_model: Module) None[source]

Adapates a pre-trained model for fine-tuning by loading compatible weights.

Parameters:
  • model (nn.Module) – The target model for fine-tuning.

  • pre_trained_model (nn.Module) – The model containing pre-trained weights.

run_all(train_loader: DataLoader, val_loader: DataLoader | None = None) Module | None[source]

Runs all enabled pre-training tasks sequentially.

Weights are chained from one task to the next if layers match.

Parameters:
  • train_loader (DataLoader) – Loader for the training data.

  • val_loader (Optional[DataLoader], optional) – Loader for validation. Defaults to None.

Returns:

The model after all pre-training tasks, or None if none enabled.

Return type:

Optional[nn.Module]

Functions

fishy.experiments.pre_training.mask_spectra_side(input_spectra: T_Tensor, side: str = 'left') T_Tensor[source]

Masks either the left or right side of the input spectra.

Examples

>>> import torch
>>> x = torch.ones(4)
>>> mask_spectra_side(x, "left").tolist()
[0.0, 0.0, 1.0, 1.0]
>>> mask_spectra_side(x, "right").tolist()
[1.0, 1.0, 0.0, 0.0]
Parameters:
  • input_spectra (T_Tensor) – The input spectra tensor to mask.

  • side (str) – The side to mask, either ‘left’ or ‘right’.

Returns:

The masked spectra tensor.

Return type:

T_Tensor

s