Skip to content

Commit

Permalink
Add same-device checks in CUDA kernels (facebookresearch#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa authored Jun 28, 2021
1 parent eb13109 commit 476ead0
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
11 changes: 11 additions & 0 deletions xformers/components/attention/csrc/cuda/sddmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,17 @@ at::Tensor sddmm_sputnik(
TORCH_CHECK(
!column_indices.is_sparse(), "column_offsets must be a dense tensor");

TORCH_CHECK(a.device() == b.device(), "a should be in the same device as b");
TORCH_CHECK(
a.device() == row_indices.device(),
"a should be in the same device as row_indices");
TORCH_CHECK(
a.device() == row_offsets.device(),
"a should be in the same device as row_offsets");
TORCH_CHECK(
a.device() == column_indices.device(),
"a should be in the same device as column_indices");

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int batch = a.size(0);
int m = a.size(1);
Expand Down
18 changes: 18 additions & 0 deletions xformers/components/attention/csrc/cuda/sddmm2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,15 @@ torch::Tensor coo_sddmm(
TORCH_CHECK(colind.dtype() == torch::kInt32);
TORCH_CHECK(D1.dtype() == torch::kFloat32);
TORCH_CHECK(D2.dtype() == torch::kFloat32);

TORCH_CHECK(
D1.device() == D2.device(), "a should be in the same device as b");
TORCH_CHECK(
D1.device() == rowind.device(),
"a should be in the same device as row_offsets");
TORCH_CHECK(
D1.device() == colind.device(),
"a should be in the same device as column_indices");
return ge_spmm::sddmm_cuda_coo(rowind, colind, D1, D2);
}

Expand All @@ -582,6 +591,15 @@ torch::Tensor csr_sddmm(
TORCH_CHECK(colind.dtype() == torch::kInt32);
TORCH_CHECK(D1.dtype() == torch::kFloat32);
TORCH_CHECK(D2.dtype() == torch::kFloat32);

TORCH_CHECK(
D1.device() == D2.device(), "a should be in the same device as b");
TORCH_CHECK(
D1.device() == rowptr.device(),
"a should be in the same device as row_offsets");
TORCH_CHECK(
D1.device() == colind.device(),
"a should be in the same device as column_indices");
return ge_spmm::sddmm_cuda_csr(rowptr, colind, D1, D2);
}

Expand Down
13 changes: 13 additions & 0 deletions xformers/components/attention/csrc/cuda/sparse_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,19 @@ at::Tensor sparse_softmax_backward_sputnik(
TORCH_CHECK(
!column_indices.is_sparse(), "column_offsets must be a dense tensor");

TORCH_CHECK(
values.device() == grad.device(),
"values should be in the same device as grad");
TORCH_CHECK(
values.device() == row_indices.device(),
"a should be in the same device as row_indices");
TORCH_CHECK(
values.device() == row_offsets.device(),
"a should be in the same device as row_offsets");
TORCH_CHECK(
values.device() == column_indices.device(),
"a should be in the same device as column_indices");

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

int batch = values.size(0);
Expand Down
13 changes: 13 additions & 0 deletions xformers/components/attention/csrc/cuda/spmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,19 @@ at::Tensor spmm_sputnik(
TORCH_CHECK(
!column_indices.is_sparse(), "column_offsets must be a dense tensor");

TORCH_CHECK(
values.device() == b.device(),
"values should be in the same device as b");
TORCH_CHECK(
values.device() == row_indices.device(),
"a should be in the same device as row_indices");
TORCH_CHECK(
values.device() == row_offsets.device(),
"a should be in the same device as row_offsets");
TORCH_CHECK(
values.device() == column_indices.device(),
"a should be in the same device as column_indices");

cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int batch = b.size(0);
int k = b.size(1);
Expand Down

0 comments on commit 476ead0

Please sign in to comment.