You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I use flash_attn2 and shift_attn together, I get an error: IndexError: too many indices for tensor of dimension 2. There is no problem when I use flash_attn2 or shift_attn separately. What could be the reason for this?
Reminder
Reproduction
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 src/train_bash.py
--stage sft
--do_train True
--model_name_or_path xxx
--finetuning_type lora
--template llama3
--flash_attn fa2
--dataset_dir data
--dataset LongAlpaca-12k
--cutoff_len 32768
--learning_rate 2e-05
--num_train_epochs 3.0
--max_samples 100000
--per_device_train_batch_size 2
--gradient_accumulation_steps 8
--lr_scheduler_type constant_with_warmup
--max_grad_norm 1.0
--logging_steps 5
--save_steps 140
--warmup_steps 20
--optim adamw_torch
--shift_attn True
--report_to none
--output_dir xxx
--fp16 True
--lora_rank 8
--lora_alpha 16
--lora_dropout 0.1
--use_dora True
--lora_target all
--plot_loss True
Expected behavior
When I use flash_attn2 and shift_attn together, I get an error: IndexError: too many indices for tensor of dimension 2. There is no problem when I use flash_attn2 or shift_attn separately. What could be the reason for this?
System Info
transformers
version: 4.40.0Others
flash_attn version: 2.5.8
The text was updated successfully, but these errors were encountered: