diff --git a/tests/trainers/test_instance_segmentation.py b/tests/trainers/test_instance_segmentation.py new file mode 100644 index 00000000000..527d6960b99 --- /dev/null +++ b/tests/trainers/test_instance_segmentation.py @@ -0,0 +1,230 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from pathlib import Path +from typing import Any, cast + +import pytest +import segmentation_models_pytorch as smp +import timm +import torch +import torch.nn as nn +from lightning.pytorch import Trainer +from pytest import MonkeyPatch +from torch.nn.modules import Module +from torchvision.models._api import WeightsEnum + +from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule +from torchgeo.datasets import LandCoverAI, RGBBandsMissingError +from torchgeo.main import main +from torchgeo.models import ResNet18_Weights +from torchgeo.trainers import InstanceSegmentationTask + + +class SegmentationTestModel(Module): + def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None: + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return cast(torch.Tensor, self.conv1(x)) + + +def create_model(**kwargs: Any) -> Module: + return SegmentationTestModel(**kwargs) + + +def plot(*args: Any, **kwargs: Any) -> None: + return None + + +def plot_missing_bands(*args: Any, **kwargs: Any) -> None: + raise RGBBandsMissingError() + + +class TestSemanticSegmentationTask: + @pytest.mark.parametrize( + 'name', + [ + 'agrifieldnet', + 'cabuar', + 'chabud', + 'chesapeake_cvpr_5', + 'chesapeake_cvpr_7', + 'deepglobelandcover', + 'etci2021', + 'ftw', + 'geonrw', + 'gid15', + 'inria', + 'l7irish', + 'l8biome', + 'landcoverai', + 'landcoverai100', + 'loveda', + 'naipchesapeake', + 'potsdam2d', + 'sen12ms_all', + 'sen12ms_s1', + 'sen12ms_s2_all', + 'sen12ms_s2_reduced', + 'sentinel2_cdl', + 'sentinel2_eurocrops', + 'sentinel2_nccm', + 'sentinel2_south_america_soybean', + 'southafricacroptype', + 'spacenet1', + 'spacenet6', + 'ssl4eo_l_benchmark_cdl', + 'ssl4eo_l_benchmark_nlcd', + 'vaihingen2d', + ], + ) + def test_trainer( + self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool + ) -> None: + + config = os.path.join('tests', 'conf', name + '.yaml') + + args = [ + '--config', + config, + '--trainer.accelerator', + 'cpu', + '--trainer.fast_dev_run', + str(fast_dev_run), + '--trainer.max_epochs', + '1', + '--trainer.log_every_n_steps', + '1', + ] + + main(['fit', *args]) + try: + main(['test', *args]) + except MisconfigurationException: + pass + try: + main(['predict', *args]) + except MisconfigurationException: + pass + + @pytest.fixture + def weights(self) -> WeightsEnum: + return ResNet18_Weights.SENTINEL2_ALL_MOCO + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = timm.create_model( + weights.meta['model'], in_chans=weights.meta['in_chans'] + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_weight_file(self, checkpoint: str) -> None: + InstanceSegmentationTask(backbone='resnet18', weights=checkpoint, num_classes=6) + + def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: + InstanceSegmentationTask( + backbone=mocked_weights.meta['model'], + weights=mocked_weights, + in_channels=mocked_weights.meta['in_chans'], + ) + + def test_weight_str(self, mocked_weights: WeightsEnum) -> None: + InstanceSegmentationTask( + backbone=mocked_weights.meta['model'], + weights=str(mocked_weights), + in_channels=mocked_weights.meta['in_chans'], + ) + + @pytest.mark.slow + def test_weight_enum_download(self, weights: WeightsEnum) -> None: + InstanceSegmentationTask( + backbone=weights.meta['model'], + weights=weights, + in_channels=weights.meta['in_chans'], + ) + + @pytest.mark.slow + def test_weight_str_download(self, weights: WeightsEnum) -> None: + InstanceSegmentationTask( + backbone=weights.meta['model'], + weights=str(weights), + in_channels=weights.meta['in_chans'], + ) + + def test_invalid_model(self) -> None: + match = "Model type 'invalid_model' is not valid." + with pytest.raises(ValueError, match=match): + InstanceSegmentationTask(model='invalid_model') + + def test_invalid_loss(self) -> None: + match = "Loss type 'invalid_loss' is not valid." + with pytest.raises(ValueError, match=match): + InstanceSegmentationTask(loss='invalid_loss') + + def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: + monkeypatch.setattr(SEN12MSDataModule, 'plot', plot) + datamodule = SEN12MSDataModule( + root='tests/data/sen12ms', batch_size=1, num_workers=0 + ) + model = InstanceSegmentationTask( + backbone='resnet18', in_channels=15, num_classes=6 + ) + trainer = Trainer( + accelerator='cpu', + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.validate(model=model, datamodule=datamodule) + + def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: + monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands) + datamodule = SEN12MSDataModule( + root='tests/data/sen12ms', batch_size=1, num_workers=0 + ) + model = InstanceSegmentationTask( + backbone='resnet18', in_channels=15, num_classes=6 + ) + trainer = Trainer( + accelerator='cpu', + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.validate(model=model, datamodule=datamodule) + + @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+']) + @pytest.mark.parametrize( + 'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0'] + ) + def test_freeze_backbone(self, model_name: str, backbone: str) -> None: + model = InstanceSegmentationTask( + model=model_name, backbone=backbone, freeze_backbone=True + ) + assert all( + [param.requires_grad is False for param in model.model.encoder.parameters()] + ) + assert all([param.requires_grad for param in model.model.decoder.parameters()]) + assert all( + [ + param.requires_grad + for param in model.model.segmentation_head.parameters() + ] + ) diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index ee69bff0021..608ac21a00b 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -12,11 +12,13 @@ from .regression import PixelwiseRegressionTask, RegressionTask from .segmentation import SemanticSegmentationTask from .simclr import SimCLRTask +from .instance_segmentation import InstanceSegmentationTask __all__ = ( 'BYOLTask', 'BaseTask', 'ClassificationTask', + 'InstanceSegmentationTask' 'IOBenchTask', 'MoCoTask', 'MultiLabelClassificationTask', diff --git a/torchgeo/trainers/instance_segmentation.py b/torchgeo/trainers/instance_segmentation.py new file mode 100644 index 00000000000..496d48fb1ab --- /dev/null +++ b/torchgeo/trainers/instance_segmentation.py @@ -0,0 +1,248 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Trainers for instance segmentation.""" + +from typing import Any + +import matplotlib.pyplot as plt +import torch +from matplotlib.figure import Figure +from torch import Tensor +from torchmetrics import MetricCollection +from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchvision.models.detection import ( + MaskRCNN_ResNet50_FPN_Weights, + maskrcnn_resnet50_fpn, +) +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor + +from torchgeo.datasets import RGBBandsMissingError, unbind_samples +from torchgeo.trainers.base import BaseTask + + +class InstanceSegmentationTask(BaseTask): + """Instance Segmentation.""" + + def __init__( + self, + model: str = 'mask_rcnn', + backbone: str = 'resnet50', + weights: str | bool | None = None, + num_classes: int = 2, + lr: float = 1e-3, + patience: int = 10, + freeze_backbone: bool = False, + ) -> None: + """Initialize a new SemanticSegmentationTask instance. + + Args: + model: Name of the model to use. + backbone: Name of the backbone to use. + weights: Initial model weights. Either a weight enum, the string + representation of a weight enum, True for ImageNet weights, False or + None for random weights, or the path to a saved model state dict. + in_channels: Number of input channels to model. + num_classes: Number of prediction classes (including the background). + lr: Learning rate for optimizer. + patience: Patience for learning rate scheduler. + freeze_backbone: Freeze the backbone network to fine-tune the + decoder and segmentation head. + + .. versionadded:: 0.7 + """ + self.weights = weights + super().__init__() + self.save_hyperparameters() + self.model = None + self.validation_outputs = [] + self.test_outputs = [] + self.configure_models() + self.configure_metrics() + + def configure_models(self) -> None: + """Initialize the model. + + Raises: + ValueError: If *model* is invalid. + """ + model = self.hparams['model'].lower() + num_classes = self.hparams['num_classes'] + + if model == 'mask_rcnn': + # Load the Mask R-CNN model with a ResNet50 backbone + self.model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, rpn_nms_thresh=0.5, box_nms_thresh=0.3) + + # Update the classification head to predict `num_classes` + in_features = self.model.roi_heads.box_predictor.cls_score.in_features + # self.model.roi_heads.box_predictor = nn.Linear(in_features, num_classes) + self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + + # Update the mask head for instance segmentation + in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels + + hidden_layer = 256 + self.model.roi_heads.mask_predictor = MaskRCNNPredictor( + in_features_mask, hidden_layer, num_classes) + + else: + raise ValueError( + f"Invalid model type '{model}'. Supported model: 'mask_rcnn'" + ) + + # Freeze backbone + if self.hparams['freeze_backbone']: + for param in self.model.backbone.parameters(): + param.requires_grad = False + + + def configure_metrics(self) -> None: + """Initialize the performance metrics. + + - Uses Mean Average Precision (mAP) for masks (IOU-based metric). + """ + self.metrics = MetricCollection([MeanAveragePrecision(iou_type="segm")]) + self.train_metrics = self.metrics.clone(prefix='train_') + self.val_metrics = self.metrics.clone(prefix='val_') + self.test_metrics = self.metrics.clone(prefix='test_') + + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + """Compute the training loss. + + Args: + batch: A batch of data from the DataLoader. Includes images and ground truth targets. + batch_idx: Index of the current batch. + + Returns: + The total loss for the batch. + """ + images, targets = batch['image'], batch['target'] + loss_dict = self.model(images, targets) + loss = sum(loss for loss in loss_dict.values()) + + print(f"\nTRAINING STEP LOSS: {loss.item()}") + + self.log('train_loss', loss, batch_size=len(images)) + return loss + + def validation_step(self, batch: Any, batch_idx: int) -> None: + """Compute the validation loss. + + Args: + batch: A batch of data from the DataLoader. Includes images and targets. + batch_idx: Index of the current batch. + + Updates metrics and stores predictions/targets for further analysis. + """ + images, targets = batch['image'], batch['target'] + batch_size = images.shape[0] + + outputs = self.model(images) + loss_dict_list = self.model(images, targets) # list of dictionaries + total_loss = sum( + sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0) + for loss_dict in loss_dict_list + ) + + for target in targets: + target["masks"] = (target["masks"] > 0).to(torch.uint8) + target["boxes"] = target["boxes"].to(torch.float32) + target["labels"] = target["labels"].to(torch.int64) + + for output in outputs: + if "masks" in output: + output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8) + + self.log('val_loss', total_loss, batch_size=batch_size) + + metrics = self.val_metrics(outputs, targets) + # Log only scalar values from metrics + scalar_metrics = {} + for key, value in metrics.items(): + if isinstance(value, torch.Tensor) and value.numel() > 1: + # Cast to float if integer and compute mean + value = value.to(torch.float32).mean() + scalar_metrics[key] = value + + self.log_dict(scalar_metrics, batch_size=batch_size) + + # check + if ( + batch_idx < 10 + and hasattr(self.trainer, 'datamodule') + and hasattr(self.trainer.datamodule, 'plot') + and self.logger + and hasattr(self.logger, 'experiment') + and hasattr(self.logger.experiment, 'add_figure') + ): + datamodule = self.trainer.datamodule + + batch['prediction_masks'] = [output['masks'].cpu() for output in outputs] + batch['image'] = batch['image'].cpu() + + sample = unbind_samples(batch)[0] + + fig: Figure | None = None + try: + fig = datamodule.plot(sample) + except RGBBandsMissingError: + pass + + if fig: + summary_writer = self.logger.experiment + summary_writer.add_figure( + f'image/{batch_idx}', fig, global_step=self.global_step + ) + plt.close() + + def test_step(self, batch: Any, batch_idx: int) -> None: + """Compute the test loss and additional metrics.""" + images, targets = batch['image'], batch['target'] + batch_size = images.shape[0] + + outputs = self.model(images) + loss_dict_list = self.model(images, targets) # Compute all losses, list of dictonaries (one for every batch element) + total_loss = sum( + sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0) + for loss_dict in loss_dict_list + ) + + for target in targets: + target["masks"] = target["masks"].to(torch.uint8) + target["boxes"] = target["boxes"].to(torch.float32) + target["labels"] = target["labels"].to(torch.int64) + + for output in outputs: + if "masks" in output: + output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8) + + self.log('test_loss', total_loss, batch_size=batch_size) + + metrics = self.val_metrics(outputs, targets) + # Log only scalar values from metrics + scalar_metrics = {} + for key, value in metrics.items(): + if isinstance(value, torch.Tensor) and value.numel() > 1: + # Cast to float if integer and compute mean + value = value.to(torch.float32).mean() + scalar_metrics[key] = value + + self.log_dict(scalar_metrics, batch_size=batch_size) + + def predict_step(self, batch: Any, batch_idx: int) -> Any: + """Perform inference on a batch of images.""" + self.model.eval() + images = batch['image'] + + with torch.no_grad(): + outputs = self.model(images) + + for output in outputs: + keep = output["scores"] > 0.05 + output["boxes"] = output["boxes"][keep] + output["labels"] = output["labels"][keep] + output["scores"] = output["scores"][keep] + output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)[keep] + + return outputs