From 3f8b95fc854649c479a00d8bfc2d97a454ce95c2 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Tue, 21 Dec 2021 19:43:10 -0800 Subject: [PATCH] WIP fusing the normalization, kernel still super buggy and a bit long --- xformers/triton/k_mem_efficient_attention.py | 92 ++++++++++++-------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/xformers/triton/k_mem_efficient_attention.py b/xformers/triton/k_mem_efficient_attention.py index 391835fece..3a3d57d51e 100644 --- a/xformers/triton/k_mem_efficient_attention.py +++ b/xformers/triton/k_mem_efficient_attention.py @@ -25,6 +25,7 @@ def k_me_attention_fw( # extract metaparameters BLOCK_M = META["BLOCK_M"] BLOCK_N, BLOCK_L = META["BLOCK_N"], META["BLOCK_L"] + SINGLE_ROW_TILE = META["SINGLE_ROW_TILE"] scale = META["SCALE"] # *within groups*, programs are ordered in a column-major order @@ -32,59 +33,73 @@ def k_me_attention_fw( pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) - # now compute the ranges that each program will go through - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # initialize and iteratively update accumulator - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - + # Compute QKt # block level matrix multiplication. # We fetch a block memory block from both inputs, matmul and accumulate, then repeat + qkt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + i = 0 + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rl = tl.arange(0, BLOCK_L) + for _ in range(L, 0, -BLOCK_L): - rl_i = rl + i * BLOCK_L + rl_i = rl + i * BLOCK_L # keep track of the masking q_ptrs = Q + rm[:, None] * L + rl_i[None, :] # (BLOCK_M, BLOCK_L) k_ptrs = K + rn[None, :] * L + rl_i[:, None] # (BLOCK_L, BLOCK_N) q = tl.load(q_ptrs, mask=((rm[:, None] < M) & (rl_i[None, :] < L)), other=0.0) # (BLOCK_M, BLOCK_L) - q *= scale # q /= sqrt(dim) - k = tl.load(k_ptrs, mask=((rl_i[:, None] < L) & (rn[None, :] < N)), other=0.0) # (BLOCK_L, BLOCK_N) - acc += tl.dot(q, k).to(tl.float32) # (BLOCK_M, BLOCK_N) + q *= scale # q /= sqrt(dim) + qkt += tl.dot(q, k).to(tl.float32) # (BLOCK_M, BLOCK_N) + + # Update the pointers and counter i += 1 - # pick the local max, safeguard the incoming exponential - # save so that an eventual mismatch can be fixed - max_acc = tl.max(acc, axis=1) # (BLOCK_M) + # Pick the local max per row, safeguard the incoming exponential + max_qkt = tl.max(qkt, axis=1) # (BLOCK_M) max_ptrs = MAXES + pid_n * stride_maxes + rm # (BLOCK_M) - tl.store(max_ptrs, max_acc, mask=(rm < M)) - # exponentiate the neutralized results - exp_acc = tl.exp(acc - max_acc[:, None]) # (BLOCK_M, BLOCK_N) + # Save so that an eventual mismatch can be fixed post-hoc + tl.store(max_ptrs, max_qkt, mask=(rm < M)) + + # Exponentiate the neutralized results + exp_acc = tl.exp(qkt - max_qkt[:, None]) # (BLOCK_M, BLOCK_N) + + # Softmax normalization constant, save so that an eventual mismatch can be fixed + weights = tl.sum(exp_acc, axis=1) # (BLOCK_M) + weights_ptrs = WEIGHTS + pid_n * stride_weights + rm + tl.store(weights_ptrs, weights, mask=(rm < M)) + + # If not posterior re-normalization, fuse it + if SINGLE_ROW_TILE: + exp_qkt = exp_acc / weights[:, None] + else: + exp_qkt = exp_acc # Now pre-compute exp_acc against V. # We proceed per chunk over L, and save as we go i = 0 + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rl = tl.arange(0, BLOCK_L) + + out_ptrs = OUT + pid_n * stride_out + rm[:, None] * L + rl[None, :] # (BLOCK_M, BLOCK_L) + v_ptrs = V + rn[:, None] * L + rl[None, :] # (BLOCK_N, BLOCK_L) for _ in range(L, 0, -BLOCK_L): - rl_i = rl + i * BLOCK_L + rl_i = rl + i * BLOCK_L # Useful to keep track of the masking - v_ptrs = V + rn[:, None] * L + rl_i[None, :] # (BLOCK_N, BLOCK_L) v = tl.load(v_ptrs, mask=((rn[:, None] < N) & (rl_i[None, :] < L)), other=0.0) - qkv = tl.dot(exp_acc, v).to(tl.float32) # (BLOCK_M, BLOCK_L) + qkv = tl.dot(exp_qkt, v) # (BLOCK_M, BLOCK_L) - out_ptrs = OUT + pid_n * stride_out + rm[:, None] * L + rl_i[None, :] # (BLOCK_M, BLOCK_L) + # FIXME: This is most probably super slow tl.store(out_ptrs, qkv, mask=(rm[:, None] < M) & (rl_i[None, :] < L)) - i += 1 - # save so that an eventual mismatch can be fixed - weights = tl.sum(exp_acc, axis=1) # (BLOCK_M) - weights_ptrs = WEIGHTS + pid_n * stride_weights + rm - tl.store(weights_ptrs, weights, mask=(rm < M)) + i += 1 + out_ptrs += BLOCK_L + v_ptrs += BLOCK_L def mem_efficient_fw(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: @@ -119,25 +134,26 @@ def grid(META): maxes_n = torch.empty((tiles_n, M), dtype=q.dtype, device=q.device) weights_n = torch.empty((tiles_n, M), dtype=q.dtype, device=q.device) + out_n = torch.empty((tiles_n, M, L), dtype=q.dtype, device=q.device) # FIXME: handle bias # FIXME: improve on the batch dimension handling ? qkvs = [] for i_b in range(B): - out = torch.empty((tiles_n, M, L), dtype=q.dtype, device=q.device) # Use a dedicated kernel to process the attention by blocks # fmt: off k_me_attention_fw[grid]( - out, maxes_n, weights_n, # outputs + out_n, maxes_n, weights_n, # outputs q[i_b], k[i_b], v[i_b], # inputs M, N, L, # dimensions - out.stride(0), maxes_n.stride(0), weights_n.stride(0), + out_n.stride(0), maxes_n.stride(0), weights_n.stride(0), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_L=BLOCK_L, BIAS=False, - SCALE=1. / math.sqrt(L) + SCALE=1. / math.sqrt(L), + SINGLE_ROW_TILE=False # FIXME: This should work: tiles_n == 1 ) # fmt: onx @@ -148,22 +164,24 @@ def grid(META): # Let's fix that: # - collect the real overall max per line - per_line_max, _ = maxes_n.max(dim=0) + global_max, _ = maxes_n.max(dim=0) # - compute the mistake that was done in real time - mismatch = torch.exp(maxes_n - per_line_max[None, :]) + mismatch = torch.exp(maxes_n - global_max[None, :]) # - update the computations to take the consolidated max/weights - out *= mismatch.unsqueeze(-1) + out_n *= mismatch.unsqueeze(-1) weights_n *= mismatch - out = torch.sum(out, dim=0) + out = torch.sum(out_n, dim=0) weights = torch.sum(weights_n, dim=0) + + qkv = out / weights.unsqueeze(-1) + else: - weights = weights_n + # FIXME: The result should already be normalized + qkv = out_n / weights_n.unsqueeze(-1) - # TODO: do this in the kernel if it owns the whole line - qkv = out / weights.unsqueeze(-1) qkvs.append(qkv) return torch.cat(qkvs, dim=0).reshape(q_shape)