Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Flash attention seems cannot be integrated with pipeline parallelism due to absent of input grad #3868

Open
SparkJiao opened this issue Jul 4, 2023 · 1 comment
Labels
bug Something isn't working training

Comments

@SparkJiao
Copy link

Describe the bug
Flash attention of both implementations from the original one or the torch.nn.functional.scaled_dot_production from pytorch2.0 cannot be integrated with LLaMA pipeline parallelism training.

ds_report output

--------------------------------------------------                                                                                                                                                                 │+-----------------------------------------------------------------------------+
DeepSpeed C++/CUDA extension op report                                                                                                                                                                             │| Processes:                                                                  |
--------------------------------------------------                                                                                                                                                                 │|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
NOTE: Ops not installed will be just-in-time (JIT) compiled at                                                                                                                                                     │|        ID   ID                                                   Usage      |
      runtime if needed. Op compatibility means that your system                                                                                                                                                   │|=============================================================================|
      meet the required dependencies to JIT install the op.                                                                                                                                                        │|    0   N/A  N/A   1880877      C   python3                          2525MiB |
--------------------------------------------------                                                                                                                                                                 │|    1   N/A  N/A   1879942      C   ...s/pytorch_scse/bin/python    40461MiB |
JIT compiled ops requires ninja                                                                                                                                                                                    │|    2   N/A  N/A   1030504      C   ...envs/wespeaker/bin/python    31351MiB |
ninja .................. [OKAY]                                                                                                                                                                                    │|    4   N/A  N/A      1663      C   ...s/pytorch_scse/bin/python    40461MiB |
--------------------------------------------------                                                                                                                                                                 │|    5   N/A  N/A   1030505      C   ...envs/wespeaker/bin/python    31351MiB |
op name ................ installed .. compatible                                                                                                                                                                   │|    6   N/A  N/A   2442536      C   ...s/pytorch_scse/bin/python    40461MiB |
--------------------------------------------------                                                                                                                                                                 │+-----------------------------------------------------------------------------+
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.                                                                                                                      │(base) fangkai@scsehg:~$
 [WARNING]  async_io: please install the libaio-dev package with apt                                                                                                                                               │
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.                                                               │
async_io ............... [NO] ....... [NO]                                                                                                                                                                         │
cpu_adagrad ............ [NO] ....... [OKAY]                                                                                                                                                                       │
cpu_adam ............... [NO] ....... [OKAY]                                                                                                                                                                       │
fused_adam ............. [NO] ....... [OKAY]                                                                                                                                                                       │
fused_lamb ............. [NO] ....... [OKAY]                                                                                                                                                                       │
quantizer .............. [NO] ....... [OKAY]                                                                                                                                                                       │
random_ltd ............. [NO] ....... [OKAY]                                                                                                                                                                       │
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0                                                                                                                                 │
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible                                                                                                                            │
sparse_attn ............ [NO] ....... [NO]                                                                                                                                                                         │
spatial_inference ...... [NO] ....... [OKAY]                                                                                                                                                                       │
transformer ............ [NO] ....... [OKAY]                                                                                                                                                                       │
stochastic_transformer . [NO] ....... [OKAY]                                                                                                                                                                       │
transformer_inference .. [NO] ....... [OKAY]                                                                                                                                                                       │
--------------------------------------------------                                                                                                                                                                 │
No CUDA runtime is found, using CUDA_HOME='/cm/shared/apps/cuda11.6/toolkit/11.6.0'                                                                                                                                │
DeepSpeed general environment info:                                                                                                                                                                                │
torch install path ............... ['/export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/torch']                                                                                             │
torch version .................... 2.0.0+cu117                                                                                                                                                                     │
deepspeed install path ........... ['/export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/deepspeed']                                                                                         │
deepspeed info ................... 0.9.5, unknown, unknown                                                                                                                                                         │
torch cuda version ............... 11.7                                                                                                                                                                            │
torch hip version ................ None                                                                                                                                                                            │
nvcc version ..................... 11.6                                                                                                                                                                            │
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.7

Screenshots
The error information is as follows:

│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:408 in main             │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│                                                                              │                                                                                                                                   │|   5  NVIDIA A100-PCI...  On   | 00000000:A1:00.0 Off |                    0 |
│   405 │   │   │   logger.info("Resuming training from the latest checkpoint: │                                                                                                                                   │| N/A   65C    P0   234W / 250W |  31354MiB / 40960MiB |     96%      Default |
│   406 │   │   │   continue_from_global_step = int(checkpoint.split('-')[-1]) │                                                                                                                                   │|                               |                      |             Disabled |
│   407 │   │                                                                  │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│ ❱ 408 │   │   global_step, tr_loss = train(cfg, model_pipe, tokenizer, conti │                                                                                                                                   │|   6  NVIDIA A100-PCI...  On   | 00000000:C1:00.0 Off |                    0 |
│   409 │   │   logger.info(" global_step = %s, average loss = %s", global_ste │                                                                                                                                   │| N/A   53C    P0    95W / 250W |  40464MiB / 40960MiB |     62%      Default |
│   410                                                                        │                                                                                                                                   │|                               |                      |             Disabled |
│   411                                                                        │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│                                                                              │                                                                                                                                   │|   7  NVIDIA A100-PCI...  On   | 00000000:E1:00.0 Off |                    0 |
│ /export/home2/fangkai/merit-v2/trainer_base_ds_mp.py:298 in train            │                                                                                                                                   │| N/A   28C    P0    33W / 250W |      0MiB / 40960MiB |      0%      Default |
│                                                                              │                                                                                                                                   │|                               |                      |             Disabled |
│   295 │   │   │   │   │   continue                                           │                                                                                                                                   │+-------------------------------+----------------------+----------------------+
│   296 │   │   │   │                                                          │                                                                                                                                   │
│   297 │   │   │   │   model.train()                                          │                                                                                                                                   │+-----------------------------------------------------------------------------+
│ ❱ 298 │   │   │   │   loss = model.train_batch(data_iter=sub_train_dataloade │                                                                                                                                   │| Processes:                                                                  |
│   299 │   │   │   │   global_step += 1                                       │                                                                                                                                   │|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
│   300 │   │   │   │                                                          │                                                                                                                                   │|        ID   ID                                                   Usage      |
│   301 │   │   │   │   tr_loss += loss.item()                                 │                                                                                                                                   │|=============================================================================|
│                                                                              │                                                                                                                                   │|    0   N/A  N/A   1880877      C   python3                          2525MiB |
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │                                                                                                                                   │|    1   N/A  N/A   1879942      C   ...s/pytorch_scse/bin/python    40461MiB |
│ epspeed/runtime/pipe/engine.py:336 in train_batch                            │                                                                                                                                   │|    2   N/A  N/A   1030504      C   ...envs/wespeaker/bin/python    31351MiB |
│                                                                              │                                                                                                                                   │|    4   N/A  N/A      1663      C   ...s/pytorch_scse/bin/python    40461MiB |
│    333 │   │   sched = schedule.TrainSchedule(micro_batches=self.micro_batch │                                                                                                                                   │|    5   N/A  N/A   1030505      C   ...envs/wespeaker/bin/python    31351MiB |
│    334 │   │   │   │   │   │   │   │   │      stages=self.num_stages,        │                                                                                                                                   │|    6   N/A  N/A   2442536      C   ...s/pytorch_scse/bin/python    40461MiB |
│    335 │   │   │   │   │   │   │   │   │      stage_id=self.stage_id)        │                                                                                                                                   │+-----------------------------------------------------------------------------+
│ ❱  336 │   │   self._exec_schedule(sched)                                    │                                                                                                                                   │(base) fangkai@scsehg:~$
│    337 │   │   self.agg_train_loss = self._aggregate_total_loss()            │                                                                                                                                   │
│    338 │   │                                                                 │                                                                                                                                   │
│    339 │   │   self.timers('train_batch').stop()                             │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │                                                                                                                                   │
│ epspeed/runtime/pipe/engine.py:1307 in _exec_schedule                        │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│   1304 │   │   │   │                                                         │                                                                                                                                   │
│   1305 │   │   │   │   # Equivalent to: self._exec_forward_pass(buffer_id=0) │                                                                                                                                   │
│   1306 │   │   │   │   self._exec_instr = MethodType(self._INSTRUCTION_MAP[t │                                                                                                                                   │
│ ❱ 1307 │   │   │   │   self._exec_instr(**cmd.kwargs)                        │                                                                                                                                   │
│   1308                                                                       │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│ /export/home2/fangkai/anaconda3/envs/torch2.0/lib/python3.9/site-packages/de │                                                                                                                                   │
│ epspeed/runtime/pipe/engine.py:996 in _exec_send_grads                       │                                                                                                                                   │
│                                                                              │                                                                                                                                   │
│    993 │   │   │   │   │   if not buffer.is_floating_point():                │                                                                                                                                   │
│    994 │   │   │   │   │   │   assert buffer.grad is None                    │                                                                                                                                   │
│    995 │   │   │   │   │   │   continue                                      │                                                                                                                                   │
│ ❱  996 │   │   │   │   │   assert buffer.grad is not None                    │                                                                                                                                   │
│    997 │   │   │   │   │   p2p.send(buffer.grad, self.prev_stage)            │                                                                                                                                   │
│    998 │   │                                                                 │                                                                                                                                   │
│    999 │   │   # We can free up the input buffer now                         │                                                                                                                                   │
╰──────────────────────────────────────────────────────────────────────────────╯                                                                                                                                   │
AssertionError

System info (please complete the following information):

  • OS: [Ubuntu 20.04]
  • GPU count and types [x2 A100 40G]
  • Interconnects (if applicable) [PCIe]
  • Python version: py3.9
  • Deepspeed version from both git

Launcher context
deepspeed launcher

The code for implementing flash attention in my own project is as follows:

""" https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py.
"""

from typing import List, Optional, Tuple, Dict

import torch
import transformers
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input


def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    # TODO: padding embedding size for being divisible by 64.
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def llama_flash_attn_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """Input shape: Batch x Time x Channel

    attention_mask: [bsz, q_len]
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    # [bsz, q_len, nh, hd]
    # [bsz, nh, q_len, hd]

    kv_seq_len = key_states.shape[-2]
    assert past_key_value is None, "past_key_value is not supported"

    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]
    assert not output_attentions, "output_attentions is not supported"
    assert not use_cache, "use_cache is not supported"

    # Flash attention codes from
    # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

    # transform the data into the format required by flash attention
    qkv = torch.stack([query_states, key_states, value_states], dim=2)  # [bsz, nh, 3, q_len, hd]
    qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
    # We have disabled _prepare_decoder_attention_mask in LlamaModel
    # the attention_mask should be the same as the key_padding_mask
    attention_mask = torch.ones((bsz, q_len), device=qkv.device)
    key_padding_mask = attention_mask

    if key_padding_mask is None:
        qkv = rearrange(qkv, 'b s ... -> (b s) ...')
        max_s = q_len
        cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
                                 device=qkv.device)
        output = flash_attn_unpadded_qkvpacked_func(
            qkv, cu_q_lens, max_s, 0.0,
            softmax_scale=None, causal=True
        )
        output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
    else:
        nheads = qkv.shape[-2]
        x = rearrange(qkv, 'b s three h d -> b s (three h d)')
        x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
        x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
        output_unpad = flash_attn_unpadded_qkvpacked_func(
            x_unpad, cu_q_lens, max_s, 0.0,
            softmax_scale=None, causal=True
        )
        output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                     indices, bsz, q_len), 'b s (h d) -> b s h d', h=nheads)
    return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None


def llama_flash_attn_forward_pytorch(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    """Input shape: Batch x Time x Channel

    attention_mask: [bsz, q_len]
    """
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    # [bsz, q_len, nh, hd]
    # [bsz, nh, q_len, hd]

    kv_seq_len = key_states.shape[-2]
    assert past_key_value is None, "past_key_value is not supported"

    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    # [bsz, nh, t, hd]
    assert not output_attentions, "output_attentions is not supported"
    assert not use_cache, "use_cache is not supported"

    with torch.backends.cuda.sdp_kernel(
            enable_flash=True,
            enable_math=False,
            enable_mem_efficient=False
    ):
        out = F.scaled_dot_product_attention(
            query_states, key_states, value_states,
            is_causal=True,
        )

    out = out.transpose(1, 2)
    out = out.reshape(bsz, q_len, self.hidden_size)

    return out, None, None

# Just hack here by calling the following method at main.py.
def replace_llama_attn_with_flash_attn():
    transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_flash_attn_forward
    # transformers.models.llama.modeling_llama.LlamaAttention.forward = llama_flash_attn_forward_pytorch

@SparkJiao SparkJiao added bug Something isn't working training labels Jul 4, 2023
@sx1999
Copy link

sx1999 commented Jan 26, 2024

Hello, I also encountered this problem, could you tell me how to solve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants