Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO Not Working with DeepSpeed stage ZeRO-3 #3108

Closed
1 task done
markelausin opened this issue Apr 2, 2024 · 7 comments
Closed
1 task done

PPO Not Working with DeepSpeed stage ZeRO-3 #3108

markelausin opened this issue Apr 2, 2024 · 7 comments
Labels
solved This problem has been already solved

Comments

@markelausin
Copy link

Reminder

  • I have read the README and searched the existing issues.

Reproduction

Generate() step is failing during PPO with LLaMA 70B + LoRA. I'm using DeepSpeed ZeRO-3 and I've tried with and without offloading, and with and without grad accumulation. Is the model not being unwrapped correctly? When I print the state dictionary of the unwrapped model (and also unwrapped_model.pretrained_model.state_dict()), I get that the following tensor, which is not 2D: ('base_model.model.model.embed_tokens.weight', tensor([], device='cuda:7', dtype=torch.bfloat16)). This probably indicates the that embed_tokens weight is being split across multiple GPUs with ZeRO stage 3. Is there a way to fix this?

Here's the model config

exp_args="--stage ppo \
--model_name_or_path /path/to/Llama2-70b \
--adapter_name_or_path /path/to/sft/lora/adapter \
--create_new_adapter \
--reward_model /path/to/reward_model/adapter \
--output_dir /path/to/output/dir \
--overwrite_output_dir \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--save_steps 0.1 \
--eval_steps 0.1 \
--save_strategy steps \
--warmup_steps 1 \
--top_k 0 \
--top_p 0.9 \
--print_param_status \
--rope_scaling linear \
--evaluation_strategy steps \
--gradient_accumulation_steps 2 \
--learning_rate 1e-5 \
--num_train_epochs 1"

Expected behavior

No response

System Info

version v0.6.1

Others

Traceback (most recent call last):
  File "/lustre/pretraining-checkpoints/pre-training-model-checkpoint/markel/deepspeed/rlhf/LLaMA-Factory/src/train_bash.py", line 14, in <module>
    main()
  File "/lustre/pretraining-checkpoints/pre-training-model-checkpoint/markel/deepspeed/rlhf/LLaMA-Factory/src/train_bash.py", line 5, in main
    run_exp()
  File "/lustre/pretraining-checkpoints/pre-training-model-checkpoint/markel/deepspeed/rlhf/LLaMA-Factory/src/llmtuner/train/tuner.py", line 37, in run_exp
    run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/lustre/pretraining-checkpoints/pre-training-model-checkpoint/markel/deepspeed/rlhf/LLaMA-Factory/src/llmtuner/train/ppo/workflow.py", line 60, in run_ppo
    ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/lustre/pretraining-checkpoints/pre-training-model-checkpoint/markel/deepspeed/rlhf/LLaMA-Factory/src/llmtuner/train/ppo/trainer.py", line 194, in ppo_train
    mini_batch_queries, mini_batch_responses = self.get_inputs(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/lustre/pretraining-checkpoints/pre-training-model-checkpoint/markel/deepspeed/rlhf/LLaMA-Factory/src/llmtuner/train/ppo/trainer.py", line 311, in get_inputs
    generate_output: torch.Tensor = unwrapped_model.generate(
  File "/usr/local/lib/python3.10/dist-packages/trl/models/modeling_value_head.py", line 203, in generate
    return self.pretrained_model.generate(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 1190, in generate
    outputs = self.base_model.generate(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1575, in generate
    result = self._sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2697, in _sample
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 972, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D
@hiyouga hiyouga added the pending This problem is yet to be addressed label Apr 3, 2024
@butujvzipi
Copy link

这个问题新版框架还存在,什么时候可以解决呀

@butujvzipi
Copy link

@hiyouga

@Ricardokevins
Copy link

Any update on this issue @hiyouga QwQ

@yukiwayx
Copy link

yukiwayx commented Jun 5, 2024

Same problem

@rahul1921
Copy link

same problem with

Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)

CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml

hiyouga added a commit that referenced this issue Jun 6, 2024
@hiyouga
Copy link
Owner

hiyouga commented Jun 6, 2024

fixed

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Jun 6, 2024
@hiyouga hiyouga closed this as completed Jun 6, 2024
@ldknight
Copy link

ldknight commented Jul 4, 2024

@hiyouga hi,can you tell me how you solved the problem? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

7 participants