From f477aa8b760f9d6ab7d913dcd0dcc52129de9ac6 Mon Sep 17 00:00:00 2001 From: Federico Galatolo Date: Sat, 12 Oct 2019 11:37:46 +0200 Subject: [PATCH] Added check for torchvision --- sampler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sampler.py b/sampler.py index bf3a543..b324b5a 100644 --- a/sampler.py +++ b/sampler.py @@ -1,5 +1,9 @@ import torch -import torchvision +is_torchvision_installed = True +try: + import torchvision +except: + is_torchvision_installed = False import torch.utils.data import random @@ -38,9 +42,9 @@ def _get_label(self, dataset, idx, labels = None): else: # Trying guessing dataset_type = type(dataset) - if dataset_type is torchvision.datasets.MNIST: + if is_torchvision_installed and dataset_type is torchvision.datasets.MNIST: return dataset.train_labels[idx].item() - elif dataset_type is torchvision.datasets.ImageFolder: + elif is_torchvision_installed and dataset_type is torchvision.datasets.ImageFolder: return dataset.imgs[idx][1] else: raise Exception("You should pass the tensor of labels to the constructor as second argument")