-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlog_matmul_kernel.cu
77 lines (65 loc) · 2.26 KB
/
log_matmul_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
namespace {
template <typename scalar_t>
__global__ void log_matmul_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> a,
const torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> b,
torch::PackedTensorAccessor32<scalar_t, 3, torch::RestrictPtrTraits> out,
const int m, const int p, const int n) {
const int row = threadIdx.x + blockIdx.x * blockDim.x;
const int col = threadIdx.y + blockIdx.y * blockDim.y;
const int batch_size = blockIdx.z;
// might exceed the sizes of a and b
if (row < m && col < n) {
scalar_t val = 0.0;
scalar_t max = -1e9;
for (int i = 0; i < p; i++) {
scalar_t v = a[batch_size][row][i] + b[batch_size][i][col];
if (v > max) {
max = v;
}
}
for (int i = 0; i < p; i++) {
scalar_t v = a[batch_size][row][i] + b[batch_size][i][col];
val += exp(v - max);
}
out[batch_size][row][col] = log(val) + max;
}
}
} // namespace
/* Sum operation in Log Semiring
* Matrix multiplication in log space.
Arguments
---------
log_a: (B, m, p)
log_b: (B, p, n)
Returns
-------
(B, m, n)
*/
torch::Tensor log_matmul_cuda(
torch::Tensor& a,
torch::Tensor& b) {
const int batch = a.sizes()[0];
const int m = a.sizes()[1];
const int p = a.sizes()[2];
const int n = b.sizes()[2];
const size_t nthreads = 32;
const dim3 threads_per_block(nthreads, nthreads, 1);
const dim3 blocks(m / nthreads + 1, n / nthreads + 1, batch);
auto options = torch::TensorOptions().dtype(a.dtype()).device(
torch::kCUDA, a.device().index());
auto out = torch::empty({batch, m, n}, options);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
a.type(), "log_cuda_matmul", ([&] {
log_matmul_forward_kernel<scalar_t><<<blocks, threads_per_block>>>(
a.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
b.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
out.packed_accessor32<scalar_t, 3, torch::RestrictPtrTraits>(),
m, p, n);
}));
return out;
}