Source code for segmentation_models_pytorch.losses.focal

from typing import Optional, List
from functools import partial

import torch
from torch.nn.modules.loss import _Loss
from ._functional import focal_loss_with_logits, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE

__all__ = ["FocalLoss"]


[docs] class FocalLoss(_Loss): def __init__( self, mode: str, alpha: Optional[float] = None, gamma: Optional[float] = 2.0, ignore_index: Optional[int] = None, from_logits: bool = True, eps: float = 1e-7, reduction: Optional[str] = "mean", normalized: bool = False, reduced_threshold: Optional[float] = None, class_weights: Optional[List[float]] = None, ): """Compute Focal loss Args: mode: Loss mode 'binary', 'multiclass' or 'multilabel' from_logits: If True, assumes input is raw logits eps: Small value used for numerical stability when converting probabilities to logits . alpha: Prior probability of having positive value in target. gamma: Power factor for dampening weight (focal strength). ignore_index: If not None, targets may contain values to be ignored. Target values equal to ignore_index will be ignored from loss computation. normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). reduced_threshold: Switch to reduced focal loss. Note, when using this mode you should use `reduction="sum"`. class_weights: List of weights for each class. If not ``None``, the loss for each class is multiplied by the corresponding weight. Only supported for multiclass and multilabel modes. Weights do not need to be normalized. Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) Reference https://github.com/BloodAxe/pytorch-toolbelt """ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super().__init__() if class_weights is not None and mode == BINARY_MODE: raise ValueError("class_weights are not supported with mode=binary") self.mode = mode self.from_logits = from_logits self.ignore_index = ignore_index self.reduction = reduction self.eps = eps self.focal_loss_fn = partial( focal_loss_with_logits, alpha=alpha, gamma=gamma, reduced_threshold=reduced_threshold, reduction=reduction, normalized=normalized, ) self.class_weights = ( to_tensor(class_weights, dtype=torch.float) if class_weights is not None else None ) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: if not self.from_logits: y_pred = torch.clamp(y_pred, self.eps, 1 - self.eps) if self.mode in {BINARY_MODE, MULTILABEL_MODE}: # inverse sigmoid y_pred = torch.log(y_pred / (1 - y_pred)) elif self.mode == MULTICLASS_MODE: # convert softmax probabilities to log-space y_pred = torch.log(y_pred) if self.mode == BINARY_MODE: y_true = y_true.reshape(-1) y_pred = y_pred.reshape(-1) if self.ignore_index is not None: # Filter predictions with ignore label from loss computation not_ignored = y_true != self.ignore_index y_pred = y_pred[not_ignored] y_true = y_true[not_ignored] loss = self.focal_loss_fn(y_pred, y_true) elif self.mode in {MULTILABEL_MODE, MULTICLASS_MODE}: num_classes = y_pred.size(1) # Filter anchors with -1 label from loss computation if self.ignore_index is not None: not_ignored = y_true != self.ignore_index class_losses = [] for cls in range(num_classes): if self.mode == MULTICLASS_MODE: cls_y_true = (y_true == cls).long() else: cls_y_true = y_true[:, cls, ...] cls_y_pred = y_pred[:, cls, ...] if self.ignore_index is not None: cls_y_true = cls_y_true[not_ignored] cls_y_pred = cls_y_pred[not_ignored] class_losses.append(self.focal_loss_fn(cls_y_pred, cls_y_true)) class_losses = torch.stack(class_losses) # shape (C,) if self.class_weights is not None: weights = self.class_weights.to(class_losses.device) loss = (class_losses * weights).sum() / weights.sum() else: loss = class_losses.mean() return loss