π The PyTorch implementation of MetaCRL, described in IJCAI 2024 "Hacking Task Confounder in Meta-Learning".
Version 1.0: Run example
miniImagenet
Datasets are supported.
Version 2.0: Re-write meta learner and basic learner. Solved some serious bugs in version 1.0.
Version 3.0: Implement multiple model basic units in
models
that can be replaced as needed.
This method can be plug-and-played into any meta-learning framework:
For more frameworks and datasets, please visit HERE with WEBSITE
For more sampling strategies and settings, please visit HERE
Meta-learning enables rapid generalization to new tasks by learning knowledge from various tasks. It is intuitively assumed that as the training progresses, a model will acquire richer knowledge, leading to better generalization performance. However, our experiments reveal an unexpected result: there is negative knowledge transfer between tasks, affecting generalization performance. To explain this phenomenon, we conduct Structural Causal Models (SCMs) for causal analysis. Our investigation uncovers the presence of spurious correlations between task-specific causal factors and labels in meta-learning. Furthermore, the confounding factors differ across different batches. We refer to these confounding factors as ``Task Confounders". Based on these findings, we propose a plug-and-play Meta-learning Causal Representation Learner (MetaCRL) to eliminate task confounders. It encodes decoupled generating factors from multiple tasks and utilizes an invariant-based bi-level optimization mechanism to ensure their causality for meta-learning.
Brief overview of the meta-learning process with MetaCRL:
-
python: 3.x
-
Pytorch: 0.4+
For easier use and to avoid any conflicts with existing Python setup, it is recommended to use virtualenv
to work in a virtual environment. Now, let's start:
Step 1: Install virtualenv
pip install --upgrade virtualenv
or using conda create
.
Step 2: Create a virtual environment, activate it:
virtualenv venv
source venv/bin/activate
Step 3: Install the requirements in requirements.txt
.
pip install -r requirements.txt
For 5-way 1-shot exp., it allocates nearly 6GB GPU memory.
-
download
MiniImagenet
dataset from here, the splitted file:train/val/test.csv
are provided indata/split
-
for image split, extract it like:
miniimagenet/
βββ images
βββ n0210891500001298.jpg
βββ n0287152500001298.jpg
...
βββ test.csv
βββ val.csv
βββ train.csv
data/data_generator
provides the python file for data generator.
- modify the
path
inexample.py
:
mini = MiniImagenet('miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=10000, resize=args.imgsz)
...
mini_test = MiniImagenet('miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=100, resize=args.imgsz)
to your actual data path.
We provide the example for training on miniImagenet:
python example.py
more examples will be provided after the paper being published.
If you find our work and codes useful, please consider citing our paper and star our repository (π₯°πThanks!!!):
@misc{wang2024hacking,
title={Hacking Task Confounder in Meta-Learning},
author={Jingyao Wang and Yi Ren and Zeen Song and Jianqi Zhang and Changwen Zheng and Wenwen Qiang},
year={2024},
eprint={2312.05771},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
(arXiv version, the final version will be updated after the paper is published.)