SegformerFinetuner¶
The lightning_model module provides a PyTorch Lightning LightningModule for fine-tuning the Segformer model for semantic segmentation tasks.
- class ucs.model.lightning_model.SegformerFinetuner(config=None, class_weights=None, **kwargs)[source]¶
Bases:
LightningModuleA PyTorch Lightning module for fine-tuning the SegFormer model for semantic segmentation tasks.
- Variables:
model (
SegformerForSemanticSegmentation) – The SegFormer model for semantic segmentation.metrics (
SegMetrics) – Metrics object for tracking performance.criterion (
CeDiceLoss) – The loss function used for training.test_results (
dict) – Stores ‘predictions’ and ‘ground_truths’ for test evaluation.
- Parameters:
config (
TrainingConfig, optional) – Training configuration containing model hyperparameters, stored in self.hparams.class_weights (
Tensor, optional) – Class weights for loss balancing.**kwargs (
dict) –Additional hyperparameters that override config, such as:
model_name (
str): The SegFormer variant to use (e.g., “b0”).max_epochs (
int): Maximum number of training epochs.learning_rate (
float): Learning rate for the optimizer.weight_decay (
float): Weight decay (L2 regularization).ignore_index (Optional[
int]): Label index to ignore during training.weighting_strategy (
str): Strategy for class weighting (‘none’, ‘balanced’, ‘max’, ‘sum’, or ‘raw’).alpha (
float):CeDiceLossloss alpha parameter.id2label (Dict[
int,str]): Mapping from class indices to class labels.
- calculate_confusion_matrix()[source]¶
Calculate the confusion matrix from test predictions and ground truths.
- Returns:
The confusion matrix.
- Return type:
- forward(images, masks=None)[source]¶
Forward pass through the model.
- freeze_encoder_layers(blocks_to_freeze=None)[source]¶
Freezes specified encoder layers to prevent weight updates.
- Parameters:
blocks_to_freeze (
list`[:class:`str], optional) – List of encoder blocks to freeze.
- on_train_epoch_end()[source]¶
Reset the metrics at the end of the train epoch.
This prevents the accumulation of metric values across epochs and ensures metrics are calculated independently for each epoch.
- on_validation_epoch_end()[source]¶
Reset the metrics at the end of the validation epoch.
This prevents the accumulation of metric values across epochs and ensures metrics are calculated independently for each epoch.
- save_pretrained_model(pretrained_path)[source]¶
Save the trained model in Hugging Face’s Transformers-compatible format.
Notes
This allows the saved model to be loaded later using the from_pretrained method.