Skip to content

Commit

Permalink
WIP fusing the normalization, kernel still super buggy and a bit long
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 22, 2021
1 parent 18744d8 commit 3f8b95f
Showing 1 changed file with 55 additions and 37 deletions.
92 changes: 55 additions & 37 deletions xformers/triton/k_mem_efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,81 @@ def k_me_attention_fw(
# extract metaparameters
BLOCK_M = META["BLOCK_M"]
BLOCK_N, BLOCK_L = META["BLOCK_N"], META["BLOCK_L"]
SINGLE_ROW_TILE = META["SINGLE_ROW_TILE"]
scale = META["SCALE"]

# *within groups*, programs are ordered in a column-major order
# row-id /col-id of the program in the *launch grid*
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)

# now compute the ranges that each program will go through
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# initialize and iteratively update accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# Compute QKt
# block level matrix multiplication.
# We fetch a block memory block from both inputs, matmul and accumulate, then repeat
qkt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

i = 0
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rl = tl.arange(0, BLOCK_L)

for _ in range(L, 0, -BLOCK_L):
rl_i = rl + i * BLOCK_L
rl_i = rl + i * BLOCK_L # keep track of the masking
q_ptrs = Q + rm[:, None] * L + rl_i[None, :] # (BLOCK_M, BLOCK_L)
k_ptrs = K + rn[None, :] * L + rl_i[:, None] # (BLOCK_L, BLOCK_N)

q = tl.load(q_ptrs, mask=((rm[:, None] < M) & (rl_i[None, :] < L)), other=0.0) # (BLOCK_M, BLOCK_L)
q *= scale # q /= sqrt(dim)

k = tl.load(k_ptrs, mask=((rl_i[:, None] < L) & (rn[None, :] < N)), other=0.0) # (BLOCK_L, BLOCK_N)

acc += tl.dot(q, k).to(tl.float32) # (BLOCK_M, BLOCK_N)
q *= scale # q /= sqrt(dim)
qkt += tl.dot(q, k).to(tl.float32) # (BLOCK_M, BLOCK_N)

# Update the pointers and counter
i += 1

# pick the local max, safeguard the incoming exponential
# save so that an eventual mismatch can be fixed
max_acc = tl.max(acc, axis=1) # (BLOCK_M)
# Pick the local max per row, safeguard the incoming exponential
max_qkt = tl.max(qkt, axis=1) # (BLOCK_M)
max_ptrs = MAXES + pid_n * stride_maxes + rm # (BLOCK_M)
tl.store(max_ptrs, max_acc, mask=(rm < M))

# exponentiate the neutralized results
exp_acc = tl.exp(acc - max_acc[:, None]) # (BLOCK_M, BLOCK_N)
# Save so that an eventual mismatch can be fixed post-hoc
tl.store(max_ptrs, max_qkt, mask=(rm < M))

# Exponentiate the neutralized results
exp_acc = tl.exp(qkt - max_qkt[:, None]) # (BLOCK_M, BLOCK_N)

# Softmax normalization constant, save so that an eventual mismatch can be fixed
weights = tl.sum(exp_acc, axis=1) # (BLOCK_M)
weights_ptrs = WEIGHTS + pid_n * stride_weights + rm
tl.store(weights_ptrs, weights, mask=(rm < M))

# If not posterior re-normalization, fuse it
if SINGLE_ROW_TILE:
exp_qkt = exp_acc / weights[:, None]
else:
exp_qkt = exp_acc

# Now pre-compute exp_acc against V.
# We proceed per chunk over L, and save as we go
i = 0
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rl = tl.arange(0, BLOCK_L)

out_ptrs = OUT + pid_n * stride_out + rm[:, None] * L + rl[None, :] # (BLOCK_M, BLOCK_L)
v_ptrs = V + rn[:, None] * L + rl[None, :] # (BLOCK_N, BLOCK_L)
for _ in range(L, 0, -BLOCK_L):
rl_i = rl + i * BLOCK_L
rl_i = rl + i * BLOCK_L # Useful to keep track of the masking

v_ptrs = V + rn[:, None] * L + rl_i[None, :] # (BLOCK_N, BLOCK_L)
v = tl.load(v_ptrs, mask=((rn[:, None] < N) & (rl_i[None, :] < L)), other=0.0)

qkv = tl.dot(exp_acc, v).to(tl.float32) # (BLOCK_M, BLOCK_L)
qkv = tl.dot(exp_qkt, v) # (BLOCK_M, BLOCK_L)

out_ptrs = OUT + pid_n * stride_out + rm[:, None] * L + rl_i[None, :] # (BLOCK_M, BLOCK_L)
# FIXME: This is most probably super slow
tl.store(out_ptrs, qkv, mask=(rm[:, None] < M) & (rl_i[None, :] < L))
i += 1

# save so that an eventual mismatch can be fixed
weights = tl.sum(exp_acc, axis=1) # (BLOCK_M)
weights_ptrs = WEIGHTS + pid_n * stride_weights + rm
tl.store(weights_ptrs, weights, mask=(rm < M))
i += 1
out_ptrs += BLOCK_L
v_ptrs += BLOCK_L


def mem_efficient_fw(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -119,25 +134,26 @@ def grid(META):

maxes_n = torch.empty((tiles_n, M), dtype=q.dtype, device=q.device)
weights_n = torch.empty((tiles_n, M), dtype=q.dtype, device=q.device)
out_n = torch.empty((tiles_n, M, L), dtype=q.dtype, device=q.device)

# FIXME: handle bias
# FIXME: improve on the batch dimension handling ?
qkvs = []
for i_b in range(B):
out = torch.empty((tiles_n, M, L), dtype=q.dtype, device=q.device)

# Use a dedicated kernel to process the attention by blocks
# fmt: off
k_me_attention_fw[grid](
out, maxes_n, weights_n, # outputs
out_n, maxes_n, weights_n, # outputs
q[i_b], k[i_b], v[i_b], # inputs
M, N, L, # dimensions
out.stride(0), maxes_n.stride(0), weights_n.stride(0),
out_n.stride(0), maxes_n.stride(0), weights_n.stride(0),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_L=BLOCK_L,
BIAS=False,
SCALE=1. / math.sqrt(L)
SCALE=1. / math.sqrt(L),
SINGLE_ROW_TILE=False # FIXME: This should work: tiles_n == 1
)
# fmt: onx

Expand All @@ -148,22 +164,24 @@ def grid(META):

# Let's fix that:
# - collect the real overall max per line
per_line_max, _ = maxes_n.max(dim=0)
global_max, _ = maxes_n.max(dim=0)

# - compute the mistake that was done in real time
mismatch = torch.exp(maxes_n - per_line_max[None, :])
mismatch = torch.exp(maxes_n - global_max[None, :])

# - update the computations to take the consolidated max/weights
out *= mismatch.unsqueeze(-1)
out_n *= mismatch.unsqueeze(-1)
weights_n *= mismatch

out = torch.sum(out, dim=0)
out = torch.sum(out_n, dim=0)
weights = torch.sum(weights_n, dim=0)

qkv = out / weights.unsqueeze(-1)

else:
weights = weights_n
# FIXME: The result should already be normalized
qkv = out_n / weights_n.unsqueeze(-1)

# TODO: do this in the kernel if it owns the whole line
qkv = out / weights.unsqueeze(-1)
qkvs.append(qkv)

return torch.cat(qkvs, dim=0).reshape(q_shape)

0 comments on commit 3f8b95f

Please sign in to comment.