fishy.engine.trainer

Trainer module for encapsulating the training loop and related logic.

Classes

class fishy.engine.trainer.DeepEngine[source]

Bases: object

High-level engine for deep learning experiments.

static evaluate_model(model, loader, criterion, device=device(type='cpu'), **kwargs)[source]
static train_model(model, train_loader, criterion, optimizer, num_epochs=100, patience=20, n_splits=5, n_runs=30, is_augmented=False, device=device(type='cpu'), val_loader=None, use_coral=False, use_cumulative_link=False, num_classes=None, regression=False, use_ttt=False, ttt_lr=0.001, ttt_steps=1, use_ema=False, ema_decay=0.999, scheduler=None, step_scheduler=False, ctx=None)[source]
static transfer_learning(dataset_name, model, file_path)[source]
class fishy.engine.trainer.EMA(model: Module, decay: float = 0.999)[source]

Bases: object

Exponential Moving Average of model weights with decay ramp-up.

__init__(model: Module, decay: float = 0.999)[source]
apply_shadow()[source]
register()[source]
restore()[source]
update()[source]
class fishy.engine.trainer.Trainer(model: Module, criterion: Module, optimizer: Optimizer, device: device, num_epochs: int, scheduler: Any | None = None, patience: int = 20, use_coral: bool = False, use_cumulative_link: bool = False, num_classes: int | None = None, regression: bool = False, use_ttt: bool = False, ttt_lr: float = 0.001, ttt_steps: int = 1, use_ema: bool = False, ema_decay: float = 0.999, step_scheduler: bool = False, ctx: RunContext | None = None, logger: Logger | None = None)[source]

Bases: object

Unified trainer for PyTorch models with Rich integration.

__init__(model: Module, criterion: Module, optimizer: Optimizer, device: device, num_epochs: int, scheduler: Any | None = None, patience: int = 20, use_coral: bool = False, use_cumulative_link: bool = False, num_classes: int | None = None, regression: bool = False, use_ttt: bool = False, ttt_lr: float = 0.001, ttt_steps: int = 1, use_ema: bool = False, ema_decay: float = 0.999, step_scheduler: bool = False, ctx: RunContext | None = None, logger: Logger | None = None)[source]
evaluate(loader: DataLoader) Dict[str, Any][source]
train(train_loader: DataLoader, val_loader: DataLoader | None = None) Dict[str, Any][source]

s