Skip to content

Commit

Permalink
Added check for torchvision
Browse files Browse the repository at this point in the history
  • Loading branch information
galatolofederico committed Oct 12, 2019
1 parent 5c57146 commit f477aa8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch
import torchvision
is_torchvision_installed = True
try:
import torchvision
except:
is_torchvision_installed = False
import torch.utils.data
import random

Expand Down Expand Up @@ -38,9 +42,9 @@ def _get_label(self, dataset, idx, labels = None):
else:
# Trying guessing
dataset_type = type(dataset)
if dataset_type is torchvision.datasets.MNIST:
if is_torchvision_installed and dataset_type is torchvision.datasets.MNIST:
return dataset.train_labels[idx].item()
elif dataset_type is torchvision.datasets.ImageFolder:
elif is_torchvision_installed and dataset_type is torchvision.datasets.ImageFolder:
return dataset.imgs[idx][1]
else:
raise Exception("You should pass the tensor of labels to the constructor as second argument")
Expand Down

0 comments on commit f477aa8

Please sign in to comment.