This repository contains the official PyTorch implementation of paper "Self-Supervised Learning Disentangled Group Representation as Feature".
Self-Supervised Learning Disentangled Group Representation as Feature
Tan Wang, Zhongqi Yue, Jianqiang Huang, Qianru Sun, Hanwang Zhang
Conference and Workshop on Neural Information Processing Systems (NeurIPS), 2021 (Spotlight)
[Paper] [Poster] [Slides] [Zhihu]
- Python 3.7
- PyTorch 1.6.0
- PIL
- OpenCV
- tqdm
--maximize_iter
: when to perform the maximize step?--env_num
: , the number of the subsets (orbits)--constrain
: if perform the constrain for partition updating? (constrain the difference of the number of samples in the 2 subsets not be too large)--retain_group
: retain the previous partition?--penalty_weight
: the penalty (irm loss) weight ()--irm_weight_maxim
: the irm loss weight in partition maximization ()--keep_cont
: maintain the standard SSL loss as the first partition--offline
: if update the partition offline? (i.e., first extract the feature and then optimize the partition)--mixup_max
: if using mixup for maximization step? (We find this option can usually gets a little bit better results but consumes more time)
- Minimization Step:
def train_env()
inmain.py
def train_env_mixup_full_retaingp()
inmain_mixup.py
- Maximization Step:
def auto_split_offline() / auto_split()
inutils.py
def auto_split_online_mixup() / auto_split_offline_mixup()
inutils_mixup.py
- Soft Contrastive Loss: To enable the calculation of the contrastive loss with the partition updating in maximization, we also change the contrastive into a soft version.
def soft_contrastive_loss()
inutils.py
soft_contrastive_loss_mixup_online() / soft_contrastive_loss_mixup_offline()
inutils_mixup.py
- Partition :
updated_split
in the code (follow the order of the dataset)
- Training IP-IRM on STL dataset for 400 epochs with updating partition every 50 epochs
CUDA_VISIBLE_DEVICES=0,1 python main.py --penalty_weight 0.2 --irm_weight_maxim 0.5 --maximize_iter 50 --random_init --constrain --constrain_relax --dataset STL --epochs 400 --offline --keep_cont --retain_group --name IPIRM_STL_epoch400
- Linear Evaluations
CUDA_VISIBLE_DEVICES=0,1 python linear.py --model_path results/STL/IPIRM_STL_epoch400/model_400.pth --dataset STL --txt --name IPIRM_STL_epoch400
- You can also directly follow the
.sh
file in therunsh
directory
Epoch | Temperature | Arch | Latent Dim | Batch Size | Accuracy(%) | Download | |||
---|---|---|---|---|---|---|---|---|---|
IP-IRM | 400 | 0.2 | 0.5 | 0.5 | ResNet50 | 128 | 256 | 84.44 | model |
IP-IRM+MixUp | 400 | 0.2 | 0.5 | 0.2 | ResNet50 | 128 | 256 | 88.26 | model |
IP-IRM+MixUp (1000epochs) | 1000 | 0.2 | 0.5 | 0.2 | ResNet50 | 128 | 256 | 90.59 | model |
Here we provide some of our experience when improving IP-IRM which may provide some insights (future direction) for you.
- Though we provide the theoretical proof (see Appendix) for our IP-IRM, the optimization process is still tricky. For example, when to perform maximization? train the maximization step for how many epochs? How to decide when a step achieves convergence? ... Many of questions can be further explored.
- There are some compromises in practice in terms of time-consuming, which can be improved. For example, the offline training for maximization process is just a kind of compromise. In mixup training, controlling the length of the partition set is also a compromise.
- Revise the maximize process to a kind of RL learning? (more intuitive)
- Adopting IP-IRM in other SSL methods
- The spirits of the IP-IRM (i.e., data partition) can also be utilized into other tasks, even other domains (e.g., pls check our ICCV2021 paper on OOD generalization)
If you find our codes helpful, please cite our paper:
@inproceedings{wang2021self,
title={Self-Supervised Learning Disentangled Group Representation as Feature},
author={Wang, Tan and Yue, Zhongqi and Huang, Jianqiang and Sun, Qianru and Zhang, Hanwang},
booktitle={Conference and Workshop on Neural Information Processing Systems (NeurIPS)},
year={2021}
}
@inproceedings{wang2021causal,
title={Causal attention for unbiased visual recognition},
author={Wang, Tan and Zhou, Chang and Sun, Qianru and Zhang, Hanwang},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={3091--3100},
year={2021}
}
Part of this code is inspired by DCL.
If you have any questions, please feel free to email me ([email protected]).