from typing import Optional, List
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from ._functional import soft_dice_score, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
__all__ = ["DiceLoss"]
[docs]
class DiceLoss(_Loss):
def __init__(
self,
mode: str,
classes: Optional[List[int]] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
class_weights: Optional[List[float]] = None,
):
"""Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
classes: List of classes that contribute in loss computation. By default, all channels are included.
log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff`
from_logits: If True, assumes input is raw logits
smooth: Smoothness constant for dice coefficient (a)
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
class_weights: List of weights for each class. If not ``None``, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super(DiceLoss, self).__init__()
self.mode = mode
if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")
if classes is not None:
assert mode != BINARY_MODE, (
"Masking classes is not supported with mode=binary"
)
classes = to_tensor(classes, dtype=torch.long)
self.classes = classes
self.from_logits = from_logits
self.smooth = smooth
self.eps = eps
self.log_loss = log_loss
self.ignore_index = ignore_index
self.class_weights = (
to_tensor(class_weights, dtype=torch.float)
if class_weights is not None
else None
)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)
if self.from_logits:
# Apply activations to get [0..1] class probabilities
# Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
# extreme values 0 and 1
if self.mode == MULTICLASS_MODE:
y_pred = y_pred.log_softmax(dim=1).exp()
else:
y_pred = F.logsigmoid(y_pred).exp()
bs = y_true.size(0)
num_classes = y_pred.size(1)
dims = (0, 2)
if self.mode == BINARY_MODE:
y_true = y_true.reshape(bs, 1, -1)
y_pred = y_pred.reshape(bs, 1, -1)
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask
if self.mode == MULTICLASS_MODE:
y_true = y_true.reshape(bs, -1)
y_pred = y_pred.reshape(bs, num_classes, -1)
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask.unsqueeze(1)
y_true = F.one_hot(
(y_true * mask).to(torch.long), num_classes
) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
else:
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # N, C, H*W
if self.mode == MULTILABEL_MODE:
y_true = y_true.reshape(bs, num_classes, -1)
y_pred = y_pred.reshape(bs, num_classes, -1)
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask
scores = self.compute_score(
y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims
)
if self.log_loss:
loss = -torch.log(scores.clamp_min(self.eps))
else:
loss = 1.0 - scores
# Dice loss is undefined for empty images with no classes
# So we set the contribution of any channel without true pixels to zero
# NOTE: A better workaround would be to use loss term `mean(y_pred)`
# for this case, however it will be a modified jaccard loss
mask = y_true.sum(dims) > 0
loss *= mask.to(loss.dtype)
if self.classes is not None:
loss = loss[self.classes]
return self.aggregate_loss(loss)
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Aggregate per-class losses into a single scalar.
Args:
loss: Per-class loss tensor of shape (C,)
Returns:
Scalar loss value
"""
if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
# If classes filter is applied, slice weights accordingly
if self.classes is not None:
weights = weights[self.classes]
return (loss * weights).sum() / weights.sum()
return loss.mean()
def compute_score(
self, output, target, smooth=0.0, eps=1e-7, dims=None
) -> torch.Tensor:
return soft_dice_score(output, target, smooth, eps, dims)