Source code for segmentation_models_pytorch.losses.jaccard

from typing import Optional, List

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from ._functional import soft_jaccard_score, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE

__all__ = ["JaccardLoss"]


[docs] class JaccardLoss(_Loss): def __init__( self, mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, ignore_index: Optional[int] = None, eps: float = 1e-7, class_weights: Optional[List[float]] = None, ): """Jaccard loss for image segmentation task. It supports binary, multiclass and multilabel cases Args: mode: Loss mode 'binary', 'multiclass' or 'multilabel' classes: List of classes that contribute in loss computation. By default, all channels are included. log_loss: If True, loss computed as `- log(jaccard_coeff)`, otherwise `1 - jaccard_coeff` from_logits: If True, assumes input is raw logits smooth: Smoothness constant for dice coefficient eps: A small epsilon for numerical stability to avoid zero division error (denominator will be always greater or equal to eps) class_weights: List of weights for each class. If not ``None``, the loss for each class is multiplied by the corresponding weight. Only supported for multiclass and multilabel modes. Weights do not need to be normalized. Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) Reference https://github.com/BloodAxe/pytorch-toolbelt """ assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} super(JaccardLoss, self).__init__() self.mode = mode if class_weights is not None and mode == BINARY_MODE: raise ValueError("class_weights are not supported with mode=binary") if classes is not None: assert mode != BINARY_MODE, ( "Masking classes is not supported with mode=binary" ) classes = to_tensor(classes, dtype=torch.long) self.classes = classes self.from_logits = from_logits self.smooth = smooth self.ignore_index = ignore_index self.eps = eps self.log_loss = log_loss self.class_weights = ( to_tensor(class_weights, dtype=torch.float) if class_weights is not None else None ) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: assert y_true.size(0) == y_pred.size(0) if self.from_logits: # Apply activations to get [0..1] class probabilities # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on # extreme values 0 and 1 if self.mode == MULTICLASS_MODE: y_pred = y_pred.log_softmax(dim=1).exp() else: y_pred = F.logsigmoid(y_pred).exp() bs = y_true.size(0) num_classes = y_pred.size(1) dims = (0, 2) if self.mode == BINARY_MODE: y_true = y_true.reshape(bs, 1, -1) y_pred = y_pred.reshape(bs, 1, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask if self.mode == MULTICLASS_MODE: y_true = y_true.reshape(bs, -1) y_pred = y_pred.reshape(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask.unsqueeze(1) y_true = F.one_hot( (y_true * mask).to(torch.long), num_classes ) # N,H*W -> N,H*W, C y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W else: y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: y_true = y_true.reshape(bs, num_classes, -1) y_pred = y_pred.reshape(bs, num_classes, -1) if self.ignore_index is not None: mask = y_true != self.ignore_index y_pred = y_pred * mask y_true = y_true * mask scores = soft_jaccard_score( y_pred, y_true.type(y_pred.dtype), smooth=self.smooth, eps=self.eps, dims=dims, ) if self.log_loss: loss = -torch.log(scores.clamp_min(self.eps)) else: loss = 1.0 - scores # IoU loss is defined for non-empty classes # So we zero contribution of channel that does not have true pixels # NOTE: A better workaround would be to use loss term `mean(y_pred)` # for this case, however it will be a modified jaccard loss mask = y_true.sum(dims) > 0 loss *= mask.float() if self.classes is not None: loss = loss[self.classes] if self.class_weights is not None: weights = self.class_weights.to(loss.device) if self.classes is not None: weights = weights[self.classes] return (loss * weights).sum() / weights.sum() return loss.mean()