Callbacks Module

The callbacks module includes functionality for handling checkpoints and saving pretrained models.

class ucs.utils.callbacks.LossAdjustmentCallback(max_epochs=50)[source]

Bases: Callback

A callback to adjust the loss function weights dynamically during training.

The callback changes the weights of the loss components (Cross-Entropy and Dice loss) over the course of the training process. It starts with pure Cross-Entropy loss during the warmup phase, then transitions smoothly to a balanced combination of Cross-Entropy and Dice loss, and finally focuses on Dice loss towards the end of training.

The transition between loss components happens in three phases:

  1. Warmup Phase (First warmup_epochs epochs): Uses only Cross-Entropy loss.

  2. Transition Phase (From warmup_epochs to transition_end): Linearly adjusts from CE loss to a mix of CE and Dice loss.

  3. Dice Loss Phase (After transition_end epochs): Uses only Dice loss.

Parameters:
  • warmup_epochs (int) – The number of epochs for the warmup phase where only Cross-Entropy loss is used. Default is 10.

  • transition_end (int) – The epoch at which the transition to pure Dice loss should end. Default is 40.

on_train_epoch_start(trainer, pl_module)[source]

This method is called at the start of each epoch to adjust the loss weights based on the current epoch number.

Parameters:
  • trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.

  • pl_module (LightningModule) – The LightningModule for which the callback is applied.

The loss weights (alpha and beta) are dynamically set based on the epoch: - During the warmup phase (epochs 0 to warmup_epochs), only Cross-Entropy loss is used (alpha=1.0, beta=0.0). - In the transition phase (epochs warmup_epochs to transition_end), the weights linearly shift from Cross-Entropy to a combination of Cross-Entropy and Dice loss. - After transition_end epochs, only Dice loss is used (alpha=0.0, beta=1.0).

class ucs.utils.callbacks.SaveModel(pretrained_dir, *args, **kwargs)[source]

Bases: ModelCheckpoint

A PyTorch Lightning callback to save the pretrained model after checkpointing.

This callback extends the functionality of ModelCheckpoint to save the pretrained model to a specified directory in addition to saving the training checkpoint.

Parameters:
  • pretrained_dir (str) – The directory where the pretrained model will be saved.

  • *args – Variable length argument list for the base ModelCheckpoint class.

  • **kwargs – Arbitrary keyword arguments for the base ModelCheckpoint class.

class ucs.utils.callbacks.UnfreezeOnPlateau(monitor='val_loss', mode='min', patience=3, delta=0.0)[source]

Bases: Callback

A PyTorch Lightning callback to unfreeze the layers of the model when the monitored metric plateaus. Stops monitoring after unfreezing.

Parameters:
  • monitor (str) – The metric to monitor for improvements (e.g., “val_loss”).

  • mode (str) – One of {“min”, “max”}. In “min” mode, lower values are better. In “max” mode, higher values are better.

  • patience (int) – Number of epochs to wait for improvement before unfreezing layers.

  • delta (float) – Minimum change in the monitored metric to qualify as an improvement.

on_validation_end(trainer, pl_module)[source]

Triggered at the end of validation. Checks if the monitored metric has plateaued and unfreezes the specified layers if necessary.

Parameters:
  • trainer (Trainer) – The Lightning trainer instance.

  • pl_module (LightningModule) – The Lightning module being trained.