Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Integrate flash attention #853

Merged
merged 24 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install wheel pytest pytest-cov nvidia-pyindex
python -m pip install git+https://github.com/HazyResearch/flash-attention.git
OrangeSodahub marked this conversation as resolved.
Show resolved Hide resolved
pip install -e "client/[test]"
pip install -e "server/[tensorrt]"
- name: Test
Expand Down
138 changes: 138 additions & 0 deletions server/clip_server/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple

from torch.nn.functional import linear

try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None


class MultiheadAttention(nn.MultiheadAttention):
def __init__(
self,
embed_dim,
num_heads,
dropout=0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
) -> None:
assert flash_attn_unpadded_func is not None, "FlashAttention is not installed."
super().__init__(
embed_dim,
num_heads,
dropout,
bias,
add_bias_kv,
add_zero_attn,
kdim,
vdim,
batch_first,
device,
dtype,
)

def attention(
self,
q,
k,
v,
batch_size=1,
seqlen=77,
softmax_scale=None,
attention_dropout=0.0,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
):
"""Implements the multihead softmax attention.
Arguments
---------
q,k,v: The tensor containing the query, key, and value. each of (B*S, H, D)
key_padding_mask: a bool tensor of shape (B, S)

"""
assert not need_weights
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda

if cu_seqlens is None:
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=q.device,
)
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
attention_dropout,
softmax_scale=softmax_scale,
causal=causal,
)

return output

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
seqlen, batch_size, embed_dim = query.shape

# in-projection and rearrange `b s (h d) -> (b s) h d`
q, k, v = linear(query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = (
q.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)
k = (
k.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)
v = (
v.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)

# flash attention (use causal mask)
causal = attn_mask is not None
attn_output = self.attention(q, k, v, batch_size, seqlen, causal=causal)

# out-projection
# `(b s) h d -> s b (h d)`
attn_output = attn_output.contiguous().view(
batch_size, seqlen, self.num_heads, self.head_dim
)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(seqlen, batch_size, embed_dim)
)
attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias)

return attn_output, None
9 changes: 8 additions & 1 deletion server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from open_clip.timm_model import TimmModel
from open_clip.utils import freeze_batch_norm_2d
from open_clip.factory import _MODEL_CONFIGS
from clip_server.model.flash_attention import MultiheadAttention


# From PyTorch internals
Expand Down Expand Up @@ -279,9 +280,15 @@ def __init__(
scale_fc: bool = False,
):
super().__init__()
head_dim = d_model // n_head
self.flash_attention = head_dim % 8 == 0 and head_dim <= 128

self.ln_1 = LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.attn = (
MultiheadAttention(d_model, n_head)
if torch.cuda.is_available() and self.flash_attention
else nn.MultiheadAttention(d_model, n_head)
)
self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity()

self.ln_2 = LayerNorm(d_model)
Expand Down