Skip to content

Commit

Permalink
feat: add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
OrangeSodahub committed Nov 9, 2022
1 parent 4fcbf68 commit 82fdb56
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 1 deletion.
89 changes: 89 additions & 0 deletions server/clip_server/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple

from torch.nn.functional import linear
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None


class MultiheadAttention(nn.MultiheadAttention):
def __init__(self, embed_dim, num_heads, dropout=0, bias=True, add_bias_kv=False, add_zero_attn=False,
kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
assert flash_attn_unpadded_func is not None, "FlashAttention is not installed."
super().__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn,
kdim, vdim, batch_first, device, dtype)

def attention(
self,
q, k, v,
batch_size=1,
seqlen=77,
key_padding_mask=None,
softmax_scale=None,
attention_dropout=0.0,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False
):
"""Implements the multihead softmax attention.
Arguments
---------
q,k,v: The tensor containing the query, key, and value. each of (B*S, H, D)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda

if cu_seqlens is None:
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=q.device)
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s, attention_dropout,
softmax_scale=softmax_scale, causal=causal
)

return output

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
seqlen, batch_size, embed_dim = query.shape

if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(self.num_heads, rounding_mode='trunc')
else:
head_dim = embed_dim // self.num_heads
assert head_dim * self.num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {self.num_heads}"

# in-projection
q, k, v = linear(query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.contiguous().view((batch_size * seqlen, self.num_heads, head_dim))
k = k.contiguous().view((batch_size * seqlen, self.num_heads, head_dim))
v = v.contiguous().view((batch_size * seqlen, self.num_heads, head_dim))

# flash attention
attn_output = self.attention(q, k, v, batch_size, seqlen)

# out-projection
attn_output = attn_output.contiguous().view(seqlen * batch_size, embed_dim)
attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias)
attn_output = attn_output.view(seqlen, batch_size, embed_dim)

return attn_output, None
5 changes: 4 additions & 1 deletion server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from open_clip.timm_model import TimmModel
from open_clip.utils import freeze_batch_norm_2d
from open_clip.factory import _MODEL_CONFIGS
from flash_attention import MultiheadAttention


# From PyTorch internals
Expand Down Expand Up @@ -277,11 +278,13 @@ def __init__(
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
use_flash: bool = True,
):
super().__init__()

self.ln_1 = LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
# TODO: `use_flash` needs to be verified
self.attn = MultiheadAttention(d_model, n_head) if use_flash else nn.MultiheadAttention(d_model, n_head)
self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity()

self.ln_2 = LayerNorm(d_model)
Expand Down

0 comments on commit 82fdb56

Please sign in to comment.