Source code for segmentation_models_pytorch.metrics.functional

"""Various metrics based on Type I and Type II errors.

References:
    https://en.wikipedia.org/wiki/Confusion_matrix


Example:

    .. code-block:: python

        import segmentation_models_pytorch as smp

        # lets assume we have multilabel prediction for 3 classes
        output = torch.rand([10, 3, 256, 256])
        target = torch.rand([10, 3, 256, 256]).round().long()

        # first compute statistics for true positives, false positives, false negative and
        # true negative "pixels"
        tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)

        # then compute metrics with required reduction (see metric docs)
        iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")

"""
import torch
import warnings
from typing import Optional, List, Tuple, Union


__all__ = [
    "get_stats",
    "fbeta_score",
    "f1_score",
    "iou_score",
    "accuracy",
    "precision",
    "recall",
    "sensitivity",
    "specificity",
    "balanced_accuracy",
    "positive_predictive_value",
    "negative_predictive_value",
    "false_negative_rate",
    "false_positive_rate",
    "false_discovery_rate",
    "false_omission_rate",
    "positive_likelihood_ratio",
    "negative_likelihood_ratio",
]


###################################################################################################
# Statistics computation (true positives, false positives, false negatives, false positives)
###################################################################################################


[docs]def get_stats( output: Union[torch.LongTensor, torch.FloatTensor], target: torch.LongTensor, mode: str, ignore_index: Optional[int] = None, threshold: Optional[Union[float, List[float]]] = None, num_classes: Optional[int] = None, ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: """Compute true positive, false positive, false negative, true negative 'pixels' for each image and each class. Args: output (Union[torch.LongTensor, torch.FloatTensor]): Model output with following shapes and types depending on the specified ``mode``: 'binary' shape (N, 1, ...) and ``torch.LongTensor`` or ``torch.FloatTensor`` 'multilabel' shape (N, C, ...) and ``torch.LongTensor`` or ``torch.FloatTensor`` 'multiclass' shape (N, ...) and ``torch.LongTensor`` target (torch.LongTensor): Targets with following shapes depending on the specified ``mode``: 'binary' shape (N, 1, ...) 'multilabel' shape (N, C, ...) 'multiclass' shape (N, ...) mode (str): One of ``'binary'`` | ``'multilabel'`` | ``'multiclass'`` ignore_index (Optional[int]): Label to ignore on for metric computation. **Not** supproted for ``'binary'`` and ``'multilabel'`` modes. Defaults to None. threshold (Optional[float, List[float]]): Binarization threshold for ``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None. num_classes (Optional[int]): Number of classes, necessary attribute only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1). If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or ``255``. Raises: ValueError: in case of misconfiguration. Returns: Tuple[torch.LongTensor]: true_positive, false_positive, false_negative, true_negative tensors (N, C) shape each. """ if torch.is_floating_point(target): raise ValueError(f"Target should be one of the integer types, got {target.dtype}.") if torch.is_floating_point(output) and threshold is None: raise ValueError( f"Output should be one of the integer types if ``threshold`` is not None, got {output.dtype}." ) if torch.is_floating_point(output) and mode == "multiclass": raise ValueError(f"For ``multiclass`` mode ``output`` should be one of the integer types, got {output.dtype}.") if mode not in {"binary", "multiclass", "multilabel"}: raise ValueError(f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}.") if mode == "multiclass" and threshold is not None: raise ValueError("``threshold`` parameter does not supported for this 'multiclass' mode") if output.shape != target.shape: raise ValueError( "Dimensions should match, but ``output`` shape is not equal to ``target`` " + f"shape, {output.shape} != {target.shape}" ) if mode != "multiclass" and ignore_index is not None: raise ValueError(f"``ignore_index`` parameter is not supproted for '{mode}' mode") if mode == "multiclass" and num_classes is None: raise ValueError("``num_classes`` attribute should be not ``None`` for 'multiclass' mode.") if ignore_index is not None and 0 <= ignore_index <= num_classes - 1: raise ValueError( f"``ignore_index`` should be outside the class values range, but got class values in range " f"0..{num_classes - 1} and ``ignore_index={ignore_index}``. Hint: if you have ``ignore_index = 0``" f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``" f"and relevant class values started from ``0``." ) if mode == "multiclass": tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes, ignore_index) else: if threshold is not None: output = torch.where(output >= threshold, 1, 0) target = torch.where(target >= threshold, 1, 0) tp, fp, fn, tn = _get_stats_multilabel(output, target) return tp, fp, fn, tn
@torch.no_grad() def _get_stats_multiclass( output: torch.LongTensor, target: torch.LongTensor, num_classes: int, ignore_index: Optional[int], ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: batch_size, *dims = output.shape num_elements = torch.prod(torch.tensor(dims)).long() if ignore_index is not None: ignore = target == ignore_index output = torch.where(ignore, -1, output) target = torch.where(ignore, -1, target) ignore_per_sample = ignore.view(batch_size, -1).sum(1) tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long) fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long) fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long) tn_count = torch.zeros(batch_size, num_classes, dtype=torch.long) for i in range(batch_size): target_i = target[i] output_i = output[i] mask = output_i == target_i matched = torch.where(mask, target_i, -1) tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1) fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp tn = num_elements - tp - fp - fn if ignore_index is not None: tn = tn - ignore_per_sample[i] tp_count[i] = tp.long() fp_count[i] = fp.long() fn_count[i] = fn.long() tn_count[i] = tn.long() return tp_count, fp_count, fn_count, tn_count @torch.no_grad() def _get_stats_multilabel( output: torch.LongTensor, target: torch.LongTensor, ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: batch_size, num_classes, *dims = target.shape output = output.view(batch_size, num_classes, -1) target = target.view(batch_size, num_classes, -1) tp = (output * target).sum(2) fp = output.sum(2) - tp fn = target.sum(2) - tp tn = torch.prod(torch.tensor(dims)) - (tp + fp + fn) return tp, fp, fn, tn ################################################################################################### # Metrics computation ################################################################################################### def _handle_zero_division(x, zero_division): nans = torch.isnan(x) if torch.any(nans) and zero_division == "warn": warnings.warn("Zero division in metric calculation!") value = zero_division if zero_division != "warn" else 0 value = torch.tensor(value, dtype=x.dtype).to(x.device) x = torch.where(nans, value, x) return x def _compute_metric( metric_fn, tp, fp, fn, tn, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division="warn", **metric_kwargs, ) -> float: if class_weights is None and reduction is not None and "weighted" in reduction: raise ValueError(f"Class weights should be provided for `{reduction}` reduction") class_weights = class_weights if class_weights is not None else 1.0 class_weights = torch.tensor(class_weights).to(tp.device) class_weights = class_weights / class_weights.sum() if reduction == "micro": tp = tp.sum() fp = fp.sum() fn = fn.sum() tn = tn.sum() score = metric_fn(tp, fp, fn, tn, **metric_kwargs) elif reduction == "macro": tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) tn = tn.sum(0) score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score * class_weights).mean() elif reduction == "weighted": tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) tn = tn.sum(0) score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score * class_weights).sum() elif reduction == "micro-imagewise": tp = tp.sum(1) fp = fp.sum(1) fn = fn.sum(1) tn = tn.sum(1) score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = score.mean() elif reduction == "macro-imagewise" or reduction == "weighted-imagewise": score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score.mean(0) * class_weights).mean() elif reduction == "none" or reduction is None: score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) else: raise ValueError( "`reduction` should be in [micro, macro, weighted, micro-imagewise," + "macro-imagesize, weighted-imagewise, none, None]" ) return score # Logic for metric computation, all metrics are with the same interface def _fbeta_score(tp, fp, fn, tn, beta=1): beta_tp = (1 + beta**2) * tp beta_fn = (beta**2) * fn score = beta_tp / (beta_tp + beta_fn + fp) return score def _iou_score(tp, fp, fn, tn): return tp / (tp + fp + fn) def _accuracy(tp, fp, fn, tn): return (tp + tn) / (tp + fp + fn + tn) def _sensitivity(tp, fp, fn, tn): return tp / (tp + fn) def _specificity(tp, fp, fn, tn): return tn / (tn + fp) def _balanced_accuracy(tp, fp, fn, tn): return (_sensitivity(tp, fp, fn, tn) + _specificity(tp, fp, fn, tn)) / 2 def _positive_predictive_value(tp, fp, fn, tn): return tp / (tp + fp) def _negative_predictive_value(tp, fp, fn, tn): return tn / (tn + fn) def _false_negative_rate(tp, fp, fn, tn): return fn / (fn + tp) def _false_positive_rate(tp, fp, fn, tn): return fp / (fp + tn) def _false_discovery_rate(tp, fp, fn, tn): return 1 - _positive_predictive_value(tp, fp, fn, tn) def _false_omission_rate(tp, fp, fn, tn): return 1 - _negative_predictive_value(tp, fp, fn, tn) def _positive_likelihood_ratio(tp, fp, fn, tn): return _sensitivity(tp, fp, fn, tn) / _false_positive_rate(tp, fp, fn, tn) def _negative_likelihood_ratio(tp, fp, fn, tn): return _false_negative_rate(tp, fp, fn, tn) / _specificity(tp, fp, fn, tn)
[docs]def fbeta_score( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, beta: float = 1.0, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """F beta score""" return _compute_metric( _fbeta_score, tp, fp, fn, tn, beta=beta, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def f1_score( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """F1 score""" return _compute_metric( _fbeta_score, tp, fp, fn, tn, beta=1.0, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def iou_score( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """IoU score or Jaccard index""" # noqa return _compute_metric( _iou_score, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def accuracy( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Accuracy""" return _compute_metric( _accuracy, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def sensitivity( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Sensitivity, recall, hit rate, or true positive rate (TPR)""" return _compute_metric( _sensitivity, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def specificity( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Specificity, selectivity or true negative rate (TNR)""" return _compute_metric( _specificity, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def balanced_accuracy( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Balanced accuracy""" return _compute_metric( _balanced_accuracy, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def positive_predictive_value( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Precision or positive predictive value (PPV)""" return _compute_metric( _positive_predictive_value, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def negative_predictive_value( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Negative predictive value (NPV)""" return _compute_metric( _negative_predictive_value, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def false_negative_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Miss rate or false negative rate (FNR)""" return _compute_metric( _false_negative_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def false_positive_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Fall-out or false positive rate (FPR)""" return _compute_metric( _false_positive_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def false_discovery_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """False discovery rate (FDR)""" # noqa return _compute_metric( _false_discovery_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def false_omission_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """False omission rate (FOR)""" # noqa return _compute_metric( _false_omission_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def positive_likelihood_ratio( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Positive likelihood ratio (LR+)""" return _compute_metric( _positive_likelihood_ratio, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
[docs]def negative_likelihood_ratio( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Negative likelihood ratio (LR-)""" return _compute_metric( _negative_likelihood_ratio, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, )
_doc = """ Args: tp (torch.LongTensor): tensor of shape (N, C), true positive cases fp (torch.LongTensor): tensor of shape (N, C), false positive cases fn (torch.LongTensor): tensor of shape (N, C), false negative cases tn (torch.LongTensor): tensor of shape (N, C), true negative cases reduction (Optional[str]): Define how to aggregate metric between classes and images: - 'micro' Sum true positive, false positive, false negative and true negative pixels over all images and all classes and then compute score. - 'macro' Sum true positive, false positive, false negative and true negative pixels over all images for each label, then compute score for each label separately and average labels scores. This does not take label imbalance into account. - 'weighted' Sum true positive, false positive, false negative and true negative pixels over all images for each label, then compute score for each label separately and average weighted labels scores. - 'micro-imagewise' Sum true positive, false positive, false negative and true negative pixels for **each image**, then compute score for **each image** and average scores over dataset. All images contribute equally to final score, however takes into accout class imbalance for each image. - 'macro-imagewise' Compute score for each image and for each class on that image separately, then compute average score on each image over labels and average image scores over dataset. Does not take into account label imbalance on each image. - 'weighted-imagewise' Compute score for each image and for each class on that image separately, then compute weighted average score on each image over labels and average image scores over dataset. - 'none' or ``None`` Same as ``'macro-imagewise'``, but without any reduction. For ``'binary'`` case ``'micro' = 'macro' = 'weighted'`` and ``'micro-imagewise' = 'macro-imagewise' = 'weighted-imagewise'``. Prefixes ``'micro'``, ``'macro'`` and ``'weighted'`` define how the scores for classes will be aggregated, while postfix ``'imagewise'`` defines how scores between the images will be aggregated. class_weights (Optional[List[float]]): list of class weights for metric aggregation, in case of `weighted*` reduction is chosen. Defaults to None. zero_division (Union[str, float]): Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. If set to β€œwarn”, this acts as 0, but warnings are also raised. Defaults to 1. Returns: torch.Tensor: if ``'reduction'`` is not ``None`` or ``'none'`` returns scalar metric, else returns tensor of shape (N, C) References: https://en.wikipedia.org/wiki/Confusion_matrix """ fbeta_score.__doc__ += _doc f1_score.__doc__ += _doc iou_score.__doc__ += _doc accuracy.__doc__ += _doc sensitivity.__doc__ += _doc specificity.__doc__ += _doc balanced_accuracy.__doc__ += _doc positive_predictive_value.__doc__ += _doc negative_predictive_value.__doc__ += _doc false_negative_rate.__doc__ += _doc false_positive_rate.__doc__ += _doc false_discovery_rate.__doc__ += _doc false_omission_rate.__doc__ += _doc positive_likelihood_ratio.__doc__ += _doc negative_likelihood_ratio.__doc__ += _doc precision = positive_predictive_value recall = sensitivity