Callbacks Module¶
The callbacks module includes functionality for handling checkpoints and saving pretrained models.
- class ucs.utils.callbacks.LossAdjustmentCallback(max_epochs=50)[source]¶
Bases:
CallbackA callback to adjust the loss function weights dynamically during training.
The callback changes the weights of the loss components (Cross-Entropy and Dice loss) over the course of the training process. It starts with pure Cross-Entropy loss during the warmup phase, then transitions smoothly to a balanced combination of Cross-Entropy and Dice loss, and finally focuses on Dice loss towards the end of training.
The transition between loss components happens in three phases:
Warmup Phase (First warmup_epochs epochs): Uses only Cross-Entropy loss.
Transition Phase (From warmup_epochs to transition_end): Linearly adjusts from CE loss to a mix of CE and Dice loss.
Dice Loss Phase (After transition_end epochs): Uses only Dice loss.
- Parameters:
- on_train_epoch_start(trainer, pl_module)[source]¶
This method is called at the start of each epoch to adjust the loss weights based on the current epoch number.
- Parameters:
trainer (pl.Trainer) – The PyTorch Lightning Trainer instance.
pl_module (
LightningModule) – The LightningModule for which the callback is applied.
The loss weights (alpha and beta) are dynamically set based on the epoch: - During the warmup phase (epochs 0 to warmup_epochs), only Cross-Entropy loss is used (alpha=1.0, beta=0.0). - In the transition phase (epochs warmup_epochs to transition_end), the weights linearly shift from Cross-Entropy to a combination of Cross-Entropy and Dice loss. - After transition_end epochs, only Dice loss is used (alpha=0.0, beta=1.0).
- class ucs.utils.callbacks.SaveModel(pretrained_dir, *args, **kwargs)[source]¶
Bases:
ModelCheckpointA PyTorch Lightning callback to save the pretrained model after checkpointing.
This callback extends the functionality of ModelCheckpoint to save the pretrained model to a specified directory in addition to saving the training checkpoint.
- class ucs.utils.callbacks.UnfreezeOnPlateau(monitor='val_loss', mode='min', patience=3, delta=0.0)[source]¶
Bases:
CallbackA PyTorch Lightning callback to unfreeze the layers of the model when the monitored metric plateaus. Stops monitoring after unfreezing.
- Parameters:
monitor (
str) – The metric to monitor for improvements (e.g., “val_loss”).mode (
str) – One of {“min”, “max”}. In “min” mode, lower values are better. In “max” mode, higher values are better.patience (
int) – Number of epochs to wait for improvement before unfreezing layers.delta (
float) – Minimum change in the monitored metric to qualify as an improvement.
- on_validation_end(trainer, pl_module)[source]¶
Triggered at the end of validation. Checks if the monitored metric has plateaued and unfreezes the specified layers if necessary.
- Parameters:
trainer (Trainer) – The Lightning trainer instance.
pl_module (LightningModule) – The Lightning module being trained.