fishy.experiments.transferΒΆ
Transfer learning module for deep learning models. Standardized to use DataModule and Trainer patterns.
Functions
- fishy.experiments.transfer.run_sequential_transfer_learning(model_name: str, transfer_datasets: List[str], target_dataset: str, num_epochs_transfer: int = 10, num_epochs_finetune: int = 20, batch_size: int = 32, learning_rate: float = 0.001, finetune_lr: float = 0.0005, device: str = 'cpu', save_intermediate: bool = False, val_split: float = 0.2, file_path: str | None = None, wandb_project: str = 'fishy-business', wandb_entity: str = 'victoria-university-of-wellington', wandb_log: bool = False, run: int = 0, wandb_run: Any | None = None) Tuple[Module, Dict[str, Any]][source]ΒΆ
Performs sequential transfer learning using standardized DataModules.
Examples
>>> m_name = "transformer" >>> isinstance(m_name, str) True
- Parameters:
model_name (str) β Name of the model architecture.
transfer_datasets (List[str]) β List of datasets to pre-train on.
target_dataset (str) β Final dataset to fine-tune on.
num_epochs_transfer (int) β Epochs per transfer phase.
num_epochs_finetune (int) β Epochs for final phase.
batch_size (int) β Batch size.
learning_rate (float) β Initial learning rate.
finetune_lr (float) β Learning rate for fine-tuning.
device (str) β Computation device.
save_intermediate (bool) β Save checkpoints after each phase.
val_split (float) β Fraction of data for validation.
file_path (str) β Path to data file.
wandb_project (str) β W&B project name.
wandb_entity (str) β W&B entity.
wandb_log (bool) β Enable W&B logging.
run (int) β Run identifier/seed.
wandb_run (Any) β Existing W&B run.
- Returns:
Trained model and history.
- Return type:
Tuple[nn.Module, Dict[str, Any]]
s