Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when training for instance segmentation with a custom dataset #17

Closed
Robotatron opened this issue Dec 25, 2022 · 14 comments
Closed
Labels
question Further information is requested

Comments

@Robotatron
Copy link

Robotatron commented Dec 25, 2022

Using my custom dataset in COCO format for instance segmentation training.
Changed CFG to

cfg.MODEL.TEST.TASK = "instance"
cfg.INPUT.TASK_PROB.SEMANTIC = 0
cfg.INPUT.TASK_PROB.INSTANCE = 1

Still getting an error UnboundLocalError: local variable 'pan_seg_gt' referenced before assignment

From #5 and reading docs I understand I have to somehow prepare my dataset for instance segmentation training.

  1. Is it correct to say OneFormer expects COCO dataset in a panoptic format?
  2. If 1) is true, how do I convert my custom instance segm. COCO dataset to a panoptic format?
  3. I found a script from panopticapi to convert from instance to panoptic format, but judging from the description it will merge every instance annotation in an image to a single annotation, which would defy the purpose of training an instance segmentation model.
    Also getting a KeyError when using that script KeyError when using detection2panoptic_coco_format cocodataset/panopticapi#58
  4. How do I prepare a detection COCO dataset to train for instance segmentation with OneFormer? Thanks.

image

@praeclarumjj3
Copy link
Member

Hi @Robotatron, please find the answers to your questions below:

  1. Yes, the dataset_mappers in this repo expect the COCO dataset in a panoptic format.
  2. If you only want to train for instance segmentation, you do not need to convert the dataset into the panoptic format as there's an easier workaround:

You need to pay attention to two files here:

You encounter the UnboundLocalError: local variable 'pan_seg_gt' referenced before assignment error because your dataset's metadata does not have a pan_seg_fine_name key:

So, because you are using a custom dataset for instance segmentation, you should ensure the following two steps:

  • Your dataset's metadata has all the necessary attributes. Create a new dataset registry file if necessary.
  • The attributes align with the ones used in the dataset mapper. Write a new dataset mapper.

It's pretty easy to ensure these, and it should take you only a short time. Remember, because you only care about instance segmentation, you need to define the thing_classes correctly in your dataset_registry and everything will work.

@praeclarumjj3 praeclarumjj3 added the question Further information is requested label Dec 26, 2022
@Robotatron
Copy link
Author

Robotatron commented Dec 27, 2022

@praeclarumjj3 Thanks for the answer. The dataset mapping is a new thing to me, please have patience with me :)

  • Your dataset's metadata has all the necessary attributes. Create a new dataset registry file if necessary.

1.

I am registering the dataset with as I did with Mask2Former with detectron2.data.datasets.register_coco_instances, is it enough or should I use your register_coco_panoptic_annos_sem_seg() from register_coco_panoptic_annos_semseg.py? Please see the image below under "details".


2.

  • The attributes align with the ones used in the dataset mapper. Write a new dataset mapper.

Are you talking about these attributes?

        dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
        dataset_dict["instances"] = instances
        dataset_dict["orig_shape"] = image_shape
        dataset_dict["task"] = task
        dataset_dict["text"] = text
        dataset_dict["thing_ids"] = self.things

If so, what do I put under "sem_seg" since I dont have semantic segmentation labels, just an empty list?

@Robotatron
Copy link
Author

Robotatron commented Dec 27, 2022

I think I am close, I've tried using the dataset mapper from Mask2Former :

I've modified the Mask2Former dataset mapper like this (mainly at the end when setting the attributes of the dataset_dict:

Copyright (c) Facebook, Inc. and its affiliates.

import copy
import logging

import numpy as np
import pycocotools.mask as mask_util
import torch
from torch.nn import functional as F

from detectron2.config import configurable
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.projects.point_rend import ColorAugSSDTransform
from detectron2.structures import BitMasks, Instances, polygons_to_bitmask

all = ["MaskFormerInstanceDatasetMapper"]

class MaskFormerInstanceDatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by MaskFormer for instance segmentation.

The callable currently does the following:

1. Read the image from "file_name"
2. Applies geometric transforms to the image and annotation
3. Find and applies suitable cropping to the image and annotation
4. Prepare image and annotation to Tensors
"""

@configurable
def __init__(
    self,
    is_train=True,
    *,
    augmentations,
    image_format,
    size_divisibility,
):
    """
    NOTE: this interface is experimental.
    Args:
        is_train: for training or inference
        augmentations: a list of augmentations or deterministic transforms to apply
        image_format: an image format supported by :func:`detection_utils.read_image`.
        size_divisibility: pad image size to be divisible by this value
    """
    self.is_train = is_train
    self.tfm_gens = augmentations
    self.img_format = image_format
    self.size_divisibility = size_divisibility

    logger = logging.getLogger(__name__)
    mode = "training" if is_train else "inference"
    logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
    
    self.things = []
    dataset_names = cfg.DATASETS.TRAIN
    self.meta = MetadataCatalog.get(dataset_names[0])
    for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
        self.things.append(v)

@classmethod
def from_config(cls, cfg, is_train=True):
    # Build augmentation
    augs = [
        T.ResizeShortestEdge(
            cfg.INPUT.MIN_SIZE_TRAIN,
            cfg.INPUT.MAX_SIZE_TRAIN,
            cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
        )
    ]
    if cfg.INPUT.CROP.ENABLED:
        augs.append(
            T.RandomCrop(
                cfg.INPUT.CROP.TYPE,
                cfg.INPUT.CROP.SIZE,
            )
        )
    if cfg.INPUT.COLOR_AUG_SSD:
        augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
    augs.append(T.RandomFlip())

    ret = {
        "is_train": is_train,
        "augmentations": augs,
        "image_format": cfg.INPUT.FORMAT,
        "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
    }
    return ret

def __call__(self, dataset_dict):
    """
    Args:
        dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

    Returns:
        dict: a format that builtin models in detectron2 accept
    """
    assert self.is_train, "MaskFormerPanopticDatasetMapper should only be used for training!"

    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
    utils.check_image_size(dataset_dict, image)

    aug_input = T.AugInput(image)
    aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
    image = aug_input.image

    # transform instnace masks
    assert "annotations" in dataset_dict
    for anno in dataset_dict["annotations"]:
        anno.pop("keypoints", None)

    annos = [
        utils.transform_instance_annotations(obj, transforms, image.shape[:2])
        for obj in dataset_dict.pop("annotations")
        if obj.get("iscrowd", 0) == 0
    ]

    if len(annos):
        assert "segmentation" in annos[0]
    segms = [obj["segmentation"] for obj in annos]
    masks = []
    for segm in segms:
        if isinstance(segm, list):
            # polygon
            masks.append(polygons_to_bitmask(segm, *image.shape[:2]))
        elif isinstance(segm, dict):
            # COCO RLE
            masks.append(mask_util.decode(segm))
        elif isinstance(segm, np.ndarray):
            assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
                segm.ndim
            )
            # mask array
            masks.append(segm)
        else:
            raise ValueError(
                "Cannot convert segmentation of type '{}' to BitMasks!"
                "Supported types are: polygons as list[list[float] or ndarray],"
                " COCO-style RLE as a dict, or a binary segmentation mask "
                " in a 2D numpy array of shape HxW.".format(type(segm))
            )

    # Pad image and segmentation label here!
    image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
    masks = [torch.from_numpy(np.ascontiguousarray(x)) for x in masks]

    classes = [int(obj["category_id"]) for obj in annos]
    classes = torch.tensor(classes, dtype=torch.int64)

    if self.size_divisibility > 0:
        image_size = (image.shape[-2], image.shape[-1])
        padding_size = [
            0,
            self.size_divisibility - image_size[1],
            0,
            self.size_divisibility - image_size[0],
        ]
        # pad image
        image = F.pad(image, padding_size, value=128).contiguous()
        # pad mask
        masks = [F.pad(x, padding_size, value=0).contiguous() for x in masks]

    image_shape = (image.shape[-2], image.shape[-1])  # h, w

    # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
    # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
    # Therefore it's important to use torch.Tensor.
    

    # Prepare per-category binary masks
    instances = Instances(image_shape)
    instances.gt_classes = classes
    if len(masks) == 0:
        # Some image does not have annotation (all ignored)
        instances.gt_masks = torch.zeros((0, image.shape[-2], image.shape[-1]))
    else:
        masks = BitMasks(torch.stack(masks))
        instances.gt_masks = masks.tensor

    #dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
    dataset_dict["image"] = image
    dataset_dict["sem_seg"] = segms # I dont know what to put here?
    dataset_dict["instances"] = instances
    dataset_dict["task"] = "The task is instance"
    dataset_dict["orig_shape"] = image_shape
    dataset_dict["text"] = "what do I put here"
    dataset_dict["thing_ids"] = self.things
    
    return dataset_dict

It almost works, the model gets initialized, the dataloader provides images:
image

But when starting training it break:
image

Full error:

/opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead. warnings.warn(warning.format(ret)) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[20], line 1 ----> 1 trainer.train()

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/defaults.py:484, in DefaultTrainer.train(self)
477 def train(self):
478 """
479 Run training.
480
481 Returns:
482 OrderedDict of results, if evaluation is enabled. Otherwise None.
483 """
--> 484 super().train(self.start_iter, self.max_iter)
485 if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
486 assert hasattr(
487 self, "_last_eval_results"
488 ), "No evaluation results obtained during training!"

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/train_loop.py:149, in TrainerBase.train(self, start_iter, max_iter)
147 for self.iter in range(start_iter, max_iter):
148 self.before_step()
--> 149 self.run_step()
150 self.after_step()
151 # self.iter == max_iter can be used by after_train to
152 # tell whether the training successfully finished or failed
153 # due to exceptions.

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/defaults.py:494, in DefaultTrainer.run_step(self)
492 def run_step(self):
493 self._trainer.iter = self.iter
--> 494 self._trainer.run_step()

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/detectron2/engine/train_loop.py:395, in AMPTrainer.run_step(self)
392 data_time = time.perf_counter() - start
394 with autocast():
--> 395 loss_dict = self.model(data)
396 if isinstance(loss_dict, torch.Tensor):
397 losses = loss_dict

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/OneFormer/oneformer/oneformer_model.py:296, in OneFormer.forward(self, batched_inputs)
293 targets = None
295 # bipartite matching-based loss
--> 296 losses = self.criterion(outputs, targets)
298 for k in list(losses.keys()):
299 if k in self.criterion.weight_dict:

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/OneFormer/oneformer/modeling/criterion.py:294, in SetCriterion.forward(self, outputs, targets)
292 losses = {}
293 for loss in self.losses:
--> 294 losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
296 # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
297 if "aux_outputs" in outputs:

File ~/OneFormer/oneformer/modeling/criterion.py:268, in SetCriterion.get_loss(self, loss, outputs, targets, indices, num_masks)
262 loss_map = {
263 'labels': self.loss_labels,
264 'masks': self.loss_masks,
265 'contrastive': self.loss_contrastive,
266 }
267 assert loss in loss_map, f"do you really want to compute {loss} loss?"
--> 268 return loss_map[loss](outputs, targets, indices, num_masks)

File ~/OneFormer/oneformer/modeling/criterion.py:152, in SetCriterion.loss_contrastive(self, outputs, targets, indices, num_masks)
150 batch_size = image_x.shape[0]
151 # get label globally
--> 152 labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
154 text_x = outputs["texts"]
156 # [B, C]

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py:822, in get_rank(group)
819 if _rank_not_in_group(group):
820 return -1
--> 822 default_pg = _get_default_group()
823 if group is None or group is GroupMember.WORLD:
824 return default_pg.rank()

File /opt/conda/envs/oneformer2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py:410, in _get_default_group()
406 """
407 Getting the default process group created by init_process_group
408 """
409 if not is_initialized():
--> 410 raise RuntimeError(
411 "Default process group has not been initialized, "
412 "please make sure to call init_process_group."
413 )
414 return GroupMember.WORLD

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

@praeclarumjj3
Copy link
Member

Hi @Robotatron, you can refer to the following code to understand how to set the text attribute inside the dataset dictionary:

if not np.all(mask == False):
cls_name = self.class_names[class_id]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
label[mask] = class_id
num = 0
for i, cls_name in enumerate(self.class_names):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
if len(masks) == 0:
# Some image does not have annotation (all ignored)
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
instances.gt_bboxes = torch.zeros((0, 4))
else:
masks = BitMasks(
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
)
instances.gt_masks = masks.tensor
instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
return instances, texts, label

All you need to do is to loop through the class ids and collect the corresponding class names.

You do not need to worry about the sem_seg attribute as it's not required during the training. So you can ignore that.

You encounter the RuntimeError: Default process group has not been initialized, please make sure to call init_process_group. error because you are probably trying to start distributed training on a single GPU. This issue might be helpful.

@Robotatron
Copy link
Author

Robotatron commented Dec 28, 2022

Thanks for the reply @praeclarumjj3!

I've looked at your linked issue (facebookresearch/detectron2#3972) , they recommend to replace "SyncBN" batch norm with a simple "BN". However it seems OneFormer does not use "SyncBN", it uses group norm "GN".
I tried replacing cfg.MODEL.SEM_SEG_HEAD.NORM = GN with "BN" but still getting the same RuntimeError: Default process group has not been initialized

Have you ever tried training OneFormer on a single machine with a single GPU or do you know if it is supported at all?

@praeclarumjj3
Copy link
Member

Hi @Robotatron, after taking a closer look at your error, I believe the error corresponds to the contrastive_loss definition. While experimenting, we never tried training with a single GPU, so we never encountered this error. The error corresponds to the fact that we use dist in our contrastive_loss definition, which will throw an error with a single GPU.

To train on a single GPU, please replace the contrastive loss method with the code below(with a check if the distributed process group has been initialized or not). I have also updated the main branch code, so you can directly pull and run the code. Thanks for bringing this to my attention.

def loss_contrastive(self, outputs, targets, indices, num_masks):
      assert "contrastive_logits" in outputs
      assert "texts" in outputs
      image_x = outputs["contrastive_logits"].float()
      
      batch_size = image_x.shape[0]
      # get label globally
      if is_dist_avail_and_initialized():
          labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
      else:
          labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device)

      text_x = outputs["texts"]

      # [B, C]
      image_x = F.normalize(image_x.flatten(1), dim=-1)
      text_x = F.normalize(text_x.flatten(1), dim=-1)

      if is_dist_avail_and_initialized():
          logits_per_img = image_x @ dist_collect(text_x).t()
          logits_per_text = text_x @ dist_collect(image_x).t()
      else:
          logits_per_img = image_x @ text_x.t()
          logits_per_text = text_x @ image_x.t()

      logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
      loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
      loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)

      loss_contrastive = loss_img + loss_text

      losses = {"loss_contrastive": loss_contrastive}
      return losses

@Robotatron
Copy link
Author

@praeclarumjj3 Thanks for updating the repo, the training now works on a single GPU <3

Last question with regards to the "text" attribute, you said earlier:

Hi @Robotatron, you can refer to the following code to understand how to set the text attribute inside the dataset dictionary:

if not np.all(mask == False):
cls_name = self.class_names[class_id]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
label[mask] = class_id
num = 0
for i, cls_name in enumerate(self.class_names):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
if len(masks) == 0:
# Some image does not have annotation (all ignored)
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
instances.gt_bboxes = torch.zeros((0, 4))
else:
masks = BitMasks(
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
)
instances.gt_masks = masks.tensor
instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
return instances, texts, label

All you need to do is to loop through the class ids and collect the corresponding class names.

1. Is dataset_dict["text"] used for training?

I am not really sure what the purpose of this is, it seems to be a list of string messages of the format "a photo with a {cls_name}"

I took your code for setting up the "texts" attribute and tried to adjust it for the instance segmentation Mask2Former dataset mapper. I am stuck at the following two lines in your code (196 and 197): https://github.com/SHI-Labs/OneFormer/blob/main/oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py#L196

image

A)

There is no segments_info in my instance segmentation COCO dataset_dict, so I've used annos from the Mask2Former dataset mapper:

        annos = [
            utils.transform_instance_annotations(obj, transforms, image.shape[:2])
            for obj in dataset_dict.pop("annotations")
            if obj.get("iscrowd", 0) == 0
        ]

B)

But then I was not sure what mask to use. I don't have pan_seg_gt variable in the instance segmentation dataset mapper and hence the if statement (if not np.all(mask == False):) would not work, so that I deleted those two lines:

  mask = pan_seg_gt == segment_info["id"]
                    if not np.all(mask == False):

and the code for setting up the "text" attribute looks like this (on the right, your original code on the left):
image

2.

Can the text attribute be populated like in the image above or should I still use the "mask" variable with the if statement?
Since I don't have the segments_info in the dataset_dict I am unsure what to use instead of pan_seg_gt and segment_info["id"]. For pan_seg_gt maybe I can use instances.gt_masks instead? Or is pan_seg_gt an ID of the ground truth annotation?

@scotwilli
Copy link

@Robotatron , are you able to train oneformer for instance segmentation.

@Robotatron
Copy link
Author

@Robotatron , are you able to train oneformer for instance segmentation.

Yes it works and trains, maybe I should submit a PR with a modified Mask2Former data mapper for instance segmentation if @praeclarumjj3 is OK with it

@praeclarumjj3
Copy link
Member

Hi @Robotatron, thanks for the offer. As many people have been opening issues regarding the custom training, I will push some custom dataset_mappers for everyone's reference with documentation.

@praeclarumjj3
Copy link
Member

praeclarumjj3 commented Jan 3, 2023

Hi @Robotatron, yes. we use dataset_dict["text"] during training to obtain the text_queries which are used while calculating the query-text contrastive loss. I request you to refer to Fig. 3 and Sec 3.1 in our paper for more details.

Also, you may use the following script to train OneFormer on a custom instance segmentation dataset. Remember you will need to register your dataset's metadata with the class names.

import copy
import logging

import numpy as np
import torch

from detectron2.data import MetadataCatalog
from detectron2.config import configurable
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
from pycocotools import mask as coco_mask

__all__ = ["InstanceCOCOCustomNewBaselineDatasetMapper"]


def convert_coco_poly_to_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks


def build_transform_gen(cfg, is_train):
    """
    Create a list of default :class:`Augmentation` from config.
    Now it includes resizing and flipping.
    Returns:
        list[Augmentation]
    """
    assert is_train, "Only support training augmentation"
    image_size = cfg.INPUT.IMAGE_SIZE
    min_scale = cfg.INPUT.MIN_SCALE
    max_scale = cfg.INPUT.MAX_SCALE

    augmentation = []

    if cfg.INPUT.RANDOM_FLIP != "none":
        augmentation.append(
            T.RandomFlip(
                horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
                vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
            )
        )

    augmentation.extend([
        T.ResizeScale(
            min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
        ),
        T.FixedSizeCrop(crop_size=(image_size, image_size)),
    ])

    return augmentation


# This is specifically designed for the COCO Instance Segmentation dataset.
class InstanceCOCOCustomNewBaselineDatasetMapper:
    """
    A callable which takes a dataset dict in Detectron2 Dataset format,
    and map it into a format used by OneFormer for custom instance segmentation using COCO format.

    The callable currently does the following:

    1. Read the image from "file_name"
    2. Applies geometric transforms to the image and annotation
    3. Find and applies suitable cropping to the image and annotation
    4. Prepare image and annotation to Tensors
    """

    @configurable
    def __init__(
        self,
        is_train=True,
        *,
        num_queries,
        tfm_gens,
        meta,
        image_format,
        max_seq_len,
        task_seq_len,
    ):
        """
        NOTE: this interface is experimental.
        Args:
            is_train: for training or inference
            augmentations: a list of augmentations or deterministic transforms to apply
            crop_gen: crop augmentation
            tfm_gens: data augmentation
            image_format: an image format supported by :func:`detection_utils.read_image`.
        """
        self.tfm_gens = tfm_gens
        logging.getLogger(__name__).info(
            "[InstanceCOCOCustomNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
                str(self.tfm_gens)
            )
        )

        self.img_format = image_format
        self.is_train = is_train
        self.meta = meta
        self.ignore_label = self.meta.ignore_label
        self.num_queries = num_queries

        self.things = []
        for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
            self.things.append(v)
        self.class_names = self.meta.thing_classes
        self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
        self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)

    @classmethod
    def from_config(cls, cfg, is_train=True):
        # Build augmentation
        tfm_gens = build_transform_gen(cfg, is_train)
        dataset_names = cfg.DATASETS.TRAIN
        meta = MetadataCatalog.get(dataset_names[0])

        ret = {
            "is_train": is_train,
            "meta": meta,
            "tfm_gens": tfm_gens,
            "image_format": cfg.INPUT.FORMAT,
            "num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
            "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
            "max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
        }
        return ret
    
    def _get_texts(self, classes, num_class_obj):
        
        classes = list(np.array(classes))
        texts = ["an instance photo"] * self.num_queries
        
        for class_id in classes:
            cls_name = self.class_names[class_id]
            num_class_obj[cls_name] += 1
        
        num = 0
        for i, cls_name in enumerate(self.class_names):
            if num_class_obj[cls_name] > 0:
                for _ in range(num_class_obj[cls_name]):
                    if num >= len(texts):
                        break
                    texts[num] = f"a photo with a {cls_name}"
                    num += 1

        return texts

    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
        utils.check_image_size(dataset_dict, image)

        # TODO: get padding mask
        # by feeding a "segmentation mask" to the same transforms
        padding_mask = np.ones(image.shape[:2])

        image, transforms = T.apply_transform_gens(self.tfm_gens, image)
        # the crop transformation has default padding value 0 for segmentation
        padding_mask = transforms.apply_segmentation(padding_mask)
        padding_mask = ~ padding_mask.astype(bool)

        image_shape = image.shape[:2]  # h, w

        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
        dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask))

        if not self.is_train:
            # USER: Modify this if you want to keep them for some reason.
            dataset_dict.pop("annotations", None)
            return dataset_dict

        if "annotations" in dataset_dict:
            # USER: Modify this if you want to keep them for some reason.
            for anno in dataset_dict["annotations"]:
                anno.pop("keypoints", None)

            # USER: Implement additional transformations if you have other types of data
            annos = [
                utils.transform_instance_annotations(obj, transforms, image_shape)
                for obj in dataset_dict.pop("annotations")
                if obj.get("iscrowd", 0) == 0
            ]

            instances = utils.annotations_to_instances(annos, image_shape)
        
            instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
            # Need to filter empty instances first (due to augmentation)
            instances = utils.filter_empty_instances(instances)
            # Generate masks from polygon
            h, w = instances.image_size
            # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float)
            if hasattr(instances, 'gt_masks'):
                gt_masks = instances.gt_masks
                gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
                instances.gt_masks = gt_masks
            dataset_dict["instances"] = instances

        num_class_obj = {}
        for name in self.class_names:
            num_class_obj[name] = 0

        task = "The task is instance"
        text = self._get_texts(instances.gt_classes, num_class_obj)

        dataset_dict["instances"] = instances
        dataset_dict["orig_shape"] = image_shape
        dataset_dict["task"] = task
        dataset_dict["text"] = text
        dataset_dict["thing_ids"] = self.things

        return dataset_dict

@praeclarumjj3
Copy link
Member

I pushed some instructions for training with custom datasets. You may take a look if you face any more issues: https://github.com/SHI-Labs/OneFormer/tree/main/datasets/custom_datasets.

@Robotatron
Copy link
Author

What a legend, thanks! Will check it out when I have time later, thanks!

@sushilkhadkaanon
Copy link

Hi, @Robotatron Could you please help me how train the model on a custom dataset? Anything will be appreciated. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants