Source code for segmentation_models_pytorch.losses.soft_bce

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn, Tensor

__all__ = ["SoftBCEWithLogitsLoss"]


[docs]class SoftBCEWithLogitsLoss(nn.Module): __constants__ = [ "weight", "pos_weight", "reduction", "ignore_index", "smooth_factor", ] def __init__( self, weight: Optional[torch.Tensor] = None, ignore_index: Optional[int] = -100, reduction: str = "mean", smooth_factor: Optional[float] = None, pos_weight: Optional[torch.Tensor] = None, ): """Drop-in replacement for torch.nn.BCEWithLogitsLoss with few additions: ignore_index and label_smoothing Args: ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 1] -> [0.9, 0.1, 0.9]) Shape - **y_pred** - torch.Tensor of shape NxCxHxW - **y_true** - torch.Tensor of shape NxHxW or Nx1xHxW Reference https://github.com/BloodAxe/pytorch-toolbelt """ super().__init__() self.ignore_index = ignore_index self.reduction = reduction self.smooth_factor = smooth_factor self.register_buffer("weight", weight) self.register_buffer("pos_weight", pos_weight) def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """ Args: y_pred: torch.Tensor of shape (N, C, H, W) y_true: torch.Tensor of shape (N, H, W) or (N, 1, H, W) Returns: loss: torch.Tensor """ if self.smooth_factor is not None: soft_targets = (1 - y_true) * self.smooth_factor + y_true * (1 - self.smooth_factor) else: soft_targets = y_true loss = F.binary_cross_entropy_with_logits( y_pred, soft_targets, self.weight, pos_weight=self.pos_weight, reduction="none", ) if self.ignore_index is not None: not_ignored_mask = y_true != self.ignore_index loss *= not_ignored_mask.type_as(loss) if self.reduction == "mean": loss = loss.mean() if self.reduction == "sum": loss = loss.sum() return loss