-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Integrate flash attention #853
feat: Integrate flash attention #853
Conversation
a50fc2d
to
24deab4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also present the performance and memory difference with flash_attention in the description.
And what's more, please confirm that the flash_attention can load the pre-trained parameters, and result the consistency output with |
I've confirmed the results' correctness in this repo. There exists max errors up to 1E-3 (<=0.003) level. |
@OrangeSodahub Please also present the performance under full-precision inference in the |
Codecov Report
@@ Coverage Diff @@
## main #853 +/- ##
==========================================
- Coverage 79.32% 78.26% -1.07%
==========================================
Files 21 22 +1
Lines 1596 1633 +37
==========================================
+ Hits 1266 1278 +12
- Misses 330 355 +25
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Flash attention does not support fp32 see: https://github.com/HazyResearch/flash-attention#alpha-release-01 |
@OrangeSodahub It seems that the ci on GPU failed. Any clues about this problem? |
65b4bce
to
0ed3c3a
Compare
Now the situation is: the gpu device only has the CUDA 11.0 and only pytorch <=1.7.1 matches. But <1.8 does not support |
@@ -26,6 +26,14 @@ | |||
from open_clip.utils import freeze_batch_norm_2d | |||
from open_clip.factory import _MODEL_CONFIGS | |||
|
|||
# Use flash attention | |||
try: | |||
from clip_server.model.flash_attention import MultiheadAttention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to check the compatibility of cuda and pytorch here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Insert the cuda and pytorch version checking here.
@@ -26,6 +26,14 @@ | |||
from open_clip.utils import freeze_batch_norm_2d | |||
from open_clip.factory import _MODEL_CONFIGS | |||
|
|||
# Use flash attention | |||
try: | |||
from clip_server.model.flash_attention import MultiheadAttention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Insert the cuda and pytorch version checking here.
This reverts commit 979a42c.
Goal: integrate the flash attention (official repo) into the CLIP Model (will replace the
torch.nn.MultiheadAttention
).Required: CUDA11 and PyTorch>=1.8 and Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080)
Performance:
Test on left:
Tesla T4
right:RTX3080
, Times = 100, using causal mask