Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Integrate flash attention #853

Merged
merged 24 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install wheel pytest pytest-cov nvidia-pyindex
{
python -m pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install git+https://github.com/HazyResearch/flash-attention.git
} || {
echo "flash attention was not installed."
}
pip install -e "client/[test]"
pip install -e "server/[tensorrt]"
- name: Test
Expand Down
133 changes: 133 additions & 0 deletions server/clip_server/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple

from torch.nn.functional import linear
from flash_attn.flash_attn_interface import flash_attn_unpadded_func


class MultiheadAttention(nn.MultiheadAttention):
def __init__(
self,
embed_dim,
num_heads,
dropout=0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
) -> None:
super().__init__(
embed_dim,
num_heads,
dropout,
bias,
add_bias_kv,
add_zero_attn,
kdim,
vdim,
batch_first,
device,
dtype,
)

def attention(
self,
q,
k,
v,
batch_size=1,
seqlen=77,
softmax_scale=None,
attention_dropout=0.0,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
):
"""Implements the multihead softmax attention.
Arguments
---------
q,k,v: The tensor containing the query, key, and value. each of (B*S, H, D)
key_padding_mask: a bool tensor of shape (B, S)

"""
assert not need_weights
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda

if cu_seqlens is None:
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=q.device,
)
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
attention_dropout,
softmax_scale=softmax_scale,
causal=causal,
)

return output

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
seqlen, batch_size, embed_dim = query.shape

# in-projection and rearrange `b s (h d) -> (b s) h d`
q, k, v = linear(query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = (
q.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)
k = (
k.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)
v = (
v.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)

# flash attention (use causal mask)
causal = attn_mask is not None
attn_output = self.attention(q, k, v, batch_size, seqlen, causal=causal)

# out-projection
# `(b s) h d -> s b (h d)`
attn_output = attn_output.contiguous().view(
batch_size, seqlen, self.num_heads, self.head_dim
)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(seqlen, batch_size, embed_dim)
)
attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias)

return attn_output, None
18 changes: 17 additions & 1 deletion server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
from open_clip.utils import freeze_batch_norm_2d
from open_clip.factory import _MODEL_CONFIGS

# Use flash attention
try:
from clip_server.model.flash_attention import MultiheadAttention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to check the compatibility of cuda and pytorch here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert the cuda and pytorch version checking here.


FLASH_ATTENTION_AVAILABLE = True
except:
FLASH_ATTENTION_AVAILABLE = False


# From PyTorch internals
def _ntuple(n):
Expand Down Expand Up @@ -279,9 +287,17 @@ def __init__(
scale_fc: bool = False,
):
super().__init__()
head_dim = d_model // n_head
self.flash_attention = head_dim % 8 == 0 and head_dim <= 128

self.ln_1 = LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.attn = (
MultiheadAttention(d_model, n_head)
if FLASH_ATTENTION_AVAILABLE
and torch.cuda.is_available()
and self.flash_attention
else nn.MultiheadAttention(d_model, n_head)
)
self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity()

self.ln_2 = LayerNorm(d_model)
Expand Down