Skip to content

Commit

Permalink
ci: disable sm90 for cuda11 (#683)
Browse files Browse the repository at this point in the history
sm90 is not available for cuda11.
  • Loading branch information
yzh119 authored Dec 18, 2024
1 parent 58a0944 commit 0483042
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
5 changes: 5 additions & 0 deletions scripts/run-ci-build-wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ FLASHINFER_LOCAL_VERSION="cu${CUDA_MAJOR}${CUDA_MINOR}torch${FLASHINFER_CI_TORCH
if [ -n "${FLASHINFER_GIT_SHA}" ]; then
FLASHINFER_LOCAL_VERSION="${FLASHINFER_GIT_SHA}.${FLASHINFER_LOCAL_VERSION}"
fi
if [ "$CUDA_MAJOR" -ge 12 ]; then
FLASHINFER_ENABLE_SM90=1
else
FLASHINFER_ENABLE_SM90=0
fi

echo "::group::Install PyTorch"
pip install torch==$FLASHINFER_CI_TORCH_VERSION --index-url "https://download.pytorch.org/whl/cu${CUDA_MAJOR}${CUDA_MINOR}"
Expand Down
45 changes: 26 additions & 19 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
enable_aot = os.environ.get("FLASHINFER_ENABLE_AOT", "0") == "1"
enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1"
enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1"
enable_sm90 = os.environ.get("FLASHINFER_ENABLE_SM90", "1") == "1"


def get_version():
Expand Down Expand Up @@ -79,16 +80,19 @@ def generate_cuda() -> None:
enable_bf16=enable_bf16,
enable_fp8=enable_fp8,
)
) + get_sm90_instantiation_cu(
argparse.Namespace(
path=gen_dir,
head_dims=head_dims,
pos_encoding_modes=pos_encoding_modes_sm90,
allow_fp16_qk_reductions=allow_fp16_qk_reductions_sm90,
mask_modes=mask_modes,
enable_bf16=enable_bf16,
)
)

if enable_sm90:
aot_kernel_uris += get_sm90_instantiation_cu(
argparse.Namespace(
path=gen_dir,
head_dims=head_dims,
pos_encoding_modes=pos_encoding_modes_sm90,
allow_fp16_qk_reductions=allow_fp16_qk_reductions_sm90,
mask_modes=mask_modes,
enable_bf16=enable_bf16,
)
)
aot_config_str = f"""prebuilt_ops_uri = set({aot_kernel_uris})"""
(root / "flashinfer" / "jit" / "aot_config.py").write_text(aot_config_str)

Expand Down Expand Up @@ -213,17 +217,20 @@ def __init__(self, *args, **kwargs) -> None:
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
),
torch_cpp_ext.CUDAExtension(
name="flashinfer._kernels_sm90",
sources=kernel_sm90_sources + prefill_sm90_sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags + sm90a_flags,
},
),
)
]
if enable_sm90:
ext_modules += [
torch_cpp_ext.CUDAExtension(
name="flashinfer._kernels_sm90",
sources=kernel_sm90_sources + prefill_sm90_sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags + sm90a_flags,
},
),
]

setuptools.setup(
version=get_version(),
Expand Down

0 comments on commit 0483042

Please sign in to comment.