from typing import Optional, List
from functools import partial
import torch
from torch.nn.modules.loss import _Loss
from ._functional import focal_loss_with_logits, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
__all__ = ["FocalLoss"]
[docs]
class FocalLoss(_Loss):
def __init__(
self,
mode: str,
alpha: Optional[float] = None,
gamma: Optional[float] = 2.0,
ignore_index: Optional[int] = None,
from_logits: bool = True,
eps: float = 1e-7,
reduction: Optional[str] = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
class_weights: Optional[List[float]] = None,
):
"""Compute Focal loss
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
from_logits: If True, assumes input is raw logits
eps: Small value used for numerical stability when converting probabilities to logits .
alpha: Prior probability of having positive value in target.
gamma: Power factor for dampening weight (focal strength).
ignore_index: If not None, targets may contain values to be ignored.
Target values equal to ignore_index will be ignored from loss computation.
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
should use `reduction="sum"`.
class_weights: List of weights for each class. If not ``None``, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()
if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")
self.mode = mode
self.from_logits = from_logits
self.ignore_index = ignore_index
self.reduction = reduction
self.eps = eps
self.focal_loss_fn = partial(
focal_loss_with_logits,
alpha=alpha,
gamma=gamma,
reduced_threshold=reduced_threshold,
reduction=reduction,
normalized=normalized,
)
self.class_weights = (
to_tensor(class_weights, dtype=torch.float)
if class_weights is not None
else None
)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if not self.from_logits:
y_pred = torch.clamp(y_pred, self.eps, 1 - self.eps)
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
# inverse sigmoid
y_pred = torch.log(y_pred / (1 - y_pred))
elif self.mode == MULTICLASS_MODE:
# convert softmax probabilities to log-space
y_pred = torch.log(y_pred)
if self.mode == BINARY_MODE:
y_true = y_true.reshape(-1)
y_pred = y_pred.reshape(-1)
if self.ignore_index is not None:
# Filter predictions with ignore label from loss computation
not_ignored = y_true != self.ignore_index
y_pred = y_pred[not_ignored]
y_true = y_true[not_ignored]
loss = self.focal_loss_fn(y_pred, y_true)
elif self.mode in {MULTILABEL_MODE, MULTICLASS_MODE}:
num_classes = y_pred.size(1)
# Filter anchors with -1 label from loss computation
if self.ignore_index is not None:
not_ignored = y_true != self.ignore_index
class_losses = []
for cls in range(num_classes):
if self.mode == MULTICLASS_MODE:
cls_y_true = (y_true == cls).long()
else:
cls_y_true = y_true[:, cls, ...]
cls_y_pred = y_pred[:, cls, ...]
if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]
class_losses.append(self.focal_loss_fn(cls_y_pred, cls_y_true))
class_losses = torch.stack(class_losses) # shape (C,)
if self.class_weights is not None:
weights = self.class_weights.to(class_losses.device)
loss = (class_losses * weights).sum() / weights.sum()
else:
loss = class_losses.mean()
return loss