Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about domain adaptation BatchNormalization #5

Open
OutstanderWang opened this issue Nov 24, 2018 · 9 comments
Open

Question about domain adaptation BatchNormalization #5

OutstanderWang opened this issue Nov 24, 2018 · 9 comments

Comments

@OutstanderWang
Copy link

the class DomainAdaptModule seems to maintain the different batch normalization parameters for source and target domain. But it seems that only _init_bn_layers function is called in the init and the save/restore the source/target bn function like bn_save_source is never called or appears in any other places.
Could you please help me understand how your code maintain the different source and target BN? And how they are used in the test phase?
Thank you ahead of time

@Britefury
Copy link
Owner

Yes, well, kind of.
I was trying out a variety of different approaches, so the facilities of DomainAdaptModule are not in fact used at all. Yes, the initialisation method is called, but then its never used, but was in the past. Sorry about that.
Okay, how it actually works:
During training, a batch of samples from the source domain is passed through the network. The batch norm layers will compute mean+std and normalize. Then a batch of target domain samples is passed through; the batch norm layer will compute different mean+std and normalize the target domain samples separately from the source domain samples. This is how its done.

@YilinLiu97
Copy link

Hi @Britefury , I'm wondering, so you end up only using the teacher model for testing the data in the new domain, right? since only the teacher model stores the mean+std of the data that we want the model to generalize to.

@Britefury
Copy link
Owner

Britefury commented Feb 18, 2019

Hi @YilinLiu97, thats a good question.

Lets see. The batch-norm layers in both models will maintain a running mean and variance. You noted in Issue #6 that my WeightEMA update code has a bug that results in the running mean and variance not being the EMA of those of the student model. So the teacher will keep its own running mean and variance that is learned during training. The teacher model is only applied to samples from the target domain, so its mean/var will be based on target samples only. Thanks for asking this; I hadn't spotted it until now. Thats interesting. I wonder if fixing the weight EMA code would degrade performance?

My new code that I use in my current experiments is now this:

class EMAWeightOptimizer (object):
    def __init__(self, target_net, source_net, ema_alpha):
        self.target_net = target_net
        self.source_net = source_net
        self.ema_alpha = ema_alpha
        self.target_params = list(target_net.state_dict().values())
        self.source_params = list(source_net.state_dict().values())

        for tgt_p, src_p in zip(self.target_params, self.source_params):
            tgt_p[:] = src_p[:]

        target_keys = set(target_net.state_dict().keys())
        source_keys = set(source_net.state_dict().keys())
        if target_keys != source_keys:
            raise ValueError('Source and target networks do not have the same state dict keys; do they have different architectures?')


    def step(self):
        one_minus_alpha = 1.0 - self.ema_alpha
        for tgt_p, src_p in zip(self.target_params, self.source_params):
            tgt_p.mul_(self.ema_alpha)
            tgt_p.add_(src_p * one_minus_alpha)

@YilinLiu97
Copy link

@Britefury , I tried it on my own experiments (with deep copied weights at the beginning) and the ema model doesn't do well on target data (!?). It acted like how a model will do when it's first initialized, which is really weird. Even with alpha=0, and I checked that the ema model has the same weights as those of the model and the only difference is the running_mean/std which is expected. I expect that the teacher model will do well at least in the target domain with its own running_mean/std, but it didn't.

I think that the initialization of the EMA shouldn't matter that much as long as the alpha slowly ramps up just like the consistency loss (so that the teacher model quickly forgets the early weights), and this is what I've done in my own experiments, but the results were described above. Since the alpha was set to be 0.99 also in the original implementation of Mean teacher (and they didn't seem to deep copy the weights at the beginning), I do wonder how could the teacher model lead to good performance.

@YilinLiu97
Copy link

Hi @Britefury, I'm wondering, is ema working well in your experiment? I mean, using the EMA model only for testing either source data or target data. Thanks!

@Britefury
Copy link
Owner

The only experiments I have tried are with the buggy EMA implementation that is in this codebase. So it may be that fixing it breaks things. The new 'fixed' EMA code I pasted above is used in other newer experiments, but not these. I suppose I need to do a comparison!

@Britefury
Copy link
Owner

Okay. I've made a modified version of the code (not putting it online yet) and I'm going to compare the EMA implementations and see what difference I get. I should have an answer in several days; its running on an old GPU :)

@YilinLiu97
Copy link

Thank you! Looking forward to the results! :)

@Britefury
Copy link
Owner

@YilinLiu97 I have run the experiments and using fixed EMA makes no statistical difference as far as I can tell, at least on the MNIST -> SVHN experiment.

I re-ran the baseline just in case running it on a machine with a smaller GPU and less RAM with smaller batch size made a difference.

The baseline accuracy is 96.998% +/- 0.056%
A re-run of the baseline got 96.941% +/- 0.103%; re-running the baseline is pretty much consistent
With fixed EMA: 96.918% +/- 0.039%; a very slight drop but not a significant one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants