Official implementation of the FAMO optimizer for multitask learning (NeurIPS 2023). [Paper]
One of the grand enduring goals of AI is to create generalist agents that can learn multiple different tasks from diverse data via multitask learning (MTL). However, gradient descent (GD) on the average loss across all tasks may yield poor multitask performance due to severe under-optimization of certain tasks. Previous approaches that manipulate task gradients for a more balanced loss decrease require storing and computing all task gradients (
update 2023/11/11: FAMO has been used in MFTCoder to boost the fine-tuning performance of large language models!
Top left: The loss landscape, and individual task losses of a toy 2-task learning problem (★ represents the minimum of task losses). Top right: the runtime of different MTL methods for 50000 steps. Bottom: the loss trajectories of different MTL methods. ADAM fails in 1 out of 5 runs to reach the Pareto front due to CG. FAMO decreases task losses in a balanced way and is the only method matching the
For the convenience of potential users of FAMO, we provide a simple example in famo.py
so that users can easily adapt FAMO to their applications. The code requires installation of torch
, which can be installed via the setup in the next section. Check the file and simply run
python famo.py
Create the conda environment and install torch
conda create -n mtl python=3.9.7
conda activate mtl
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
Install the repo:
git clone https://github.com/Cranial-XIX/FAMO.git
cd FAMO
pip install -e .
We follow the MTAN paper. The datasets could be downloaded from NYU-v2 and CityScapes. To download the CelebA dataset, please refer to this link. The dataset should be put under experiments/EXP_NAME/dataset/
folder where EXP_NAME
is chosen from nyuv2, cityscapes, celeba
. Note that quantum_chemistry
will download the data automatically.
The file hierarchy should look like
FAMO
└─ experiments
└─ utils.py (for argument parsing)
└─ nyuv2
└─ dataset (the dataset folder containing the MTL data)
└─ trainer.py (the main file to run the training)
└─ run.sh (the command to reproduce FAMO's results)
└─ cityscapes
└─ dataset (the dataset folder containing the MTL data)
└─ trainer.py (the main file to run the training)
└─ run.sh (the command to reproduce FAMO's results)
└─ quantum_chemistry
└─ dataset (the dataset folder containing the MTL data)
└─ trainer.py (the main file to run the training)
└─ run.sh (the command to reproduce FAMO's results)
└─ celeba
└─ dataset (the dataset folder containing the MTL data)
└─ trainer.py (the main file to run the training)
└─ run.sh (the command to reproduce FAMO's results)
└─ methods
└─ weight_methods.py (the different MTL optimizers)
To run experiments, go to the relevant folder with name EXP_NAME
cd experiments/EXP_NAME
bash run.sh
You can check the run.sh
for details about training with FAMO.
Following NashMTL, we also support experiment tracking with Weights & Biases with two additional parameters:
python trainer.py --method=famo --wandb_project=<project-name> --wandb_entity=<entity-name>
We support the following MTL methods with a unified API. To run experiment with MTL method X
simply run:
python trainer.py --method=X
Method (code name) | Paper (notes) |
---|---|
FAMO (famo ) |
Fast Adaptive Multitask Optimization |
Nash-MTL (nashmtl ) |
Multi-Task Learning as a Bargaining Game |
CAGrad (cagrad ) |
Conflict-Averse Gradient Descent for Multi-task Learning |
PCGrad (pcgrad ) |
Gradient Surgery for Multi-Task Learning |
IMTL-G (imtl ) |
Towards Impartial Multi-task Learning |
MGDA (mgda ) |
Multi-Task Learning as Multi-Objective Optimization |
DWA (dwa ) |
End-to-End Multi-Task Learning with Attention |
Uncertainty weighting (uw ) |
Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics |
Linear scalarization (ls ) |
- (equal weighting) |
Scale-invariant baseline (scaleinvls ) |
- (see Nash-MTL paper for details) |
Random Loss Weighting (rlw ) |
A Closer Look at Loss Weighting in Multi-Task Learning |
Following CAGrad, the MTRL experiments are conducted on Metaworld benchmarks. In particular, we follow the mtrl codebase and the experiment setup in this paper.
-
Install mtrl according to the instructions.
-
Git clone Metaworld and change to
d9a75c451a15b0ba39d8b7a8b6d18d883b8655d8
commit (Feb 26, 2021). Install metaworld accordingly. -
Copy the
mtrl_files
folder under mtrl of this repo to the cloned repo of mtrl. Then
cd PATH_TO_MTRL/mtrl_files/ && chmod +x mv.sh && ./mv.sh
Then follow the run.sh
script to run experiments (We are still testing the results but the code should be runnable).
This repo is built upon CAGrad and NashMTL. If you find FAMO to be useful in your own research, please consider citing the following papers:
@misc{liu2023famo,
title={FAMO: Fast Adaptive Multitask Optimization},
author={Bo Liu and Yihao Feng and Peter Stone and Qiang Liu},
year={2023},
eprint={2306.03792},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{liu2021conflict,
title={Conflict-Averse Gradient Descent for Multi-task Learning},
author={Liu, Bo and Liu, Xingchao and Jin, Xiaojie and Stone, Peter and Liu, Qiang},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
@article{navon2022multi,
title={Multi-Task Learning as a Bargaining Game},
author={Navon, Aviv and Shamsian, Aviv and Achituve, Idan and Maron, Haggai and Kawaguchi, Kenji and Chechik, Gal and Fetaya, Ethan},
journal={arXiv preprint arXiv:2202.01017},
year={2022}
}