diff --git a/tests/trainers/test_instance_segmentation.py b/tests/trainers/test_instance_segmentation.py
new file mode 100644
index 00000000000..527d6960b99
--- /dev/null
+++ b/tests/trainers/test_instance_segmentation.py
@@ -0,0 +1,230 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+from pathlib import Path
+from typing import Any, cast
+
+import pytest
+import segmentation_models_pytorch as smp
+import timm
+import torch
+import torch.nn as nn
+from lightning.pytorch import Trainer
+from pytest import MonkeyPatch
+from torch.nn.modules import Module
+from torchvision.models._api import WeightsEnum
+
+from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
+from torchgeo.datasets import LandCoverAI, RGBBandsMissingError
+from torchgeo.main import main
+from torchgeo.models import ResNet18_Weights
+from torchgeo.trainers import InstanceSegmentationTask
+
+
+class SegmentationTestModel(Module):
+    def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None:
+        super().__init__()
+        self.conv1 = nn.Conv2d(
+            in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return cast(torch.Tensor, self.conv1(x))
+
+
+def create_model(**kwargs: Any) -> Module:
+    return SegmentationTestModel(**kwargs)
+
+
+def plot(*args: Any, **kwargs: Any) -> None:
+    return None
+
+
+def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
+    raise RGBBandsMissingError()
+
+
+class TestSemanticSegmentationTask:
+    @pytest.mark.parametrize(
+        'name',
+        [
+            'agrifieldnet',
+            'cabuar',
+            'chabud',
+            'chesapeake_cvpr_5',
+            'chesapeake_cvpr_7',
+            'deepglobelandcover',
+            'etci2021',
+            'ftw',
+            'geonrw',
+            'gid15',
+            'inria',
+            'l7irish',
+            'l8biome',
+            'landcoverai',
+            'landcoverai100',
+            'loveda',
+            'naipchesapeake',
+            'potsdam2d',
+            'sen12ms_all',
+            'sen12ms_s1',
+            'sen12ms_s2_all',
+            'sen12ms_s2_reduced',
+            'sentinel2_cdl',
+            'sentinel2_eurocrops',
+            'sentinel2_nccm',
+            'sentinel2_south_america_soybean',
+            'southafricacroptype',
+            'spacenet1',
+            'spacenet6',
+            'ssl4eo_l_benchmark_cdl',
+            'ssl4eo_l_benchmark_nlcd',
+            'vaihingen2d',
+        ],
+    )
+    def test_trainer(
+        self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
+    ) -> None:
+
+        config = os.path.join('tests', 'conf', name + '.yaml')
+
+        args = [
+            '--config',
+            config,
+            '--trainer.accelerator',
+            'cpu',
+            '--trainer.fast_dev_run',
+            str(fast_dev_run),
+            '--trainer.max_epochs',
+            '1',
+            '--trainer.log_every_n_steps',
+            '1',
+        ]
+
+        main(['fit', *args])
+        try:
+            main(['test', *args])
+        except MisconfigurationException:
+            pass
+        try:
+            main(['predict', *args])
+        except MisconfigurationException:
+            pass
+
+    @pytest.fixture
+    def weights(self) -> WeightsEnum:
+        return ResNet18_Weights.SENTINEL2_ALL_MOCO
+
+    @pytest.fixture
+    def mocked_weights(
+        self,
+        tmp_path: Path,
+        monkeypatch: MonkeyPatch,
+        weights: WeightsEnum,
+        load_state_dict_from_url: None,
+    ) -> WeightsEnum:
+        path = tmp_path / f'{weights}.pth'
+        model = timm.create_model(
+            weights.meta['model'], in_chans=weights.meta['in_chans']
+        )
+        torch.save(model.state_dict(), path)
+        try:
+            monkeypatch.setattr(weights.value, 'url', str(path))
+        except AttributeError:
+            monkeypatch.setattr(weights, 'url', str(path))
+        return weights
+
+    def test_weight_file(self, checkpoint: str) -> None:
+        InstanceSegmentationTask(backbone='resnet18', weights=checkpoint, num_classes=6)
+
+    def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
+        InstanceSegmentationTask(
+            backbone=mocked_weights.meta['model'],
+            weights=mocked_weights,
+            in_channels=mocked_weights.meta['in_chans'],
+        )
+
+    def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
+        InstanceSegmentationTask(
+            backbone=mocked_weights.meta['model'],
+            weights=str(mocked_weights),
+            in_channels=mocked_weights.meta['in_chans'],
+        )
+
+    @pytest.mark.slow
+    def test_weight_enum_download(self, weights: WeightsEnum) -> None:
+        InstanceSegmentationTask(
+            backbone=weights.meta['model'],
+            weights=weights,
+            in_channels=weights.meta['in_chans'],
+        )
+
+    @pytest.mark.slow
+    def test_weight_str_download(self, weights: WeightsEnum) -> None:
+        InstanceSegmentationTask(
+            backbone=weights.meta['model'],
+            weights=str(weights),
+            in_channels=weights.meta['in_chans'],
+        )
+
+    def test_invalid_model(self) -> None:
+        match = "Model type 'invalid_model' is not valid."
+        with pytest.raises(ValueError, match=match):
+            InstanceSegmentationTask(model='invalid_model')
+
+    def test_invalid_loss(self) -> None:
+        match = "Loss type 'invalid_loss' is not valid."
+        with pytest.raises(ValueError, match=match):
+            InstanceSegmentationTask(loss='invalid_loss')
+
+    def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
+        monkeypatch.setattr(SEN12MSDataModule, 'plot', plot)
+        datamodule = SEN12MSDataModule(
+            root='tests/data/sen12ms', batch_size=1, num_workers=0
+        )
+        model = InstanceSegmentationTask(
+            backbone='resnet18', in_channels=15, num_classes=6
+        )
+        trainer = Trainer(
+            accelerator='cpu',
+            fast_dev_run=fast_dev_run,
+            log_every_n_steps=1,
+            max_epochs=1,
+        )
+        trainer.validate(model=model, datamodule=datamodule)
+
+    def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
+        monkeypatch.setattr(SEN12MSDataModule, 'plot', plot_missing_bands)
+        datamodule = SEN12MSDataModule(
+            root='tests/data/sen12ms', batch_size=1, num_workers=0
+        )
+        model = InstanceSegmentationTask(
+            backbone='resnet18', in_channels=15, num_classes=6
+        )
+        trainer = Trainer(
+            accelerator='cpu',
+            fast_dev_run=fast_dev_run,
+            log_every_n_steps=1,
+            max_epochs=1,
+        )
+        trainer.validate(model=model, datamodule=datamodule)
+
+    @pytest.mark.parametrize('model_name', ['unet', 'deeplabv3+'])
+    @pytest.mark.parametrize(
+        'backbone', ['resnet18', 'mobilenet_v2', 'efficientnet-b0']
+    )
+    def test_freeze_backbone(self, model_name: str, backbone: str) -> None:
+        model = InstanceSegmentationTask(
+            model=model_name, backbone=backbone, freeze_backbone=True
+        )
+        assert all(
+            [param.requires_grad is False for param in model.model.encoder.parameters()]
+        )
+        assert all([param.requires_grad for param in model.model.decoder.parameters()])
+        assert all(
+            [
+                param.requires_grad
+                for param in model.model.segmentation_head.parameters()
+            ]
+        )
diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py
index ee69bff0021..608ac21a00b 100644
--- a/torchgeo/trainers/__init__.py
+++ b/torchgeo/trainers/__init__.py
@@ -12,11 +12,13 @@
 from .regression import PixelwiseRegressionTask, RegressionTask
 from .segmentation import SemanticSegmentationTask
 from .simclr import SimCLRTask
+from .instance_segmentation import InstanceSegmentationTask
 
 __all__ = (
     'BYOLTask',
     'BaseTask',
     'ClassificationTask',
+    'InstanceSegmentationTask'
     'IOBenchTask',
     'MoCoTask',
     'MultiLabelClassificationTask',
diff --git a/torchgeo/trainers/instance_segmentation.py b/torchgeo/trainers/instance_segmentation.py
new file mode 100644
index 00000000000..496d48fb1ab
--- /dev/null
+++ b/torchgeo/trainers/instance_segmentation.py
@@ -0,0 +1,248 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Trainers for instance segmentation."""
+
+from typing import Any
+
+import matplotlib.pyplot as plt
+import torch
+from matplotlib.figure import Figure
+from torch import Tensor
+from torchmetrics import MetricCollection
+from torchmetrics.detection.mean_ap import MeanAveragePrecision
+from torchvision.models.detection import (
+    MaskRCNN_ResNet50_FPN_Weights,
+    maskrcnn_resnet50_fpn,
+)
+from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
+from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
+
+from torchgeo.datasets import RGBBandsMissingError, unbind_samples
+from torchgeo.trainers.base import BaseTask
+
+
+class InstanceSegmentationTask(BaseTask):
+    """Instance Segmentation."""
+
+    def __init__(
+        self,
+        model: str = 'mask_rcnn',           
+        backbone: str = 'resnet50',         
+        weights: str | bool | None = None, 
+        num_classes: int = 2,              
+        lr: float = 1e-3,                   
+        patience: int = 10,                 
+        freeze_backbone: bool = False,      
+    ) -> None:
+        """Initialize a new SemanticSegmentationTask instance.
+
+        Args:
+            model: Name of the model to use.
+            backbone: Name of the backbone to use.
+            weights: Initial model weights. Either a weight enum, the string
+                representation of a weight enum, True for ImageNet weights, False or
+                None for random weights, or the path to a saved model state dict.
+            in_channels: Number of input channels to model.
+            num_classes: Number of prediction classes (including the background).
+            lr: Learning rate for optimizer.
+            patience: Patience for learning rate scheduler.
+            freeze_backbone: Freeze the backbone network to fine-tune the
+                decoder and segmentation head.
+
+        .. versionadded:: 0.7
+        """
+        self.weights = weights         
+        super().__init__()              
+        self.save_hyperparameters()     
+        self.model = None               
+        self.validation_outputs = []    
+        self.test_outputs = []          
+        self.configure_models()         
+        self.configure_metrics()        
+
+    def configure_models(self) -> None:
+        """Initialize the model.
+
+        Raises:
+            ValueError: If *model* is invalid.
+        """
+        model = self.hparams['model'].lower()      
+        num_classes = self.hparams['num_classes']  
+
+        if model == 'mask_rcnn':
+            # Load the Mask R-CNN model with a ResNet50 backbone
+            self.model = maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, rpn_nms_thresh=0.5, box_nms_thresh=0.3)  
+
+            # Update the classification head to predict `num_classes` 
+            in_features = self.model.roi_heads.box_predictor.cls_score.in_features
+            # self.model.roi_heads.box_predictor = nn.Linear(in_features, num_classes)
+            self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
+
+            # Update the mask head for instance segmentation
+            in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
+
+            hidden_layer = 256
+            self.model.roi_heads.mask_predictor = MaskRCNNPredictor(
+                 in_features_mask, hidden_layer, num_classes)
+
+        else:
+            raise ValueError(
+                f"Invalid model type '{model}'. Supported model: 'mask_rcnn'"
+            )
+
+        # Freeze backbone 
+        if self.hparams['freeze_backbone']:
+            for param in self.model.backbone.parameters():
+                param.requires_grad = False  
+
+
+    def configure_metrics(self) -> None:
+        """Initialize the performance metrics.
+
+        - Uses Mean Average Precision (mAP) for masks (IOU-based metric).
+        """
+        self.metrics = MetricCollection([MeanAveragePrecision(iou_type="segm")])
+        self.train_metrics = self.metrics.clone(prefix='train_')
+        self.val_metrics = self.metrics.clone(prefix='val_')
+        self.test_metrics = self.metrics.clone(prefix='test_')
+
+    def training_step(self, batch: Any, batch_idx: int) -> Tensor:
+        """Compute the training loss.
+
+        Args:
+            batch: A batch of data from the DataLoader. Includes images and ground truth targets.
+            batch_idx: Index of the current batch.
+
+        Returns:
+            The total loss for the batch.
+        """
+        images, targets = batch['image'], batch['target']     
+        loss_dict = self.model(images, targets)               
+        loss = sum(loss for loss in loss_dict.values())  
+
+        print(f"\nTRAINING STEP LOSS: {loss.item()}")
+
+        self.log('train_loss', loss, batch_size=len(images))  
+        return loss  
+
+    def validation_step(self, batch: Any, batch_idx: int) -> None:
+        """Compute the validation loss.
+
+        Args:
+            batch: A batch of data from the DataLoader. Includes images and targets.
+            batch_idx: Index of the current batch.
+
+        Updates metrics and stores predictions/targets for further analysis.
+        """
+        images, targets = batch['image'], batch['target']   
+        batch_size = images.shape[0]
+         
+        outputs = self.model(images) 
+        loss_dict_list = self.model(images, targets)  # list of dictionaries
+        total_loss = sum(
+            sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0)
+            for loss_dict in loss_dict_list
+        )
+
+        for target in targets:
+            target["masks"] = (target["masks"] > 0).to(torch.uint8)
+            target["boxes"] = target["boxes"].to(torch.float32)
+            target["labels"] = target["labels"].to(torch.int64)
+
+        for output in outputs:
+            if "masks" in output:
+                output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)
+
+        self.log('val_loss', total_loss, batch_size=batch_size)
+
+        metrics = self.val_metrics(outputs, targets)
+        # Log only scalar values from metrics
+        scalar_metrics = {}
+        for key, value in metrics.items():
+            if isinstance(value, torch.Tensor) and value.numel() > 1:
+                # Cast to float if integer and compute mean
+                value = value.to(torch.float32).mean()
+            scalar_metrics[key] = value
+
+        self.log_dict(scalar_metrics, batch_size=batch_size)           
+
+        # check
+        if (
+            batch_idx < 10
+            and hasattr(self.trainer, 'datamodule')
+            and hasattr(self.trainer.datamodule, 'plot')
+            and self.logger
+            and hasattr(self.logger, 'experiment')
+            and hasattr(self.logger.experiment, 'add_figure')
+        ):
+            datamodule = self.trainer.datamodule
+
+            batch['prediction_masks'] = [output['masks'].cpu() for output in outputs]  
+            batch['image'] = batch['image'].cpu()
+
+            sample = unbind_samples(batch)[0]
+
+            fig: Figure | None = None
+            try:
+                fig = datamodule.plot(sample)
+            except RGBBandsMissingError:
+                pass
+
+            if fig:
+                summary_writer = self.logger.experiment
+                summary_writer.add_figure(
+                    f'image/{batch_idx}', fig, global_step=self.global_step
+                )
+                plt.close()
+    
+    def test_step(self, batch: Any, batch_idx: int) -> None:
+        """Compute the test loss and additional metrics."""
+        images, targets = batch['image'], batch['target']
+        batch_size = images.shape[0]
+
+        outputs = self.model(images)
+        loss_dict_list = self.model(images, targets)  # Compute all losses, list of dictonaries (one for every batch element)
+        total_loss = sum(
+            sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0)
+            for loss_dict in loss_dict_list
+        )
+
+        for target in targets:
+            target["masks"] = target["masks"].to(torch.uint8)
+            target["boxes"] = target["boxes"].to(torch.float32)
+            target["labels"] = target["labels"].to(torch.int64)
+
+        for output in outputs:
+            if "masks" in output:
+                output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)
+
+        self.log('test_loss', total_loss, batch_size=batch_size)
+
+        metrics = self.val_metrics(outputs, targets)
+        # Log only scalar values from metrics
+        scalar_metrics = {}
+        for key, value in metrics.items():
+            if isinstance(value, torch.Tensor) and value.numel() > 1:
+                # Cast to float if integer and compute mean
+                value = value.to(torch.float32).mean()
+            scalar_metrics[key] = value
+
+        self.log_dict(scalar_metrics, batch_size=batch_size)    
+
+    def predict_step(self, batch: Any, batch_idx: int) -> Any:
+        """Perform inference on a batch of images."""
+        self.model.eval()
+        images = batch['image']
+
+        with torch.no_grad():  
+            outputs = self.model(images)
+
+        for output in outputs:
+            keep = output["scores"] > 0.05  
+            output["boxes"] = output["boxes"][keep]
+            output["labels"] = output["labels"][keep]
+            output["scores"] = output["scores"][keep]
+            output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)[keep]
+
+        return outputs