Paper Link: https://arxiv.org/abs/2402.05808
It is suggested to use a python 3.9 environment to run the experiment. Run the following commands to set up your environment:
git clone https://github.com/xxxxx.git
conda create -n R3_math python=3.9 -y
cd R3_math/
pip install -r requirements.txt
conda create -n R3_others python=3.9 -y
cd R3_others/
pip install -r requirements.txt
To train a sft model, first set the model path and output path in the R3_others/scripts/step1_supervised_finetuning/R3_sft.sh
script. Then, run the following command:
cd R3_others/scripts/step1_supervised_finetuning/
bash R3_sft.sh
To train a reinforced model using RR3_math/scripts/R3_cot_gsm8k.sh
, and run the following command:
cd R3_math/scripts/
bash R3_cot_gsm8k.sh
Note: If you want to try RR3_others/scripts/step3_rlhf_finetuning/R3_mix.sh
. Then, run the folloing command:
cd R3_others/scripts/step3_rlhf_finetuning/
bash R3_mix.sh
It is not required for math datasets. Results will be saved in wandb.
To evaluate the model performance, first run the evaluation script R3_others/scripts/eval/eval_single.sh
. Then, get your results in output_{dataset_name}.py
. Here's an example for MNLI dataset:
cd R3_others/scripts/eval
bash eval_single.sh
# after evaluation
# you will get a result file like: eval_mnli/R3_test.txt
python output_mnli.py
# then you will get acc result
For the purpose of security review, we provide some examples of the data, formatted as follows:
Dataset: MNLI
---- mnli_train_example.json # for SFT
---- mnli_mix_example.json # fot R^3
---- mnli_test.json
If you find R
@misc{xi2024training,
title={Training Large Language Models for Reasoning through Reverse Curriculum Reinforcement Learning},
author={Zhiheng Xi and Wenxiang Chen and Boyang Hong and Senjie Jin and Rui Zheng and Wei He and Yiwen Ding and Shichun Liu and Xin Guo and Junzhe Wang and Honglin Guo and Wei Shen and Xiaoran Fan and Yuhao Zhou and Shihan Dou and Xiao Wang and Xinbo Zhang and Peng Sun and Tao Gui and Qi Zhang and Xuanjing Huang},
year={2024},
eprint={2402.05808},
archivePrefix={arXiv},
primaryClass={cs.AI}
}