diff --git a/test/test_datasets.py b/test/test_datasets.py index 26064a11c71..aa100aa55c1 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -827,6 +827,11 @@ def test_transforms_v2_wrapper_spawn(self): with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) + def test_slice_error(self): + with self.create_dataset() as (dataset, _): + with pytest.raises(ValueError, match="Index must be of type integer"): + dataset[:2] + class CocoCaptionsTestCase(CocoDetectionTestCase): DATASET_CLASS = datasets.CocoCaptions diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 7fda9667938..fd374a8ec24 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -44,6 +44,10 @@ def _load_target(self, id: int) -> List[Any]: return self.coco.loadAnns(self.coco.getAnnIds(id)) def __getitem__(self, index: int) -> Tuple[Any, Any]: + + if not isinstance(index, int): + raise ValueError(f"Index must be of type integer, got {type(index)} instead.") + id = self.ids[index] image = self._load_image(id) target = self._load_target(id)