Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ppo阶段wandb未创建job #1026

Closed
ticoAg opened this issue Sep 25, 2023 · 4 comments
Closed

[BUG] ppo阶段wandb未创建job #1026

ticoAg opened this issue Sep 25, 2023 · 4 comments
Labels
solved This problem has been already solved

Comments

@ticoAg
Copy link

ticoAg commented Sep 25, 2023

训练脚本如下

wandb online

CUDA_VISIBLE_DEVICES=$gpu_vis accelerate launch --config_file $acclerate_config src/train_bash.py \
    --stage ppo \
    --do_train \
    --finetuning_type lora \
    --lora_target W_pack \
    --lora_rank 64 \
    --resume_lora_training False \
    --model_name_or_path $proj_dir/$model_name_or_path \
    --reward_model $proj_dir/rm/$reward_model \
    --output_dir $root_dir/$exp_id \
    --overwrite_output_dir \
        --template $template \
        --dataset $dataset \
        --cutoff_len 4096 \
        --per_device_train_batch_size 2 \
        --gradient_accumulation_steps 2 \
        --preprocessing_num_workers 128 \
        --num_train_epochs 3 \
    --save_strategy epoch \
    --warmup_ratio 0.1 \
        --learning_rate 1e-5 \
        --lr_scheduler_type cosine \
        --max_grad_norm 0.5 \
        --adam_epsilon 1e-7 \
    --logging_steps 5 \
    --flash_attn \
    --plot_loss \
    --bf16 \
    --report_to wandb \
    --run_name $exp_id

日志正常显示,wandb未创建任务并同步日志

{'loss': 0.007, 'reward': -0.1181, 'learning_rate': 3.9279869067103113e-07, 'epoch': 0.0}                                                                                                                        
{'loss': 0.0044, 'reward': -0.122, 'learning_rate': 8.837970540098201e-07, 'epoch': 0.0}                                                                                                                         
{'loss': 0.0121, 'reward': -0.1249, 'learning_rate': 1.3747954173486089e-06, 'epoch': 0.01}
@ticoAg
Copy link
Author

ticoAg commented Sep 25, 2023

workflow代码内wandb.init之后 自定义的trainer里用wandb.log输出是可以创建任务同步日志,但是是每个卡创建一个,日志这个结果应该是汇总之后的吧

@mmbwf
Copy link
Contributor

mmbwf commented Sep 25, 2023

trl has its own logging parameters, you need first enable log with wandb in PPOConfig (workflow.py):

ppo_config = PPOConfig(
        model_name=model_args.model_name_or_path,
        learning_rate=training_args.learning_rate,
        mini_batch_size=training_args.per_device_train_batch_size,
        batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        ppo_epochs=4,
        max_grad_norm=training_args.max_grad_norm,
        seed=training_args.seed,
        optimize_cuda_cache=True,
        log_with="wandb"
    )

and next, you should add some code in trainer.py for logging ppo stats:

# Run PPO step
stats = self.step(queries, responses, rewards)
# add code for wandb report
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)

finally, you should modify some details to avoid errors, because when you use wandb, trl will convert stats from scaler to numpy.

logs = dict(
    loss=round(loss_meter.avg.item(), 4), #get item for wandb report
    reward=round(reward_meter.avg, 4),
    learning_rate=stats["ppo/learning_rate"],
    epoch=round(step / len_dataloader, 2)
)

@hiyouga hiyouga added the pending This problem is yet to be addressed label Sep 26, 2023
@skepsun
Copy link

skepsun commented Sep 26, 2023

我提过这个问题,解决方案也和上面的一样,不过一直没有改

@hiyouga
Copy link
Owner

hiyouga commented Sep 27, 2023

Fixed in b0b0138, thanks @mmbwf

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Sep 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

4 participants