Skip to content

Commit

Permalink
feat: Integrate flash attention (#853)
Browse files Browse the repository at this point in the history
* feat: add flash attention

* fix: flash-attn dependency

* fix: flash_attn dependency

* fix: use cuda in flash_attn

* fix: import flash_attn

* fix: errors with batch_size > 1

* fix: remove use_flash

* fix: optimize shape operation

* fix: add causal mask

* fix: flash attention import

* fix: setup.py

* fix: flash attention import

* fix: test ci

* fix: test ci

* fix: test ci

* fix: test ci

* fix: test ci

* fix: test ci

* fix: test ci

* fix: test ci

* fix: test ci

* fix: passby flash attention

* fix: test ci

* Revert "fix: test ci"

This reverts commit 979a42c.
  • Loading branch information
OrangeSodahub authored Nov 15, 2022
1 parent d2ecec6 commit e4717a3
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 1 deletion.
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install wheel pytest pytest-cov nvidia-pyindex
{
python -m pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install git+https://github.com/HazyResearch/flash-attention.git
} || {
echo "flash attention was not installed."
}
pip install -e "client/[test]"
pip install -e "server/[tensorrt]"
- name: Test
Expand Down
133 changes: 133 additions & 0 deletions server/clip_server/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional, Tuple

from torch.nn.functional import linear
from flash_attn.flash_attn_interface import flash_attn_unpadded_func


class MultiheadAttention(nn.MultiheadAttention):
def __init__(
self,
embed_dim,
num_heads,
dropout=0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
) -> None:
super().__init__(
embed_dim,
num_heads,
dropout,
bias,
add_bias_kv,
add_zero_attn,
kdim,
vdim,
batch_first,
device,
dtype,
)

def attention(
self,
q,
k,
v,
batch_size=1,
seqlen=77,
softmax_scale=None,
attention_dropout=0.0,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
):
"""Implements the multihead softmax attention.
Arguments
---------
q,k,v: The tensor containing the query, key, and value. each of (B*S, H, D)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda

if cu_seqlens is None:
max_s = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=q.device,
)
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
attention_dropout,
softmax_scale=softmax_scale,
causal=causal,
)

return output

def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
seqlen, batch_size, embed_dim = query.shape

# in-projection and rearrange `b s (h d) -> (b s) h d`
q, k, v = linear(query, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = (
q.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)
k = (
k.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)
v = (
v.transpose(0, 1)
.contiguous()
.view(batch_size * seqlen, self.num_heads, self.head_dim)
)

# flash attention (use causal mask)
causal = attn_mask is not None
attn_output = self.attention(q, k, v, batch_size, seqlen, causal=causal)

# out-projection
# `(b s) h d -> s b (h d)`
attn_output = attn_output.contiguous().view(
batch_size, seqlen, self.num_heads, self.head_dim
)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(seqlen, batch_size, embed_dim)
)
attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias)

return attn_output, None
18 changes: 17 additions & 1 deletion server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
from open_clip.utils import freeze_batch_norm_2d
from open_clip.factory import _MODEL_CONFIGS

# Use flash attention
try:
from clip_server.model.flash_attention import MultiheadAttention

FLASH_ATTENTION_AVAILABLE = True
except:
FLASH_ATTENTION_AVAILABLE = False


# From PyTorch internals
def _ntuple(n):
Expand Down Expand Up @@ -279,9 +287,17 @@ def __init__(
scale_fc: bool = False,
):
super().__init__()
head_dim = d_model // n_head
self.flash_attention = head_dim % 8 == 0 and head_dim <= 128

self.ln_1 = LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.attn = (
MultiheadAttention(d_model, n_head)
if FLASH_ATTENTION_AVAILABLE
and torch.cuda.is_available()
and self.flash_attention
else nn.MultiheadAttention(d_model, n_head)
)
self.ln_attn = LayerNorm(d_model) if scale_attn else nn.Identity()

self.ln_2 = LayerNorm(d_model)
Expand Down

0 comments on commit e4717a3

Please sign in to comment.