Skip to content

Commit

Permalink
Warp-specialized FP8 rowsise GEMM kernel (#3532)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#614

Pull Request resolved: #3532

Adding a warp-specialize GEMM kernel. A couple highlights:

- `tl.async_task` warp partition. Rewrote the well-known flattened 1D-loop GEMM persistent kernel with a more intuitive natural Warp-specialized 2D-loop persistent kernel with producer/consumer co-operative model.
- One kernel for both, enabled autotuning across both WS and non-WS with same kernel, since. WS may not always be beneficial. We have also updated the tutorial to support multi-modes, including non-WS.
- Use regular load instead of TMA for scale loading within each consumer
- Compiler-automated accumulator initialization omission. The first iteration of the matmul K-loop does not need to an accumulator when computing the output which is fed to the next iteration as the accumulator. Therefore peeling the first iteration out of the K-loop can avoid the zero initialization of the accumulator. `tl.assume` is used to deal with the case when the loop doesn’t run at all.

Reviewed By: jianyuh

Differential Revision: D67676051

fbshipit-source-id: 4c552e37c358dc48d19b26ea3019c7afcb1ef18a
  • Loading branch information
htyu authored and facebook-github-bot committed Dec 28, 2024
1 parent 0c93fd0 commit 921e305
Showing 1 changed file with 341 additions and 0 deletions.
341 changes: 341 additions & 0 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def _kernel_matmul_fp8_row_imprecise_acc(
"n_key",
"k_key",
],
use_cuda_graph=True,
)
@triton.heuristics(
{
Expand Down Expand Up @@ -898,6 +899,232 @@ def _kernel_matmul_fp8_row_tma_persistent(
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)


has_warp_specialization = hasattr(tl, "async_task")


def get_ws_configs() -> List[Config]:
if not has_warp_specialization:
return []
return [
# pyre-ignore
Config(
{
"BLOCK_M": 128,
"BLOCK_N": 256,
"BLOCK_K": 128,
"SPLIT_K": 1,
"NUM_CONSUMER_GROUPS": 2,
},
num_stages=3,
num_warps=4,
num_consumer_groups=2,
num_buffers_warp_spec=3,
),
# pyre-ignore
Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"BLOCK_K": 128,
"SPLIT_K": 1,
"NUM_CONSUMER_GROUPS": 2,
},
num_stages=4,
num_warps=4,
num_consumer_groups=2,
num_buffers_warp_spec=4,
),
# pyre-ignore
Config(
{
"BLOCK_M": 128,
"BLOCK_N": 256,
"BLOCK_K": 128,
"SPLIT_K": 1,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=3,
num_warps=8,
num_consumer_groups=0,
num_buffers_warp_spec=3,
),
# pyre-ignore
Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"BLOCK_K": 512,
"SPLIT_K": 1,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=3,
num_warps=4,
num_consumer_groups=0,
num_buffers_warp_spec=3,
),
]


@triton.autotune(
configs=[
Config(
{
"BLOCK_M": 128,
"BLOCK_N": 256,
"BLOCK_K": 128,
"SPLIT_K": 1,
"NUM_CONSUMER_GROUPS": 1,
},
num_stages=3,
num_warps=8,
),
]
+ get_ws_configs(),
key=[
"m_key",
"n_key",
"k_key",
],
use_cuda_graph=True,
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
A_ptr,
B_ptr,
C_ptr,
M,
N,
K,
m_key,
n_key,
k_key,
A_scale,
B_scale,
Bias,
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_cm,
stride_cn,
dot_out_dtype: tl.constexpr,
c_dtype: tl.constexpr,
bias_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
AB_DTYPE: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_BIAS: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
) -> None:
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
Args:
A (TensorWrapper): [M , K] input tensor.
B (TensorWrapper): [N, K] input tensor.
C (TensorWrapper): [M, N] output tensor.
M (int): M dimension of input tensor.
N (int): N dimension of input tensor.
K (int): K dimension of input tensor.
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
stride_am (int): Stride of M dimension of A.
stride_ak (int): Stride of K dimension of A.
stride_bn (int): Stride of N dimension of B.
stride_bk (int): Stride of K dimension of B.
stride_cm (int): Stride of M dimension of C.
stride_cn (int): Stride of N dimension of C.
dot_out_dtype (torch.dtype): Output type of tensor core.
allow_tf32 (bool): Whether to use TF32 for tensor core.
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
BLOCK_M (int): Block size for M dimension.
BLOCK_N (int): Block size for N dimension.
BLOCK_K (int): Block size for K dimension.
GROUP_M (int): Number of groups for M dimension swizzle.
SPLIT_K (int): Number of SM's to launch per row.
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
"""
num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
dtype_fp8 = tl.float8e4nv
for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
# pyre-ignore
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_M, BLOCK_K] pointers
# `b_ptrs` is a block of [BLOCK_K, BLOCK_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k = 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
# pyre-ignore
tl.assume(tl.cdiv(K, BLOCK_K) > 0)
for _ in range(0, tl.cdiv(K, BLOCK_K)):
# pyre-ignore
with tl.async_task([0]):
a = tl._experimental_descriptor_load(
A_ptr,
[offs_am, offs_k],
[BLOCK_M, BLOCK_K],
dtype_fp8,
)
b = tl._experimental_descriptor_load(
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
)

if fp8_fast_accum:
acc = tl.dot(
a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
)
else:
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

offs_k += BLOCK_K

# pyre-ignore
with tl.async_task([1, NUM_CONSUMER_GROUPS]):
# Invert scaling.
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
a_scale = tl.load(A_scale + rm, mask=rm < M)
b_scale = tl.load(B_scale + rn, mask=rn < N)
scale = a_scale[:, None] * b_scale[None, :]
acc *= scale
# Load and add bias if specified.
if USE_BIAS:
bias = tl._experimental_descriptor_load(
Bias, [offs_bn], [BLOCK_N], bias_dtype
)
acc += bias[None, :]
acc = acc.to(c_dtype)
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

Expand Down Expand Up @@ -1004,6 +1231,7 @@ def matmul_fp8_row(
imprecise_acc: bool = False,
tma_persistent: bool = True,
no_use_persistent: bool = False,
use_warp_specialization: bool = False,
) -> torch.Tensor:
"""
Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
Expand Down Expand Up @@ -1097,6 +1325,119 @@ def persistent_grid(META):
# USE_BIAS=bias is not None,
AB_DTYPE=False,
)
elif use_warp_specialization:
assert has_warp_specialization
# used by TMA warp specialization kernel
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("a")
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("c")
desc_helper.init_tma_descriptor("a_scale")
desc_helper.init_tma_descriptor("b_scale")
desc_helper.init_tma_descriptor("bias")

def persistent_grid_tma_ws(META):
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"a",
a.data_ptr(),
M,
K,
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
META["BLOCK_K"],
a.element_size(),
)

desc_helper.fill_2d_tma_descriptor(
"b",
b.data_ptr(),
N,
K,
META["BLOCK_N"],
META["BLOCK_K"],
b.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"c",
c.data_ptr(),
M,
N,
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
META["BLOCK_N"],
c.element_size(),
)
desc_helper.fill_1d_tma_descriptor(
"a_scale",
a_scale.data_ptr(),
M,
META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
a_scale.element_size(),
)
desc_helper.fill_1d_tma_descriptor(
"b_scale",
b_scale.data_ptr(),
N,
META["BLOCK_N"],
b_scale.element_size(),
)
if bias is not None:
desc_helper.fill_1d_tma_descriptor(
"bias",
bias.data_ptr(),
N,
META["BLOCK_N"],
bias.element_size(),
)
return (
min(
NUM_SMS,
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
),
)

desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

bias_dtype_triton = None
if bias is not None:
bias_dtype_triton = map_dtype_to_triton(bias.dtype)

# pyre-ignore
torch._library.capture_triton(
_kernel_matmul_fp8_row_tma_persistent_ws_cooperative
)[persistent_grid_tma_ws](
desc_a,
desc_b,
desc_c,
M,
N,
K,
m_key,
n_key,
k_key,
a_scale,
b_scale,
desc_bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
dot_out_dtype=dot_out_dtype_triton,
c_dtype=c_dtype_triton,
bias_dtype=bias_dtype_triton,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8,
AB_DTYPE=False,
NUM_SMS=NUM_SMS,
USE_BIAS=bias is not None,
)
elif tma_persistent:
# used by TMA persistent kernel
desc_helper = TmaAutoTuneHelper()
Expand Down

0 comments on commit 921e305

Please sign in to comment.