Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is flash_attn mandatory for training models like InternLM2? #4398

Closed
1 task done
gaoyang07 opened this issue Jun 20, 2024 · 1 comment
Closed
1 task done

Is flash_attn mandatory for training models like InternLM2? #4398

gaoyang07 opened this issue Jun 20, 2024 · 1 comment
Labels
solved This problem has been already solved

Comments

@gaoyang07
Copy link

Reminder

  • I have read the README and searched the existing issues.

System Info

Question from InternLM/InternLM#747

Reproduction

model

model_name_or_path: internlm/internlm2-chat-7b

method

stage: sft
do_train: true
finetuning_type: lora
lora_target: all

dataset

dataset: data
template: intern2
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16

output

output_dir: saves/internlm2-chat-7b/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true

train

per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
fp16: true

eval
val_size: 0.1
per_device_eval_batch_size: 1
evaluation_strategy: steps
eval_steps: 500

Expected behavior

No response

Others

No response

@github-actions github-actions bot added the pending This problem is yet to be addressed label Jun 20, 2024
@hiyouga
Copy link
Owner

hiyouga commented Jun 20, 2024

Replace these lines:
https://huggingface.co/internlm/internlm2-chat-7b/blob/main/modeling_internlm2.py#L56-L58
with

try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
except ImportError:
    pass

It should resolve this problem.

HuggingFace's Transformers will check the imports for foreign codes, and raise error if we import flash_attn package without a try-except condition: https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/dynamic_module_utils.py#L161-L186

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Jun 20, 2024
@hiyouga hiyouga closed this as completed Jun 20, 2024
hiyouga added a commit that referenced this issue Jun 30, 2024
PrimaLuz pushed a commit to PrimaLuz/LLaMA-Factory that referenced this issue Jul 1, 2024
xtchen96 pushed a commit to xtchen96/LLaMA-Factory that referenced this issue Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
solved This problem has been already solved
Projects
None yet
Development

No branches or pull requests

2 participants