Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention integration failed #36

Open
SparkJiao opened this issue Jul 4, 2023 · 0 comments
Open

Flash attention integration failed #36

SparkJiao opened this issue Jul 4, 2023 · 0 comments

Comments

@SparkJiao
Copy link

Hello,

when I try to use flash attention, I have encountered the following problem:

│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:346 in main             │
│                                                                              │
│   343 │   │   │   logger.info("Resuming training from the latest checkpoint: │
│   344 │   │   │   continue_from_global_step = int(checkpoint.split('-')[-1]) │
│   345 │   │                                                                  │
│ ❱ 346 │   │   global_step, tr_loss = train(cfg, model_pipe, tokenizer, conti │
│   347 │   │   logger.info(" global_step = %s, average loss = %s", global_ste │
│   348                                                                        │
│   349                                                                        │
│                                                                              │
│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:236 in train            │
│                                                                              │
│   233 │   │   │   │   │   continue                                           │
│   234 │   │   │   │                                                          │
│   235 │   │   │   │   model.train()                                          │
│ ❱ 236 │   │   │   │   loss = model.train_batch(data_iter=sub_train_dataloade │
│   237 │   │   │   │   global_step += 1                                       │
│   238 │   │   │   │                                                          │
│   239 │   │   │   │   tr_loss += loss.item()                                 │
│                                                                              │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │
│ epspeed/runtime/pipe/engine.py:336 in train_batch                            │
│                                                                              │
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │
│ ❱  336 │   │   self._exec_schedule(sched)                                    │
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │
│    338 │   │                                                                 │
│    339 │   │   self.timers('train_batch').stop()                             │
│                                                                              │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │
│ epspeed/runtime/pipe/engine.py:1307 in _exec_schedule                        │
│                                                                              │
│   1304 │   │   │   │                                                         │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │
│   1308                                                                       │
│                                                                              │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │
│ epspeed/runtime/pipe/engine.py:996 in _exec_send_grads                       │
│                                                                              │
│    993 │   │   │   │   │   if not buffer.is_floating_point():                │
│    994 │   │   │   │   │   │   assert buffer.grad is None                    │
│    995 │   │   │   │   │   │   continue                                      │
│ ❱  996 │   │   │   │   │   assert buffer.grad is not None                    │
│    997 │   │   │   │   │   p2p.send(buffer.grad, self.prev_stage)            │
│    998 │   │                                                                 │
│    999 │   │   # We can free up the input buffer now                         │
╰──────────────────────────────────────────────────────────────────────────────╯
AssertionError

I also test it by using the torch.nn.functional.scaled_dot_product_attention, which implements flash attention in torch2.0, but I met the same problem. May I know if you have encountered with the problem?

Thanks for your help very much!

Best,
Fangkai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant