Source code for ucs.utils.metrics

# util/metric.py
from typing import Literal, Optional

import numpy as np
import torch
from torch import nn
from torchmetrics import (
    Accuracy,
    Dice,
    JaccardIndex,
    MetricCollection,
    Precision,
    Recall,
)
from tqdm import tqdm

from ucs.core.errors import LossWeightsSizeError, LossWeightsTypeError, NormalizeError


[docs] class DiceLoss(nn.Module): def __init__(self, ignore_index=None, smooth=1e-6): super().__init__() self.ignore_index = ignore_index self.smooth = smooth
[docs] def forward(self, inputs, targets): """ Args: inputs: Tensor of shape (B, C, H, W), probabilities target: Tensor of shape (B, H, W), class indices """ # One-hot encode the target to match the shape of inputs (B, C, H, W) target_one_hot = ( torch.nn.functional.one_hot(targets, num_classes=inputs.shape[1]) .permute(0, 3, 1, 2) .float() ) if self.ignore_index is not None: # Create a mask for valid pixels (ignore_index excluded) valid_mask = targets != self.ignore_index valid_mask = valid_mask.unsqueeze(1) # Shape: (B, 1, H, W) inputs = inputs * valid_mask # Mask probabilities target_one_hot = target_one_hot * valid_mask # Mask target # Compute intersection and union # Sum over H and W intersection = (inputs * target_one_hot).sum(dim=(2, 3)) union = inputs.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3)) # Compute Dice coefficient for each class and average across classes dice = (2.0 * intersection + self.smooth) / (union + self.smooth) mean_dice = dice.mean(dim=1) # Average across classes for each batch return 1 - mean_dice.mean() # Average across batch
# CeDiceLoss class
[docs] class CeDiceLoss(nn.Module): """ A combined loss function that incorporates Cross-Entropy Loss (CE) and Dice Loss for semantic segmentation tasks. The CE loss accounts for class imbalance, while the Dice loss measures the overlap between predicted and ground truth masks, providing a robust metric for segmentation accuracy. Args: num_classes (int): The total number of classes in the segmentation task. alpha (float, optional): Weight for Cross-Entropy Loss in the combined loss. Default is 0.8. beta (float, optional): Weight for Dice Loss in the combined loss. Default is 0.2. weights (list, np.ndarray, torch.Tensor, optional): Class weights for handling imbalanced datasets. If provided, it must match the number of classes. ignore_index (int, optional): Class index to ignore during loss computation. Pixels with this index are excluded from both CE and Dice loss calculations. Default is None. reduction (str, optional): Specifies the reduction to apply to the output of the CE loss. Must be one of "none", "mean" (default), or "sum". Attributes: alpha (float): Weight for the CE loss component. beta (float): Weight for the Dice loss component. num_classes (int): Number of classes in the segmentation task. ignore_index (int or None): Class index to ignore during loss computation. weights (torch.Tensor or None): Tensor containing class weights for CE loss. ce_loss (torch.nn.CrossEntropyLoss): Cross-Entropy loss module. dice_loss (Dice): Dice loss for multi-class segmentation. """ def __init__( self, num_classes, alpha=0.7, weights=None, ignore_index=None, reduction="mean" ): super().__init__() self.alpha = alpha self.beta = 1 - alpha self.num_classes = num_classes self.ignore_index = ignore_index self.weights = self.initialize_weights(weights, num_classes) self.ce_loss = nn.CrossEntropyLoss( label_smoothing=0.1, ignore_index=ignore_index if ignore_index is not None else -100, reduction=reduction, weight=self.weights, ) self.dice_loss = DiceLoss(ignore_index=ignore_index)
[docs] def set_stage(self, alpha, beta): self.alpha = alpha self.beta = beta
[docs] def forward(self, inputs, targets): """ Forward pass of the loss computation that combines Cross-Entropy loss and Dice loss. This method computes the total loss by combining the Cross-Entropy loss (CE loss) and Dice loss. The Cross-Entropy loss is calculated using the `inputs` and `targets`, while the Dice loss is calculated by first applying a softmax function to the `inputs` to get the predicted probabilities. The Dice loss is computed by comparing the predicted probabilities with the target labels. The total loss is a weighted sum of the CE loss and Dice loss, with the weights specified by the `alpha` and `beta` parameters. Args: inputs (torch.Tensor): The model's raw output logits (before softmax). Shape: (batch_size, num_classes, height, width) for 2D inputs. targets (torch.Tensor): The ground truth labels for comparison. Shape: (batch_size, height, width) for 2D inputs. Returns: torch.Tensor: The weighted sum of the Cross-Entropy loss and Dice loss. A scalar value representing the total loss. """ ce_loss = self.ce_loss(inputs, targets) dice_loss = self.dice_loss(inputs.softmax(dim=1), targets) return self.alpha * ce_loss + self.beta * dice_loss
[docs] def initialize_weights(self, weights, num_classes): """ Initialize and validate weights for a loss function. Args: weights (list, np.ndarray, torch.Tensor, optional): Weighting factor(s) for each class. num_classes (int): The total number of classes for the loss function. Returns: torch.Tensor: A tensor of weights for all classes. Raises: LossWeightsTypeError: If the `weights` argument is of an unsupported type. LossWeightsSizeError: If the size of `weights` does not match `num_classes`. """ if weights is None: return None # No weights provided, return None # Convert weights to tensor if not isinstance(weights, torch.Tensor): try: weights = torch.tensor(weights, dtype=torch.float) except ValueError as exc: raise LossWeightsTypeError(type(weights)) from exc # Validate size if weights.size(0) != num_classes: raise LossWeightsSizeError(weights.size(0), num_classes) return weights
[docs] class FocalLoss(nn.CrossEntropyLoss): """ Focal Loss for addressing class imbalance in classification tasks. Args: num_classes (int): Number of classes. gamma (float, optional): Focusing parameter. Defaults to 2.0. alpha (float, list, np.ndarray, torch.Tensor, optional): Weighting factor for each class. Defaults to None. ignore_index (int, optional): Specifies a target value that is ignored. Defaults to None. reduction (str, optional): Specifies the reduction to apply to the output. Defaults to 'mean'. """ def __init__( self, num_classes, gamma=2.0, alpha=None, ignore_index=None, reduction="mean" ): self.ignore_index = ignore_index if ignore_index is not None else -100 super().__init__(ignore_index=self.ignore_index, reduction="none") self.gamma = gamma self.reduction = reduction self.num_classes = num_classes self.alpha = self._set_alpha(alpha)
[docs] def forward(self, inputs, target): """ Forward pass of the loss calculation. Args: inputs (torch.Tensor): Predicted logits of shape (N, C, H, W) where C is the number of classes target (torch.Tensor): Ground truth labels of shape (N, H, W) Returns: torch.Tensor: Computed loss """ self.alpha = self.alpha.to(inputs.device) # Calculate standard cross entropy cross_entropy = super().forward(inputs, target) # Calculate probabilities (pt) pt = torch.exp(-cross_entropy) # Apply focal scaling focal_loss = (1 - pt) ** self.gamma * cross_entropy # Apply alpha weighting properly (with broadcasting) if self.alpha is not None: # Convert target to one-hot encoding target_one_hot = torch.zeros_like(inputs) target_one_hot.scatter_(1, target.unsqueeze(1), 1) # Apply alpha weights properly (with broadcasting) alpha_weights = (self.alpha.view(1, -1, 1, 1) * target_one_hot).sum(dim=1) # Mask the alpha weights for the ignored indices valid_mask = target != self.ignore_index focal_loss = alpha_weights * focal_loss # Mask focal loss for ignored indices focal_loss = focal_loss[valid_mask] else: # If no alpha is provided, proceed without weighting valid_mask = target != self.ignore_index focal_loss = focal_loss[valid_mask] # Apply reduction if self.reduction == "mean": return focal_loss.mean() if self.reduction == "sum": return focal_loss.sum() # 'none' return focal_loss
def _set_alpha(self, alpha): """ Set the alpha value for class weighting. Args: alpha (float, list, np.ndarray, torch.Tensor, optional): Weighting factor for each class. Returns: torch.Tensor: The alpha tensor. Raises: TypeError: If the provided alpha is not of type float, list, np.ndarray, or torch.Tensor. """ if alpha is None: return torch.ones(self.num_classes) if isinstance(alpha, (float, int)): return torch.full((self.num_classes,), alpha) if isinstance(alpha, (list, np.ndarray)): return torch.tensor(alpha) if isinstance(alpha, torch.Tensor): return alpha raise TypeError(f"Unsupported alpha type: {type(alpha)}")
[docs] class SegMetrics(MetricCollection): """ A utility class to handle metrics for segmentation tasks. Provides functionality for IoU (Jaccard Index) and Dice coefficient calculation. Args: num_classes (int): Number of classes in the segmentation task. device (str, optional): Device to run the metrics on. Default is "cpu". ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the inputs gradient. Default is None. Attributes: num_classes (int): The number of classes. ignore_index (int): The index to ignore in the target. metrics (dict): Dictionary of metrics to compute. """ def __init__(self, num_classes, device="cpu", ignore_index: Optional[int] = None): self.num_classes = num_classes self.ignore_index = ignore_index metrics = { "mean_iou": JaccardIndex( task="multiclass", average="macro", num_classes=num_classes, ignore_index=self.ignore_index, ).to(device), "mean_dice": Dice( average="macro", num_classes=num_classes, ignore_index=self.ignore_index ).to(device), } super().__init__(metrics)
[docs] def update(self, predicted: torch.Tensor, targets: torch.Tensor) -> None: predicted = predicted.view(-1) targets = targets.view(-1) super().update(predicted, targets)
[docs] class TestMetrics(SegMetrics): """ A utility class to handle metrics for segmentation tasks. Provides functionality for IoU (Jaccard Index) and Dice coefficient calculation. Args: num_classes (int): Number of classes in the segmentation task. device (str, optional): Device to run the metrics on. Default is "cpu". ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the inputs gradient. Default is None. Attributes: num_classes (int): The number of classes. ignore_index (int): The index to ignore in the target. metrics (dict): Dictionary of metrics to compute. """ def __init__(self, num_classes, device="cpu", ignore_index: Optional[int] = None): super().__init__(num_classes, device, ignore_index) test_metrics = { "accuracy": Accuracy( task="multiclass", num_classes=self.num_classes, average="macro", ignore_index=self.ignore_index, ).to(device), "precision": Precision( task="multiclass", num_classes=self.num_classes, average="macro", ignore_index=self.ignore_index, ).to(device), "recall": Recall( task="multiclass", num_classes=self.num_classes, average="macro", ignore_index=self.ignore_index, ).to(device), } self.add_metrics(test_metrics)
[docs] def compute_class_weights( dataloader, num_classes, mask_key="labels", normalize: Literal["sum", "max", "raw", "balanced"] = "raw", ignore_index=None, ): """ Compute class weights based on the frequencies of each class in the dataset. Args: dataloader (DataLoader): PyTorch DataLoader containing the dataset. num_classes (int): Total number of classes. mask_key (str): Key to access masks/labels in the dataloader's batch. normalize (str): Method to normalize weights: - "sum": Scales weights to sum to 1. - "max": Scales weights so the maximum weight is 1. - "raw": Leaves weights unnormalized. - "balanced": Ensures equal contribution from all classes. ignore_index (int, optional): Class index to ignore in weight computation. Returns: torch.Tensor: Computed class weights of shape (num_classes,). Raises: NormalizeError: If the normalization method is unsupported. """ class_counts = torch.zeros(num_classes, dtype=torch.int64) for batch in tqdm(dataloader, desc="Compute class weights"): masks = batch[mask_key] masks = masks.view(-1) # Flatten masks into 1D # Extract valid pixels; already 1D valid_labels = masks[masks != ignore_index] counts = torch.bincount(valid_labels.int(), minlength=num_classes) class_counts += counts total_pixels = class_counts.sum() class_weights = torch.where( class_counts > 0, 1 / torch.log(1.02 + (class_counts / total_pixels) + 1e-6), torch.tensor(0.0, dtype=torch.float32), ) # Compute class weights based on the normalization strategy if normalize == "sum": class_weights /= class_weights.sum() elif normalize == "max": class_weights /= class_weights.max() elif normalize == "balanced": effective_classes = num_classes - (1 if ignore_index is not None else 0) class_weights = class_weights * (effective_classes / class_weights.sum()) elif normalize == "raw": pass else: raise NormalizeError(normalize) return class_weights.to(torch.float32)