Source code for ucs.utils.callbacks

from pytorch_lightning.callbacks import Callback, ModelCheckpoint


[docs] class SaveModel(ModelCheckpoint): """ A PyTorch Lightning callback to save the pretrained model after checkpointing. This callback extends the functionality of ModelCheckpoint to save the pretrained model to a specified directory in addition to saving the training checkpoint. Args: pretrained_dir (str): The directory where the pretrained model will be saved. *args: Variable length argument list for the base ModelCheckpoint class. **kwargs: Arbitrary keyword arguments for the base ModelCheckpoint class. """ def __init__(self, pretrained_dir, *args, **kwargs): super().__init__(*args, **kwargs) self.pretrained_dir = pretrained_dir def _save_checkpoint(self, trainer, filepath): """ Overrides the base method to save the pretrained model alongside the checkpoint. Args: trainer (Trainer): The PyTorch Lightning trainer instance. filepath (str): Path to the checkpoint file being saved. """ super()._save_checkpoint(trainer, filepath) if trainer.is_global_zero: # main process trainer.lightning_module.save_pretrained_model(self.pretrained_dir)
[docs] class UnfreezeOnPlateau(Callback): """ A PyTorch Lightning callback to unfreeze the layers of the model when the monitored metric plateaus. Stops monitoring after unfreezing. Args: monitor (str): The metric to monitor for improvements (e.g., "val_loss"). mode (str): One of {"min", "max"}. In "min" mode, lower values are better. In "max" mode, higher values are better. patience (int): Number of epochs to wait for improvement before unfreezing layers. delta (float): Minimum change in the monitored metric to qualify as an improvement. """ def __init__(self, monitor="val_loss", mode="min", patience=3, delta=0.0): super().__init__() self.monitor = monitor self.mode = mode self.patience = patience self.delta = delta self.best_value = None self.epochs_without_improvement = 0 self.unfreeze_done = False # Validate mode if self.mode not in {"min", "max"}: raise ValueError(f"Invalid mode: {self.mode}. Choose 'min' or 'max'.")
[docs] def on_validation_end(self, trainer, pl_module): """ Triggered at the end of validation. Checks if the monitored metric has plateaued and unfreezes the specified layers if necessary. Args: trainer (Trainer): The Lightning trainer instance. pl_module (LightningModule): The Lightning module being trained. """ if self.unfreeze_done: return # Stop monitoring after unfreezing current_value = self._get_current_metric(trainer) if current_value is None: return # Skip if the monitored metric is not logged if self._is_improvement(current_value): self._update_best_value(current_value) else: self.epochs_without_improvement += 1 if self.epochs_without_improvement >= self.patience: self._unfreeze_layers(trainer, pl_module)
def _get_current_metric(self, trainer): """ Retrieve the current value of the monitored metric. Args: trainer (Trainer): The Lightning trainer instance. Returns: float or None: The current value of the monitored metric. """ return trainer.callback_metrics.get(self.monitor) def _is_improvement(self, current_value): """ Check if the current value represents a meaningful improvement. Args: current_value (float): The current value of the monitored metric. Returns: bool: True if the improvement is significant based on the mode and delta. """ if self.best_value is None: return True # Treat the first value as an improvement improvement = current_value - self.best_value if self.mode == "min": return improvement <= -self.delta return improvement >= self.delta def _update_best_value(self, current_value): """ Update the best value of the monitored metric. Args: current_value (float): The current value of the monitored metric. """ self.best_value = current_value self.epochs_without_improvement = 0 def _unfreeze_layers(self, trainer, pl_module): """ Unfreezes the specified layers in the Lightning module and stops further monitoring. Args: trainer (Trainer): The Lightning trainer instance. pl_module (LightningModule): The Lightning module being trained. """ pl_module.unfreeze_encoder_layers() self.unfreeze_done = True message = f"UnfreezeOnPlateau: Layers unfrozen at epoch {trainer.current_epoch}. Patience: {self.patience}, Best {self.monitor}: {self.best_value:.4f}" self._log_event(trainer, message) def _log_event(self, trainer, message): """ Log an event message to the console and experiment logger. Args: trainer (Trainer): The Lightning trainer instance. message (str): The message to log. """ if trainer.is_global_zero and trainer.logger: print(f"\n{message}\n") if hasattr(trainer.logger.experiment, "log"): trainer.logger.experiment.log( {"event_message": message, "epoch": trainer.current_epoch} ) # Handle TensorBoardLogger elif hasattr(trainer.logger.experiment, "add_text"): trainer.logger.experiment.add_text( "event_logs", message, global_step=trainer.current_epoch ) # Handle MLFlowLogger elif hasattr(trainer.logger.experiment, "set_tags"): trainer.logger.experiment.set_tags({"event_message": message}) # Fallback else: trainer.logger.log_metrics( {"event_message": message}, step=trainer.current_epoch )
[docs] class LossAdjustmentCallback(Callback): """ A callback to adjust the loss function weights dynamically during training. The callback changes the weights of the loss components (Cross-Entropy and Dice loss) over the course of the training process. It starts with pure Cross-Entropy loss during the warmup phase, then transitions smoothly to a balanced combination of Cross-Entropy and Dice loss, and finally focuses on Dice loss towards the end of training. The transition between loss components happens in three phases: 1. **Warmup Phase** (First `warmup_epochs` epochs): Uses only Cross-Entropy loss. 2. **Transition Phase** (From `warmup_epochs` to `transition_end`): Linearly adjusts from CE loss to a mix of CE and Dice loss. 3. **Dice Loss Phase** (After `transition_end` epochs): Uses only Dice loss. Args: warmup_epochs (int): The number of epochs for the warmup phase where only Cross-Entropy loss is used. Default is 10. transition_end (int): The epoch at which the transition to pure Dice loss should end. Default is 40. """ def __init__(self, max_epochs=50): super().__init__() self.warmup_epochs = max_epochs * 0.2 self.transition_end = max_epochs * 0.8
[docs] def on_train_epoch_start(self, trainer, pl_module): """ This method is called at the start of each epoch to adjust the loss weights based on the current epoch number. Args: trainer (pl.Trainer): The PyTorch Lightning Trainer instance. pl_module (pl.LightningModule): The LightningModule for which the callback is applied. The loss weights (`alpha` and `beta`) are dynamically set based on the epoch: - During the warmup phase (epochs 0 to `warmup_epochs`), only Cross-Entropy loss is used (`alpha=1.0`, `beta=0.0`). - In the transition phase (epochs `warmup_epochs` to `transition_end`), the weights linearly shift from Cross-Entropy to a combination of Cross-Entropy and Dice loss. - After `transition_end` epochs, only Dice loss is used (`alpha=0.0`, `beta=1.0`). """ current_epoch = trainer.current_epoch # Pure CE loss during warmup if current_epoch < self.warmup_epochs: alpha, beta = 1.0, 0.0 # Pure Dice loss after transition elif self.warmup_epochs <= current_epoch <= self.transition_end: progress = progress = (current_epoch - self.warmup_epochs) / ( self.transition_end - self.warmup_epochs ) alpha = 1 - progress beta = progress # Linear transition phase else: alpha = 0.0 beta = 1.0 # Update loss weights pl_module.criterion.set_stage(alpha, beta) # Log weights pl_module.log("alpha", alpha, on_epoch=True, prog_bar=True) pl_module.log("beta", beta, on_epoch=True, prog_bar=True)