Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rlhf PPO训练报error #17

Open
Fenglly opened this issue Aug 14, 2024 · 11 comments
Open

rlhf PPO训练报error #17

Fenglly opened this issue Aug 14, 2024 · 11 comments
Labels
bug Something isn't working

Comments

@Fenglly
Copy link

Fenglly commented Aug 14, 2024

使用qwen1.5-chatchat模型进行rlhf ppo训练,遇到如下报错:
2024 08 14_冯丽莹 94f0aadae35fc1f0987a48b25d91cf97
2024 08 14_冯丽莹 fa535a68e248f19f92d78220ac2aac72

@mst272
Copy link
Owner

mst272 commented Aug 14, 2024

有可能你的代码没拉到最新。并且请给出更详细的配置。

@Fenglly
Copy link
Author

Fenglly commented Aug 14, 2024

代码是今天下午两点拉取的main版本的代码,环境是按照requriment进行安装的

@Fenglly
Copy link
Author

Fenglly commented Aug 14, 2024

这个是common_args.py
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum

class TrainArgPath(Enum):
PPO_ARGS = 'rlhf_args/ppo_config.py'
RLOO_ARGS = 'rlhf_args/rloo_config.py'
CPO_ARGS = 'rlhf_args/cpo_config.py'
SimPO_ARGS = 'rlhf_args/simpo_config.py'
CPOSimPO_ARGS = 'rlhf_args/cpo-simpo_config.py'

class CommonArgs:
"""
一些常用的自定义参数
"""
train_args_path: TrainArgPath = field(default=TrainArgPath.PPO_ARGS.value,
metadata={"help": "当前模式训练参数,目前支持 [PPO,RLOO,CPO,SimPO,CPOSimPO]"})
# 微调方法相关选择与配置
train_mode: str = field(default='full', metadata={"help": "选择采用的训练方式:[qlora, lora, full]"})
use_dora: bool = field(default=False,
metadata={"help": "仅在train_mode==lora时可以使用。是否使用Dora(一个基于Lora的变体)"})
rlhf_type: str = field(default="PPO",
metadata={"help": "选择使用的RLHF方法,目前支持[PPO,RLOO,CPO,SimPO,CPOSimPO]"})

# lora相关配置
lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"})
lora_alpha: Optional[int] = field(default=16, metadata={"help": "lora alpha"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "lora dropout"})

@Fenglly
Copy link
Author

Fenglly commented Aug 14, 2024

这是ppo_config.py文件:
import os
from dataclasses import dataclass, field
from typing import Optional, Union, List, Literal
from transformers import SchedulerType, IntervalStrategy
from transformers.training_args import OptimizerNames
from trl.trainer.ppov2_trainer import PPOv2Config

@DataClass
class PPOConfig(PPOv2Config):
# common config
exp_name: str = os.path.basename(file)[: -len(".py")]
"""the name of this experiment"""
run_name: Optional[str] = None
"""a unique name of this run"""
sanity_check: bool = False
"""wether to run in debug mode"""

# batch size related config
num_mini_batches: int = 1
"""Number of minibatches to split a batch into"""
total_episodes: Optional[int] = None
"""The total number of episodes in the dataset"""
local_rollout_forward_batch_size: int = 64
"""per rank no grad forward pass in the rollout phase"""
num_sample_generations: int = 10
"""the number of debugging samples generations (i.e., `generate_completions` calls) throughout training"""

response_length: int = 53
"""the length of the response"""
stop_token: Optional[Literal["eos"]] = None
"""the stop token"""
stop_token_id: Optional[int] = None
"""the truncation token id"""
temperature: float = 0.7
"""the sampling temperature"""
penalty_reward_value: int = -1
"""the reward value for responses that do not contain `stop_token_id`"""
non_eos_penalty: bool = False
"""whether to penalize responses that do not contain `stop_token_id`"""
reward_model_path: str = "/data/LLM_Weight/qwen/Qwen1.5-0.5B-Chat/"
"""the path to the reward model"""
sft_model_path: str = "/data/LLM_Weight/qwen/Qwen1.5-0.5B-Chat/"
"""the path to the sft model"""

# ppo config
num_ppo_epochs: int = 4
"""the number of epochs to train"""
vf_coef: float = 0.1
"""the value function coefficient"""
cliprange: float = 0.2
"""the clip range"""
cliprange_value: float = 0.2
"""the clip range for the value function"""
gamma: float = 1
"""the discount factor"""
lam: float = 0.95
"""the lambda value for GAE"""
whiten_rewards: bool = False
"""whether to whiten the rewards"""
kl_coef: float = 0.05
"""the KL coefficient"""

# TrainingArguments的相关参数
train_data_path: Optional[str] = field(default='/data/LLM-Dojo-main/rlhf/data_example/data.jsonl', metadata={"help": "训练集路径"})
output_dir: str = field(default='./out', metadata={"help": "模型训练完成后的保存路径"})
num_train_epochs: int = field(default=1, metadata={"help": "训练轮次"})
per_device_train_batch_size: int = field(default=2, metadata={"help": "训练的batch size"})
gradient_checkpointing: bool = field(default=True, metadata={"help": "是否使用梯度累计"})
gradient_accumulation_steps: int = field(default=16, metadata={"help": "梯度累计的步长"})
learning_rate: float = field(default=2e-4, metadata={"help": "学习率"})
logging_steps: int = field(default=100, metadata={"help": "打印的步长"})
save_steps: int = field(default=500, metadata={"help": "多少步长保存一次"})
save_strategy: Union[IntervalStrategy, str] = field(default="epoch", metadata={"help": "The checkpoint save "
                                                                                       "strategy to use."}, )
save_total_limit: Optional[int] = field(default=2, metadata={"help": "If a value is passed, will limit the total "
                                                                     "amount of checkpoints. Deletes the older "
                                                                     "checkpoints in"})
lr_scheduler_type: Union[SchedulerType, str] = field(default="constant_with_warmup",
                                                     metadata={"help": "The scheduler type to use."})
warmup_steps: int = field(default=10, metadata={"help": "Linear warmup over warmup_steps."})
optim: Union[OptimizerNames, str] = field(default='adamw_torch', metadata={"help": "The optimizer to use."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
report_to: Optional[List[str]] = field(default='wandb', metadata={
    "help": "The list of integrations to report the results and logs to."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
remove_unused_columns: Optional[bool] = field(default=False, metadata={
    "help": "Remove columns not required by the model when using an nlp.Dataset."})
bf16: bool = field(default=True, metadata={"help": "是否使用bf16精度"})

# Deepspeed训练相关参数,不使用时设置为default=None
deepspeed: Optional[str] = field(default=None, metadata={"help": "启用Deepspeed时需要的config文件"})

world_size: Optional[int] = 1
"""The number of processes (GPUs) to use"""

@Fenglly
Copy link
Author

Fenglly commented Aug 14, 2024

2024 08 14_冯丽莹 e79f95d83a8329d47b2f940dc28cf284
下面还有一个其他环境里的error,并不是运行这个脚本所使用的Python,有点奇怪

@mst272
Copy link
Owner

mst272 commented Aug 14, 2024

项目中给的train data只是示例,如果数量不够的话就会报错。需要数据需要大于batch*gradient_accumulation_steps。 另外之前的版本忘eval_samples参数了,你可再重新拉去代码,增大data数量再试一下。

@mst272
Copy link
Owner

mst272 commented Aug 15, 2024

我在本地试了一下qwen1.5-0.5B PPO full可以正常启动。你的启动命令是什么?yaml文件中的num_processes是否与显卡数量对应?如还无法解决建议用python rlhf_train.py命令启动(可以qlora或lora降低显存),查看具体报错原因。

@mst272 mst272 added the bug Something isn't working label Aug 15, 2024
@Fenglly
Copy link
Author

Fenglly commented Aug 15, 2024

CUDA_VISIBLE_DEVICES=1 nohup accelerate launch --config_file ./ds_config/deepspeed_zero3.yaml rlhf_train.py 启动命令是这个,你那边这两个参数分别设置的是多少eval_samples,per_device_train_batch_size?您数据也是使用的示例数据里的data.jsonl吗

@mst272
Copy link
Owner

mst272 commented Aug 15, 2024

先确认一下zero3.yaml文件里的num_processes改没改为1.

num_train_epochs:2
per_device_train_batch_size:2
gradient_accumulation_steps:8
eval_samples:30

@Fenglly
Copy link
Author

Fenglly commented Aug 15, 2024

嗯嗯 zero3.yaml文件里的num_processes设置没问题

@mst272
Copy link
Owner

mst272 commented Aug 15, 2024

很奇怪,你用python rlhf_train.py命令启动(可以qlora或lora降低显存),查看具体报错原因。上述报错无法具体定位。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants