Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.utils.data.DataLoader with BalancedBatchSampler results into higher amount of batches than intented? #3

Open
candasunal opened this issue May 29, 2022 · 3 comments

Comments

@candasunal
Copy link

candasunal commented May 29, 2022

Hello Federico,

First of all thank you very much for this repo, it seems like your solution is just the one I needed.

I wanted my batches have the equal amount of samples from each class (in this case 10 samples from each MNIST class). When I wanted to use it in torch.utils.data.DataLoader as sampler argument with the batch size as 100, the result has larger size than it is supposed to be.

For example, the code below creates 675 trainloader items (this term might be wrong) instead of 600:

batch_size = 100
train_MNIST = datasets.MNIST('./content/MNIST_DATA/train/', 
                             train = True, 
                             transform = transforms.ToTensor(), 
                             download = True)

trainloader = torch.utils.data.DataLoader(train_MNIST, 
                                          sampler=BalancedBatchSampler(train_MNIST), 
                                          batch_size = batch_size)

Screen Shot 2022-05-29 at 21 58 18

I attach the result I see from spyder IDE.

Am I missing something, shouldn't it be 600 instead of 675?

Thank you in advance.

@uzair789
Copy link

HI, Did you manage to figure this out? I have the same issue. Thank you

@uzair789
Copy link

I figured it out. Its because of the imbalance in your dataset. The while loop will keep running as long as the count is less than the balanced_max value. Hence if your balance_max (which is the number of samples from your largest class) is very large and the other class counts are very less, then in order to cover all the samples from the largest class, additional batches will be created.

@candasunal
Copy link
Author

Hello Uzair,

I used the conventional MNIST dataset for this and that's not unbalanced.

After I couldn't find the solution for that I moved on to something else, but will take another look at the implementation that I did along with your comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants