fishy.engine.losses

Functions

fishy.engine.losses.coral_loss(logits: Tensor, levels: Tensor, importance_weights: Tensor | None = None, reduction: str = 'mean') Tensor[source]

Computes the CORAL (Consistent Rank Logits) loss for ordinal regression.

Source: https://github.com/Raschka-research-group/corn-ordinal-regression/blob/main/coral_pytorch/losses.py

Parameters:
  • logits (torch.Tensor) – Outputs of the CORAL layer, shape (n_examples, n_classes-1).

  • levels (torch.Tensor) – True labels represented as extended binary vectors (via levels_from_labelbatch), shape (n_examples, n_classes-1).

  • importance_weights (Optional[torch.Tensor], optional) – Weights for the examples in the batch. Defaults to None.

  • reduction (str, optional) – Reduction to apply to the loss (‘mean’, ‘sum’, or None). Defaults to “mean”.

Returns:

The computed loss value (scalar if reduction is not None, else vector).

Return type:

torch.Tensor

Raises:
  • ValueError – If logits and levels shapes do not match.

  • ValueError – If reduction is not one of ‘mean’, ‘sum’, or None.

Examples

>>> import torch
>>> logits = torch.tensor([[10.0, 5.0], [-5.0, -10.0]])
>>> levels = torch.tensor([[1.0, 1.0], [0.0, 0.0]])
>>> loss = coral_loss(logits, levels)
>>> float(loss) < 0.1
True
>>> loss_sum = coral_loss(logits, levels, reduction='sum')
>>> float(loss_sum) > float(loss)
True

Computes the Cumulative Link loss for ordinal regression.

This loss treats ordinal regression as K-1 binary classification problems. For each class boundary, it predicts whether the true label is beyond that boundary.

Parameters:
  • logits (torch.Tensor) – Model outputs, shape (n_examples, n_classes-1).

  • labels (torch.Tensor) – True integer labels, shape (n_examples,).

  • num_classes (int) – The total number of ordinal classes.

  • reduction (str, optional) – Type of reduction to apply (‘mean’, ‘sum’, ‘none’). Defaults to “mean”.

Returns:

The computed loss value.

Return type:

torch.Tensor

Raises:

ValueError – If shape mismatch between logits and derived cumulative labels.

Examples

>>> import torch
>>> logits = torch.tensor([[10.0, 10.0], [-10.0, -10.0]])
>>> labels = torch.tensor([2, 0])
>>> loss = cumulative_link_loss(logits, labels, num_classes=3)
>>> float(loss) < 0.1
True
fishy.engine.losses.levels_from_labelbatch(labels: Tensor, num_classes: int, dtype: dtype = None) Tensor[source]

Converts a batch of integer labels to extended binary levels for CORAL.

For example, with 5 classes, label 2 becomes [1, 1, 0, 0]. Vectorized implementation.

Parameters:
  • labels (torch.Tensor) – Batch of integer labels.

  • num_classes (int) – Total number of ordinal classes.

  • dtype (torch.dtype, optional) – Desired dtype of the output tensor. Defaults to None.

Returns:

Binary levels tensor of shape (batch_size, num_classes - 1).

Return type:

torch.Tensor

Examples

>>> import torch
>>> labels = torch.tensor([0, 1, 2])
>>> levels = levels_from_labelbatch(labels, num_classes=4, dtype=torch.float32)
>>> levels
tensor([[0., 0., 0.],
        [1., 0., 0.],
        [1., 1., 0.]])

s