diff --git a/server/clip_server/model/flash_attention.py b/server/clip_server/model/flash_attention.py new file mode 100644 index 000000000..4db0c2e14 --- /dev/null +++ b/server/clip_server/model/flash_attention.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +from torch import Tensor +from typing import Optional, Tuple + +from torch.nn.functional import linear +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + flash_attn_unpadded_func = None + + +class MultiheadAttention(nn.MultiheadAttention): + def __init__(self, embed_dim, num_heads, dropout=0, bias=True, add_bias_kv=False, add_zero_attn=False, + kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: + assert flash_attn_unpadded_func is not None, "FlashAttention is not installed." + super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, + kdim, vdim, batch_first, device, dtype) + + def attention( + self, + q, k, v, + batch_size=1, + seqlen=77, + key_padding_mask=None, + softmax_scale=None, + attention_dropout=0.0, + causal=False, + cu_seqlens=None, + max_s=None, + need_weights=False + ): + """Implements the multihead softmax attention. + Arguments + --------- + q,k,v: The tensor containing the query, key, and value. each of (B*S, H, D) + key_padding_mask: a bool tensor of shape (B, S) + + """ + assert not need_weights + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda + + if cu_seqlens is None: + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=q.device) + output = flash_attn_unpadded_func( + q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, attention_dropout, + softmax_scale=softmax_scale, causal=causal + ) + + return output + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True + ) -> Tuple[Tensor, Optional[Tensor]]: + # set up shape vars + seqlen, batch_size, embed_dim = query.shape + + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(self.num_heads, rounding_mode='trunc') + else: + head_dim = embed_dim // self.num_heads + assert head_dim * self.num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}" + + # in-projection + q, k, v = linear(query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view((batch_size * seqlen, self.num_heads, head_dim)) + k = k.contiguous().view((batch_size * seqlen, self.num_heads, head_dim)) + v = v.contiguous().view((batch_size * seqlen, self.num_heads, head_dim)) + + # flash attention + attn_output = self.attention(q, k, v, batch_size, seqlen) + + # out-projection + attn_output = attn_output.contiguous().view(seqlen * batch_size, embed_dim) + attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias) + attn_output = attn_output.view(seqlen, batch_size, embed_dim) + + return attn_output, None \ No newline at end of file diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index 7c4d6633e..190195ada 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -25,6 +25,7 @@ from open_clip.timm_model import TimmModel from open_clip.utils import freeze_batch_norm_2d from open_clip.factory import _MODEL_CONFIGS +from flash_attention import MultiheadAttention # From PyTorch internals @@ -277,11 +278,13 @@ def __init__( scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, + use_flash: bool = True, ): super().__init__() self.ln_1 = LayerNorm(d_model) - self.attn = nn.MultiheadAttention(d_model, n_head) + # TODO: `use_flash` needs to be verified + self.attn = MultiheadAttention(d_model, n_head) if use_flash else nn.MultiheadAttention(d_model, n_head) self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity() self.ln_2 = LayerNorm(d_model)