Skip to content

Commit

Permalink
yet another take, partial sum
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 10, 2021
1 parent ce1ab1c commit c7ab5b0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 35 deletions.
26 changes: 20 additions & 6 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def grid(meta):
triton.cdiv(N, meta["BLOCK_N"]),
)

GROUP_M = 128
BLOCK_M = GROUP_M // 4

# fmt: off
k_dropout_fw[grid](
y, x_,
Expand All @@ -53,7 +56,8 @@ def grid(meta):
M, N,
p,
USE_BIAS=bias is not None,
ACTIVATION=activation
ACTIVATION=activation,
BLOCK_M=BLOCK_M
)
# fmt: on

Expand Down Expand Up @@ -87,12 +91,21 @@ def backward(ctx, grad_out):
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, N)

GROUP_M = 256
BLOCK_M = GROUP_M // 4
N_BLOCKS_M = triton.cdiv(M, GROUP_M)

if ctx.trainable_bias:
grad_bias = torch.empty((N,), device=grad_in.device, dtype=grad_in.dtype)
locks = torch.zeros(N // 2, dtype=torch.int32, device=grad_in.device)
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)
else:
grad_bias = grad_in # will not be used
locks = grad_in

def grid(meta):
return (
Expand All @@ -104,13 +117,14 @@ def grid(meta):
k_dropout_bw[grid](
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds, locks,
seeds,
grad_out_.stride(0), inputs.stride(0),
M, N,
ctx.p,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M
)
# fmt: on

Expand Down
34 changes: 5 additions & 29 deletions xformers/triton/k_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,9 @@

# WARNING: For now, the number of threads must be the same as the N buffer, and warps have to be 4 (will be fixed)
k_configs = [
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
]


Expand Down Expand Up @@ -131,7 +125,7 @@ def k_dropout_fw(
@triton.jit
def k_dropout_bw(
GRAD_IN, GRAD_BIAS, GRAD_OUT,
INPUTS, BIAS, SEEDS, LOCKS,
INPUTS, BIAS, SEEDS,
stride_grad, stride_inputs,
M, N,
p,
Expand Down Expand Up @@ -250,23 +244,5 @@ def k_dropout_bw(
rand_mask = rand_mask1

if TRAINABLE_BIAS:
lock_ptr = LOCKS + 2 * col_id
count_ptr = LOCKS + 2 * col_id + 1
grad_bias_ptr = GRAD_BIAS + cols

# Uniquely taking a lock over the col results
while tl.atomic_cas(lock_ptr, 0, 1) == 1:
pass

count = tl.load(count_ptr)
if count == 0:
# first store doesn't accumulate
tl.atomic_xchg(count_ptr, 1)
else:
# read and add back
grad_bias += tl.load(grad_bias_ptr, mask=cols < N)

grad_bias_ptr = GRAD_BIAS + row_id * N + cols
tl.store(grad_bias_ptr, grad_bias, mask=cols < N)

# release lock
tl.atomic_xchg(lock_ptr, 0)

0 comments on commit c7ab5b0

Please sign in to comment.