Skip to content

Commit

Permalink
Merge pull request #506 from LDYang694/master
Browse files Browse the repository at this point in the history
add new cuda12.2 for g++11
  • Loading branch information
MenghaoGuo authored Apr 23, 2024
2 parents 0287f74 + 3092bba commit 63022e1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/jittor/src/misc/cuda_atomic.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ T cuda_atomic_mul(T* a, T b) {
return old_f;
}


#if CUDA_ARCH >= 800
template<> __device__
__half cuda_atomic_max(__half* a, __half b) {
auto old_f = *a;
Expand Down Expand Up @@ -166,7 +166,7 @@ __half cuda_atomic_min(__half* a, __half b) {
}
return old_f;
}

#endif
#if CUDA_ARCH >= 800
template<> __device__
__nv_bfloat16 cuda_atomic_max(__nv_bfloat16* a, __nv_bfloat16 b) {
Expand Down
5 changes: 4 additions & 1 deletion python/jittor_utils/install_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def install_cuda():
LOG.w(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.2")
return None
else:
if cuda_driver_version >= [11,2]:
if cuda_driver_version >= [12,2]:
cuda_tgz = "cuda12.2_cudnn8_linux.tgz"
md5 = "7afda9332a268f29354488f13b489f53"
elif cuda_driver_version >= [11,2]:
cuda_tgz = "cuda11.2_cudnn8_linux.tgz"
md5 = "b93a1a5d19098e93450ee080509e9836"
elif cuda_driver_version >= [11,]:
Expand Down

0 comments on commit 63022e1

Please sign in to comment.