Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于利用simpo在ultrafeedback_binarized 数据集上进行偏好对齐 #4085

Closed
1 task done
Meaquadddd opened this issue Jun 5, 2024 · 7 comments
Closed
1 task done
Labels
solved This problem has been already solved

Comments

@Meaquadddd
Copy link

Meaquadddd commented Jun 5, 2024

Reminder

  • I have read the README and searched the existing issues.

System Info

  • transformers version: 4.41.1
  • Platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.2
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.3
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: DEEPSPEED
    - mixed_precision: bf16
    - use_cpu: False
    - debug: True
    - num_processes: 3
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - deepspeed_config: {'gradient_accumulation_steps': 2, 'zero3_init_flag': False, 'zero_stage': 0}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: False

Reproduction

我尝试运行下面这段代码,利用simpo在ultrafeedback_binarized 上进行偏好对齐

CUDA_VISIBLE_DEVICES=0 llamafactory-cli train \
    --stage dpo \
    --pref_loss simpo \
    --simpo_gamma 1.0 \
    --do_train True \
    --model_name_or_path /home/ubuntu/date/llama_ckpts/llama_lx3_ckpts/BAdam_llama3_random_lr1e-6/checkpoint-9600  \
    --preprocessing_num_workers 16 \
    --finetuning_type full \
    --template default \
    --flash_attn auto \
    --dataset_dir data \
    --dataset ultrafeedback_binarized \
    --split train \
    --cutoff_len 2048 \
    --learning_rate 2e-7 \
    --num_train_epochs 1.0 \
    --max_samples 10000000 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 1 \
    --save_steps 100 \
    --warmup_ratio 0.1 \
    --optim adamw_torch \
    --packing False \
    --report_to none \
    --use_badam True \
    --output_dir saves/LLaMA3-8B/full/train_2024-06-05 \
    --pure_bf16 True \
    --plot_loss True \
    --use_badam True \
    --badam_mode layer \
    --badam_switch_mode random \
    --badam_switch_interval 100 \
    --val_size 0.05 \
    --evaluation_strategy steps \
    --eval_steps 20

由于ultrafeedback_binarized在dataset_info.json 中没有定义,我按照相关的readme文档定义了

"ultrafeedback_binarized": {
    "hf_hub_url": "HuggingFaceH4/ultrafeedback_binarized",
    "formatting": "sharegpt",
    "ranking": true,
    "columns": {
      "prompt": "prompt",
      "messages": "messages",
      "chosen": "chosen",
      "rejected": "rejected"
    },
    "tags": {
      "role_tag": "role",
      "system_tag": "system",
      "content_tag": "content",
      "user_tag": "user",
      "assistant_tag": "assistant"
    }
  },

huggingface上的ultrafeedback_binarized 数据集格式是

{
    "prompt": "Which books are still banned in Canada?",
    "chosen": [
    {
        "content": "Which books are still banned in Canada?", 
        "role": "user"
    },
    {
        "content": "As a responsible and ethical AI assistant, I must inform you that I cannot provide a list of banned books in Canada, as there are no books that are currently banned in the country. Canada has a strong commitment to freedom of expression and the right to access information, and there are no laws in place that restrict the publication or distribution of books.\n",
        "role": "assistant"
        }
    ],
 "messages": [
     {
         "content": "Which books are still banned in Canada?",
         "role": "user"
    },
    {
        "content": "As a responsible and ethical AI assistant, I must inform you that I cannot provide a list of banned books in Canada, as there are no books that are currently banned in the country. Canada has a strong commitment to freedom of expression and the right to access information, and there are no laws in place that restrict the publication or distribution of books.\n",
        "role": "assistant"
    }],
 
 "prompt_id": "aeccf551d9ba42fdf5f2044de43b8ce6e360fb523ace428317b81d804594e090",
 "rejected": [
     {
         "content": "Which books are still banned in Canada?",
         "role": "user"},
     {
         "content": "According to the Canadian Government’s Ban Affront website, there are still several books that are banned in Canada. These include The Begum’s Millionaire, The Education of Little Tree, The Harry Potter series, Lolita, 1984, and Lady Chatterley’s Lover. Some of these books are considered inaccessible due to their age, while others are still legally banned in certain parts of the country.",
         "role": "assistant"
         }
     ],
 "score_chosen": 8.0,
 "score_rejected": 5.0
}

运行上面的命令之后打印报错

Traceback (most recent call last):
  File "/home/anaconda3/envs/xiliang/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
  File "/home/xiliang/LLaMA-Factory/src/llamafactory/cli.py", line 65, in main
    run_exp()
  File "/home/xiliang/LLaMA-Factory/src/llamafactory/train/tuner.py", line 39, in run_exp
    run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
  File "/home/xiliang/LLaMA-Factory/src/llamafactory/train/dpo/workflow.py", line 29, in run_dpo
    dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
  File "/home/xiliang/LLaMA-Factory/src/llamafactory/data/loader.py", line 154, in get_dataset
    column_names = list(next(iter(dataset)).keys())
StopIteration

Expected behavior

猜测是由于多出来的

"score_chosen": 8.0,
 "score_rejected": 5.0

字段导致数据集不规范?

Others

上面的代码将dataset 修改为 dpo_en_demo后方可正常运行

No response

@hiyouga
Copy link
Owner

hiyouga commented Jun 5, 2024

@hiyouga hiyouga added the solved This problem has been already solved label Jun 5, 2024
@hiyouga hiyouga closed this as completed Jun 5, 2024
@Baichenjia
Copy link

@Meaquadddd 最后这个问题如何解决呢

@Baichenjia
Copy link

@hiyouga ultrafeedback_binarized是研究中常用的对齐数据集,能否提供支持?

@hiyouga
Copy link
Owner

hiyouga commented Jun 7, 2024

使用 dataset: ultrafeedback 加载官方数据集

@Meaquadddd
Copy link
Author

@Meaquadddd 最后这个问题如何解决呢

我最后自己手动把源码做了下改动

@yhy-2000
Copy link

@Meaquadddd 最后这个问题如何解决呢

我最后自己手动把源码做了下改动

请问可以分享一下具体修改了哪些地方吗? 非常感谢!!

@Meaquadddd
Copy link
Author

Meaquadddd commented Oct 23, 2024

@Meaquadddd 最后这个问题如何解决呢

我最后自己手动把源码做了下改动

请问可以分享一下具体修改了哪些地方吗? 非常感谢!!

现在应该直接在脚本里设置
--dataset ultrafeedback
就没问题了,可以参看 https://github.com/hiyouga/LLaMA-Factory/issues/4132。
Sorry for the late reply!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

4 participants