From 0223e6fa071534bfc1a3b2010dd7065623afd540 Mon Sep 17 00:00:00 2001 From: YangXiuyu Date: Thu, 24 Nov 2022 13:33:34 +0800 Subject: [PATCH] fix: add pip installable flash attention (#863) * fix: setup.py * fix: ci * fix: install flash-attn needs torch * fix: typo * fix: passby flash attn not installed * fix: ci Co-authored-by: Ziniu Yu Co-authored-by: Ziniu Yu --- .github/workflows/ci.yml | 7 +++---- server/setup.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 87dae3336..2ae094b7b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -155,14 +155,13 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install wheel pytest pytest-cov nvidia-pyindex + pip install -e "client/[test]" + pip install -e "server/[tensorrt]" { - python -m pip install torch==1.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html - python -m pip install git+https://github.com/HazyResearch/flash-attention.git + pip install -e "server/[flash-attn]" } || { echo "flash attention was not installed." } - pip install -e "client/[test]" - pip install -e "server/[tensorrt]" - name: Test id: test run: | diff --git a/server/setup.py b/server/setup.py index 78275b0c1..faaff6cfb 100644 --- a/server/setup.py +++ b/server/setup.py @@ -59,6 +59,7 @@ 'tensorrt': ['nvidia-tensorrt'], 'transformers': ['transformers>=4.16.2'], 'search': ['annlite>=0.3.10'], + 'flash-attn': ['flash-attn'], }, classifiers=[ 'Development Status :: 5 - Production/Stable',