fishy.experiments.pre_training¶
Pre-training strategies for models on mass spectrometry data.
This module implements various pre-training tasks such as masked spectra modeling, next spectra prediction, peak detection, denoising autoencoding, peak parameter regression, segment reordering, and contrastive invariance learning. These tasks are designed to leverage unlabeled or semi-labeled mass spectrometry data to learn useful representations for downstream tasks.
Classes
- class fishy.experiments.pre_training.PreTrainer(model: Module, config: PreTrainingConfig, optimizer: Optimizer | None = None, logger: Logger | None = None)[source]¶
Bases:
objectHandles the execution of various self-supervised pre-training tasks.
Examples
>>> import torch.nn as nn >>> model = nn.Linear(10, 10) >>> config = PreTrainingConfig(n_features=10, device='cpu') >>> trainer = PreTrainer(model, config) >>> isinstance(trainer.model, nn.Module) True
- Parameters:
model (nn.Module) – The model to be pre-trained.
config (PreTrainingConfig) – Configuration parameters for the tasks.
optimizer (Optional[torch.optim.Optimizer]) – Optimizer to use. If None, AdamW is initialized.
logger (Optional[logging.Logger]) – Logger instance.
- __init__(model: Module, config: PreTrainingConfig, optimizer: Optimizer | None = None, logger: Logger | None = None) None[source]¶
- pre_train_contrastive_invariance(train_loader: DataLoader, val_loader: DataLoader | None = None, temperature: float = 0.1, embedding_dim: int = 128) float[source]¶
Pre-trains the model using Contrastive Transformation Invariance Learning (CTIL). Uses NT-Xent loss to maximize similarity between differently augmented views of the same spectrum.
- pre_train_denoising_autoencoder(train_loader: DataLoader, val_loader: DataLoader | None = None, noise_std_dev: float = 0.1, mask_point_prob: float = 0.05) Module[source]¶
Pre-trains the model using Spectrum Denoising Autoencoding (SDA). Model learns to reconstruct clean spectra from inputs corrupted by noise and point masking.
- pre_train_masked_spectra(train_loader: DataLoader) Module[source]¶
Pre-trains the model using Masked Spectra Modelling (MSM). The model attempts to reconstruct contiguous chunks of masked spectral data.
- class fishy.experiments.pre_training.PreTrainingConfig(num_epochs: int = 100, file_path: str = 'transformer_checkpoint.pth', n_features: int = 2080, chunk_size: int = 50, device: str | device = device(type='cpu'), learning_rate: float = 0.0001, noise_enabled: bool = False, shift_enabled: bool = False, scale_enabled: bool = False, crop_enabled: bool = False, flip_enabled: bool = False, permutation_enabled: bool = False, crop_size: float = 0.8)[source]¶
Bases:
objectConfiguration for pre-training tasks.
Examples
>>> config = PreTrainingConfig(num_epochs=5, n_features=100) >>> config.num_epochs == 5 True >>> config.n_features == 100 True
- num_epochs¶
Number of epochs to train for each task.
- Type:
int
- file_path¶
Path to save the model checkpoints.
- Type:
str
- n_features¶
Number of input features (spectrum length).
- Type:
int
- chunk_size¶
Size of the contiguous mask for MSM.
- Type:
int
- device¶
Device to run training on.
- Type:
Device
- learning_rate¶
Learning rate for the optimizer.
- Type:
float
- noise_enabled¶
Enable noise during augmentation for CTIL.
- Type:
bool
- shift_enabled¶
Enable shift during augmentation for CTIL.
- Type:
bool
- scale_enabled¶
Enable scale during augmentation for CTIL.
- Type:
bool
- crop_enabled¶
Enable crop during augmentation for CTIL.
- Type:
bool
- flip_enabled¶
Enable flip during augmentation for CTIL.
- Type:
bool
- permutation_enabled¶
Enable permutation during augmentation for CTIL.
- Type:
bool
- crop_size¶
Portion of spectrum to keep when cropping.
- Type:
float
- __init__(num_epochs: int = 100, file_path: str = 'transformer_checkpoint.pth', n_features: int = 2080, chunk_size: int = 50, device: str | device = device(type='cpu'), learning_rate: float = 0.0001, noise_enabled: bool = False, shift_enabled: bool = False, scale_enabled: bool = False, crop_enabled: bool = False, flip_enabled: bool = False, permutation_enabled: bool = False, crop_size: float = 0.8) None[source]¶
- chunk_size: int = 50¶
- crop_enabled: bool = False¶
- crop_size: float = 0.8¶
- device: str | device = device(type='cpu')¶
- file_path: str = 'transformer_checkpoint.pth'¶
- flip_enabled: bool = False¶
- learning_rate: float = 0.0001¶
- n_features: int = 2080¶
- noise_enabled: bool = False¶
- num_epochs: int = 100¶
- permutation_enabled: bool = False¶
- scale_enabled: bool = False¶
- shift_enabled: bool = False¶
- class fishy.experiments.pre_training.PreTrainingOrchestrator(config: TrainingConfig, device: device, input_dim: int, ctx: RunContext, logger: Logger | None = None)[source]¶
Bases:
objectHandles the orchestration of multiple self-supervised pre-training tasks.
Uses external configuration (pre_training.yaml) to define tasks and their hyperparameters. Supports weight chaining between sequential tasks.
Examples
>>> from fishy._core.config import TrainingConfig >>> from fishy._core.utils import RunContext >>> cfg = TrainingConfig() >>> ctx = RunContext("ds", "method", "model") INFO ... Initialized RunContext: model on ds... >>> orch = PreTrainingOrchestrator(cfg, torch.device('cpu'), 10, ctx) >>> orch.input_dim == 10 True
- config¶
Global training configuration.
- Type:
TrainingConfig
- device¶
Computation device.
- Type:
torch.device
- input_dim¶
Dimensionality of input spectra.
- Type:
int
- ctx¶
Context for logging and checkpointing.
- Type:
RunContext
- logger¶
Logger instance.
- Type:
logging.Logger
- task_configs¶
List of task definitions from config.
- Type:
List[Dict]
- __init__(config: TrainingConfig, device: device, input_dim: int, ctx: RunContext, logger: Logger | None = None) None[source]¶
Initializes the PreTrainingOrchestrator.
- Parameters:
config (TrainingConfig) – Configuration object.
device (torch.device) – Computing device.
input_dim (int) – Input feature dimension.
ctx (RunContext) – Experiment context.
logger (Optional[logging.Logger], optional) – Custom logger. Defaults to None.
- adapt_for_finetuning(model: Module, pre_trained_model: Module) None[source]¶
Adapates a pre-trained model for fine-tuning by loading compatible weights.
- Parameters:
model (nn.Module) – The target model for fine-tuning.
pre_trained_model (nn.Module) – The model containing pre-trained weights.
- run_all(train_loader: DataLoader, val_loader: DataLoader | None = None) Module | None[source]¶
Runs all enabled pre-training tasks sequentially.
Weights are chained from one task to the next if layers match.
- Parameters:
train_loader (DataLoader) – Loader for the training data.
val_loader (Optional[DataLoader], optional) – Loader for validation. Defaults to None.
- Returns:
The model after all pre-training tasks, or None if none enabled.
- Return type:
Optional[nn.Module]
Functions
- fishy.experiments.pre_training.mask_spectra_side(input_spectra: T_Tensor, side: str = 'left') T_Tensor[source]¶
Masks either the left or right side of the input spectra.
Examples
>>> import torch >>> x = torch.ones(4) >>> mask_spectra_side(x, "left").tolist() [0.0, 0.0, 1.0, 1.0] >>> mask_spectra_side(x, "right").tolist() [1.0, 1.0, 0.0, 0.0]
- Parameters:
input_spectra (T_Tensor) – The input spectra tensor to mask.
side (str) – The side to mask, either ‘left’ or ‘right’.
- Returns:
The masked spectra tensor.
- Return type:
T_Tensor
s