fishy.engine.losses¶
Functions
- fishy.engine.losses.coral_loss(logits: Tensor, levels: Tensor, importance_weights: Tensor | None = None, reduction: str = 'mean') Tensor[source]¶
Computes the CORAL (Consistent Rank Logits) loss for ordinal regression.
Source: https://github.com/Raschka-research-group/corn-ordinal-regression/blob/main/coral_pytorch/losses.py
- Parameters:
logits (torch.Tensor) – Outputs of the CORAL layer, shape (n_examples, n_classes-1).
levels (torch.Tensor) – True labels represented as extended binary vectors (via levels_from_labelbatch), shape (n_examples, n_classes-1).
importance_weights (Optional[torch.Tensor], optional) – Weights for the examples in the batch. Defaults to None.
reduction (str, optional) – Reduction to apply to the loss (‘mean’, ‘sum’, or None). Defaults to “mean”.
- Returns:
The computed loss value (scalar if reduction is not None, else vector).
- Return type:
torch.Tensor
- Raises:
ValueError – If logits and levels shapes do not match.
ValueError – If reduction is not one of ‘mean’, ‘sum’, or None.
Examples
>>> import torch >>> logits = torch.tensor([[10.0, 5.0], [-5.0, -10.0]]) >>> levels = torch.tensor([[1.0, 1.0], [0.0, 0.0]]) >>> loss = coral_loss(logits, levels) >>> float(loss) < 0.1 True >>> loss_sum = coral_loss(logits, levels, reduction='sum') >>> float(loss_sum) > float(loss) True
- fishy.engine.losses.cumulative_link_loss(logits: Tensor, labels: Tensor, num_classes: int, reduction: str = 'mean') Tensor[source]¶
Computes the Cumulative Link loss for ordinal regression.
This loss treats ordinal regression as K-1 binary classification problems. For each class boundary, it predicts whether the true label is beyond that boundary.
- Parameters:
logits (torch.Tensor) – Model outputs, shape (n_examples, n_classes-1).
labels (torch.Tensor) – True integer labels, shape (n_examples,).
num_classes (int) – The total number of ordinal classes.
reduction (str, optional) – Type of reduction to apply (‘mean’, ‘sum’, ‘none’). Defaults to “mean”.
- Returns:
The computed loss value.
- Return type:
torch.Tensor
- Raises:
ValueError – If shape mismatch between logits and derived cumulative labels.
Examples
>>> import torch >>> logits = torch.tensor([[10.0, 10.0], [-10.0, -10.0]]) >>> labels = torch.tensor([2, 0]) >>> loss = cumulative_link_loss(logits, labels, num_classes=3) >>> float(loss) < 0.1 True
- fishy.engine.losses.levels_from_labelbatch(labels: Tensor, num_classes: int, dtype: dtype = None) Tensor[source]¶
Converts a batch of integer labels to extended binary levels for CORAL.
For example, with 5 classes, label 2 becomes [1, 1, 0, 0]. Vectorized implementation.
- Parameters:
labels (torch.Tensor) – Batch of integer labels.
num_classes (int) – Total number of ordinal classes.
dtype (torch.dtype, optional) – Desired dtype of the output tensor. Defaults to None.
- Returns:
Binary levels tensor of shape (batch_size, num_classes - 1).
- Return type:
torch.Tensor
Examples
>>> import torch >>> labels = torch.tensor([0, 1, 2]) >>> levels = levels_from_labelbatch(labels, num_classes=4, dtype=torch.float32) >>> levels tensor([[0., 0., 0.], [1., 0., 0.], [1., 1., 0.]])
s