From e4717a35f850e6a2cd8b4d8b4c994fad30fd5c72 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Tue, 15 Nov 2022 15:53:19 +0800 Subject: [PATCH] feat: Integrate flash attention (#853) * feat: add flash attention * fix: flash-attn dependency * fix: flash_attn dependency * fix: use cuda in flash_attn * fix: import flash_attn * fix: errors with batch_size > 1 * fix: remove use_flash * fix: optimize shape operation * fix: add causal mask * fix: flash attention import * fix: setup.py * fix: flash attention import * fix: test ci * fix: test ci * fix: test ci * fix: test ci * fix: test ci * fix: test ci * fix: test ci * fix: test ci * fix: test ci * fix: passby flash attention * fix: test ci * Revert "fix: test ci" This reverts commit 979a42cdd855671f852d26e681164567e10e0092. --- .github/workflows/ci.yml | 6 + server/clip_server/model/flash_attention.py | 133 ++++++++++++++++++++ server/clip_server/model/model.py | 18 ++- 3 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 server/clip_server/model/flash_attention.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fd38adf89..87dae3336 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,6 +155,12 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install wheel pytest pytest-cov nvidia-pyindex + { + python -m pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install git+https://github.com/HazyResearch/flash-attention.git + } || { + echo "flash attention was not installed." + } pip install -e "client/[test]" pip install -e "server/[tensorrt]" - name: Test diff --git a/server/clip_server/model/flash_attention.py b/server/clip_server/model/flash_attention.py new file mode 100644 index 000000000..fe368f33e --- /dev/null +++ b/server/clip_server/model/flash_attention.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +from torch import Tensor +from typing import Optional, Tuple + +from torch.nn.functional import linear +from flash_attn.flash_attn_interface import flash_attn_unpadded_func + + +class MultiheadAttention(nn.MultiheadAttention): + def __init__( + self, + embed_dim, + num_heads, + dropout=0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None, + ) -> None: + super().__init__( + embed_dim, + num_heads, + dropout, + bias, + add_bias_kv, + add_zero_attn, + kdim, + vdim, + batch_first, + device, + dtype, + ) + + def attention( + self, + q, + k, + v, + batch_size=1, + seqlen=77, + softmax_scale=None, + attention_dropout=0.0, + causal=False, + cu_seqlens=None, + max_s=None, + need_weights=False, + ): + """Implements the multihead softmax attention. + Arguments + --------- + q,k,v: The tensor containing the query, key, and value. each of (B*S, H, D) + key_padding_mask: a bool tensor of shape (B, S) + + """ + assert not need_weights + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda + + if cu_seqlens is None: + max_s = seqlen + cu_seqlens = torch.arange( + 0, + (batch_size + 1) * seqlen, + step=seqlen, + dtype=torch.int32, + device=q.device, + ) + output = flash_attn_unpadded_func( + q, + k, + v, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + attention_dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return output + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + ) -> Tuple[Tensor, Optional[Tensor]]: + # set up shape vars + seqlen, batch_size, embed_dim = query.shape + + # in-projection and rearrange `b s (h d) -> (b s) h d` + q, k, v = linear(query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = ( + q.transpose(0, 1) + .contiguous() + .view(batch_size * seqlen, self.num_heads, self.head_dim) + ) + k = ( + k.transpose(0, 1) + .contiguous() + .view(batch_size * seqlen, self.num_heads, self.head_dim) + ) + v = ( + v.transpose(0, 1) + .contiguous() + .view(batch_size * seqlen, self.num_heads, self.head_dim) + ) + + # flash attention (use causal mask) + causal = attn_mask is not None + attn_output = self.attention(q, k, v, batch_size, seqlen, causal=causal) + + # out-projection + # `(b s) h d -> s b (h d)` + attn_output = attn_output.contiguous().view( + batch_size, seqlen, self.num_heads, self.head_dim + ) + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(seqlen, batch_size, embed_dim) + ) + attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias) + + return attn_output, None diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 7c4d6633e..e87c56411 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -26,6 +26,14 @@ from open_clip.utils import freeze_batch_norm_2d from open_clip.factory import _MODEL_CONFIGS +# Use flash attention +try: + from clip_server.model.flash_attention import MultiheadAttention + + FLASH_ATTENTION_AVAILABLE = True +except: + FLASH_ATTENTION_AVAILABLE = False + # From PyTorch internals def _ntuple(n): @@ -279,9 +287,17 @@ def __init__( scale_fc: bool = False, ): super().__init__() + head_dim = d_model // n_head + self.flash_attention = head_dim % 8 == 0 and head_dim <= 128 self.ln_1 = LayerNorm(d_model) - self.attn = nn.MultiheadAttention(d_model, n_head) + self.attn = ( + MultiheadAttention(d_model, n_head) + if FLASH_ATTENTION_AVAILABLE + and torch.cuda.is_available() + and self.flash_attention + else nn.MultiheadAttention(d_model, n_head) + ) self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() self.ln_2 = LayerNorm(d_model)