From 5c5714668113c41c91713aea200f44e18eaa6ba1 Mon Sep 17 00:00:00 2001 From: Federico Galatolo Date: Wed, 10 Jul 2019 18:45:20 +0200 Subject: [PATCH] Using indicies instead of pop() for better performances and multiple epochs support --- example.py | 6 ++++-- sampler.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/example.py b/example.py index 3f34307..564dd59 100644 --- a/example.py +++ b/example.py @@ -1,6 +1,7 @@ import torch from sampler import BalancedBatchSampler +epochs = 3 size = 20 features = 5 classes_prob = torch.tensor([0.1, 0.4, 0.5]) @@ -12,5 +13,6 @@ train_loader = torch.utils.data.DataLoader(dataset, sampler=BalancedBatchSampler(dataset, dataset_Y), batch_size=6) -for batch_x, batch_y in train_loader: - print("labels: %s\ninputs: %s\n" % (batch_y, batch_x)) \ No newline at end of file +for epoch in range(0, epochs): + for batch_x, batch_y in train_loader: + print("epoch: %d labels: %s\ninputs: %s\n" % (epoch, batch_y, batch_x)) \ No newline at end of file diff --git a/sampler.py b/sampler.py index 28851c3..bf3a543 100644 --- a/sampler.py +++ b/sampler.py @@ -23,12 +23,14 @@ def __init__(self, dataset, labels=None): self.dataset[label].append(random.choice(self.dataset[label])) self.keys = list(self.dataset.keys()) self.currentkey = 0 + self.indices = [-1]*len(self.keys) def __iter__(self): - while len(self.dataset[self.keys[self.currentkey]]) > 0: - yield self.dataset[self.keys[self.currentkey]].pop() + while self.indices[self.currentkey] < self.balanced_max - 1: + self.indices[self.currentkey] += 1 + yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]] self.currentkey = (self.currentkey + 1) % len(self.keys) - + self.indices = [-1]*len(self.keys) def _get_label(self, dataset, idx, labels = None): if self.labels is not None: