Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Integrate flash attention #853

Merged
merged 24 commits into from
Nov 15, 2022

Conversation

OrangeSodahub
Copy link
Contributor

@OrangeSodahub OrangeSodahub commented Nov 9, 2022

Goal: integrate the flash attention (official repo) into the CLIP Model (will replace the torch.nn.MultiheadAttention).

Required: CUDA11 and PyTorch>=1.8 and Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080)

Performance:

Test on left: Tesla T4 right: RTX3080, Times = 100, using causal mask


Model: `ViT-L-14::laion2b-s32b-b82k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |      | baseline(s) | flash_attn(s)   | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|      |:-----------:|:---------------:|:------------:|----------:|
| (1, 77)           |     0.81193 |       0.66437 |      1.22211 |   0.00066 |      | 0.66602     | 0.55968         | 1.1899       | 0.000599  |
| (2, 77)           |      0.8264 |       0.70035 |      1.17998 |    0.0007 |      | 0.69619     | 0.60348         | 1.1536       | 0.000570  |
| (4, 77)           |     0.81998 |       0.69887 |       1.1733 |   0.00064 |      | 0.75353     | 0.65664         | 1.1225       | 0.000573  |
| (8, 77)           |     1.05975 |       0.85742 |      1.23597 |   0.00054 |      | 0.73511     | 0.64203         | 1.1449       | 0.000530  |
| (16, 77)          |     2.12992 |       1.68367 |      1.26504 |   0.00057 |      | 0.78381     | 0.72705         | 1.0780       | 0.000545  |
| (1, 3, 224, 224)  |     2.00593 |       1.34507 |      1.49131 |   0.00155 |      | 1.25425     | 1.10564         | 1.1344       | 0.001713  |
| (2, 3, 224, 224)  |     3.74493 |       2.29818 |      1.62952 |   0.00122 |      | 1.32202     | 1.22698         | 1.0775       | 0.001739  |
| (4, 3, 224, 224)  |     7.35365 |       4.44447 |      1.65456 |   0.00159 |      | 2.36880     | 2.11119         | 1.0926       | 0.001227  |
| (8, 3, 224, 224)  |    14.63006 |       9.05604 |       1.6155 |   0.00135 |      | 3.88597     | 3.76666         | 1.0316       | 0.001441  |
| (16, 3, 224, 224) |    29.63732 |      18.29142 |      1.62028 |   0.00155 |      | 7.84954     | 7.20817         | 1.0889       | 0.001533  |

Model: `ViT-B-32::laion2b-s34b-b79k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |      | baseline(s) | flash_attn(s)   | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|      |:-----------:|:---------------:|:------------:|----------:|
| (1, 77)           |     0.42692 |       0.37867 |       1.1274 |   0.00076 |      | 0.73374     | 0.63919         | 1.0914       | 0.000599  |
| (2, 77)           |     0.46548 |       0.42958 |      1.08358 |   0.00077 |      | 0.72493     | 0.66131         | 1.0967       | 0.000570  |
| (4, 77)           |     0.49818 |       0.46378 |      1.07417 |   0.00079 |      | 0.70449     | 0.64658         | 1.0946       | 0.000573  |
| (8, 77)           |     0.48738 |       0.45324 |       1.0753 |   0.00068 |      | 0.71474     | 0.68031         | 1.0902       | 0.000530  |
| (16, 77)          |      0.4764 |       0.44315 |      1.07502 |   0.00067 |      | 0.71016     | 0.70077         | 1.0960       | 0.000545  |
| (1, 3, 224, 224)  |      0.4349 |       0.40392 |       1.0767 |   0.00054 |      | 0.65047     | 0.59596         | 1.0914       | 0.000601  |
| (2, 3, 224, 224)  |      0.4445 |       0.42513 |      1.04555 |   0.00058 |      | 0.67756     | 0.63876         | 1.0607       | 0.000556  |
| (4, 3, 224, 224)  |     0.43605 |       0.41524 |      1.05009 |   0.00054 |      | 0.71534     | 0.67424         | 1.0609       | 0.000552  |
| (8, 3, 224, 224)  |     0.47367 |       0.45316 |      1.04527 |   0.00064 |      | 0.66634     | 0.62738         | 1.0620       | 0.000597  |
| (16, 3, 224, 224) |     0.51586 |       0.50555 |       1.0204 |   0.00058 |      | 0.6718      | 0.63712         | 1.0545       | 0.000578  |

Model: `ViT-B-16::laion400m_e31`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |      | baseline(s) | flash_attn(s)   | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|      |:-----------:|:---------------:|:------------:|----------:|
| (1, 77)           |     0.42736 |       0.38143 |      1.12042 |   0.00051 |      | 0.69159     | 0.61381         | 1.1267       | 0.000541  |`
| (2, 77)           |     0.46499 |         0.432 |      1.07635 |   0.00025 |      | 0.77649     | 0.70859         | 1.0958       | 0.000248  |
| (4, 77)           |     0.50052 |       0.46798 |      1.06953 |   0.00045 |      | 0.78205     | 0.71427         | 1.0948       | 0.000425  |
| (8, 77)           |     0.48717 |       0.45675 |      1.06659 |   0.00043 |      | 0.72512     | 0.66471         | 1.0908       | 0.000419  |
| (16, 77)          |     0.47855 |       0.44717 |      1.07017 |    0.0005 |      | 0.72748     | 0.66784         | 1.0893       | 0.000457  |
| (1, 3, 224, 224)  |     0.45754 |       0.38839 |      1.17803 |   0.00044 |      | 0.70322     | 0.59588         | 1.1801       | 0.000410  |
| (2, 3, 224, 224)  |     0.50152 |       0.44523 |      1.12641 |   0.00042 |      | 0.73814     | 0.64270         | 1.1484       | 0.000412  |
| (4, 3, 224, 224)  |     0.56444 |       0.51215 |       1.1021 |   0.00044 |      | 0.75904     | 0.66596         | 1.1397       | 0.000395  |
| (8, 3, 224, 224)  |     0.93341 |       0.82501 |      1.13139 |   0.00045 |      | 1.18015     | 1.03530         | 1.1399       | 0.000427  |
| (16, 3, 224, 224) |     1.77167 |       1.58981 |      1.11439 |   0.00044 |      | 2.22424     | 1.97961         | 1.1235       | 0.000429  |

Model: `ViT-g-14::laion2b-s12b-b42k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|
| (1, 77)           |      0.8894 |        0.7984 |      1.11397 |   0.00093 |
| (2, 77)           |     0.89632 |       0.83137 |      1.07812 |   0.00079 |
| (4, 77)           |     0.93494 |       0.87282 |      1.07117 |    0.0008 |
| (8, 77)           |     1.15918 |       1.08828 |      1.06514 |   0.00073 |
| (16, 77)          |     2.05242 |       1.92306 |      1.06727 |   0.00058 |
| (1, 3, 224, 224)  |     1.94179 |       1.85599 |      1.04623 |   0.00115 |
| (2, 3, 224, 224)  |     3.29228 |       3.10451 |      1.06048 |   0.00109 |
| (4, 3, 224, 224)  |     5.87704 |       5.55928 |      1.05715 |   0.00115 |
| (8, 3, 224, 224)  |    10.37121 |        9.8581 |      1.05204 |   0.00101 |
| (16, 3, 224, 224) |    20.85575 |      19.86552 |      1.04984 |   0.00112 |

Model: `ViT-H-14::laion2b-s32b-b79k`

|       shape       | baseline(s) | flash_attn(s) | speed up (x) | mean diff |
|:-----------------:|:-----------:|:-------------:|:------------:|:---------:|
| (1, 77)           |     0.90747 |       0.81617 |      1.11186 |    0.0008 |
| (2, 77)           |     0.90999 |       0.84415 |      1.07799 |   0.00077 |
| (4, 77)           |     0.97304 |        0.9068 |      1.07305 |   0.00053 |
| (8, 77)           |     2.55404 |       2.48595 |      1.02738 |   0.00067 |
| (16, 77)          |     4.36827 |       4.11261 |      1.06216 |   0.00067 |
| (1, 3, 224, 224)  |     1.33663 |       1.26552 |      1.05619 |    0.0009 |
| (2, 3, 224, 224)  |       2.273 |       2.12032 |        1.072 |    0.0008 |
| (4, 3, 224, 224)  |     4.04465 |       3.78626 |      1.06824 |   0.00082 |
| (8, 3, 224, 224)  |     7.23102 |       6.80665 |      1.06234 |    0.0008 |
| (16, 3, 224, 224) |    13.64971 |       12.8397 |      1.06308 |   0.00087 |

@OrangeSodahub OrangeSodahub force-pushed the integrate-flash-attention branch from a50fc2d to 24deab4 Compare November 9, 2022 05:27
Copy link
Member

@numb3r3 numb3r3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also present the performance and memory difference with flash_attention in the description.

server/clip_server/model/model.py Outdated Show resolved Hide resolved
server/clip_server/model/flash_attention.py Show resolved Hide resolved
@numb3r3
Copy link
Member

numb3r3 commented Nov 9, 2022

And what's more, please confirm that the flash_attention can load the pre-trained parameters, and result the consistency output with nn.MultiHeadAttention

@OrangeSodahub
Copy link
Contributor Author

OrangeSodahub commented Nov 9, 2022

And what's more, please confirm that the flash_attention can load the pre-trained parameters, and result the consistency output with nn.MultiHeadAttention

I've confirmed the results' correctness in this repo. There exists max errors up to 1E-3 (<=0.003) level.

@numb3r3
Copy link
Member

numb3r3 commented Nov 9, 2022

@OrangeSodahub Please also present the performance under full-precision inference in the Openclip-flashattn repo

@codecov
Copy link

codecov bot commented Nov 9, 2022

Codecov Report

Merging #853 (733d572) into main (f96ce54) will decrease coverage by 1.06%.
The diff coverage is 34.21%.

@@            Coverage Diff             @@
##             main     #853      +/-   ##
==========================================
- Coverage   79.32%   78.26%   -1.07%     
==========================================
  Files          21       22       +1     
  Lines        1596     1633      +37     
==========================================
+ Hits         1266     1278      +12     
- Misses        330      355      +25     
Flag Coverage Δ
cas 78.26% <34.21%> (-1.07%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
server/clip_server/model/flash_attention.py 20.00% <20.00%> (ø)
server/clip_server/model/model.py 70.11% <87.50%> (+0.25%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@OrangeSodahub
Copy link
Contributor Author

OrangeSodahub commented Nov 9, 2022

@OrangeSodahub Please also present the performance under full-precision inference in the Openclip-flashattn repo

Flash attention does not support fp32 see: https://github.com/HazyResearch/flash-attention#alpha-release-01

.github/workflows/ci.yml Outdated Show resolved Hide resolved
.github/workflows/ci.yml Outdated Show resolved Hide resolved
@numb3r3
Copy link
Member

numb3r3 commented Nov 11, 2022

@OrangeSodahub It seems that the ci on GPU failed. Any clues about this problem?

@OrangeSodahub OrangeSodahub force-pushed the integrate-flash-attention branch from 65b4bce to 0ed3c3a Compare November 11, 2022 04:03
@OrangeSodahub
Copy link
Contributor Author

Now the situation is: the gpu device only has the CUDA 11.0 and only pytorch <=1.7.1 matches. But <1.8 does not support flash_attn.

@@ -26,6 +26,14 @@
from open_clip.utils import freeze_batch_norm_2d
from open_clip.factory import _MODEL_CONFIGS

# Use flash attention
try:
from clip_server.model.flash_attention import MultiheadAttention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to check the compatibility of cuda and pytorch here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert the cuda and pytorch version checking here.

@@ -26,6 +26,14 @@
from open_clip.utils import freeze_batch_norm_2d
from open_clip.factory import _MODEL_CONFIGS

# Use flash attention
try:
from clip_server.model.flash_attention import MultiheadAttention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert the cuda and pytorch version checking here.

This reverts commit 979a42c.
@numb3r3 numb3r3 marked this pull request as ready for review November 15, 2022 07:53
@numb3r3 numb3r3 merged commit e4717a3 into jina-ai:main Nov 15, 2022
@OrangeSodahub OrangeSodahub deleted the integrate-flash-attention branch November 15, 2022 09:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants