-
Notifications
You must be signed in to change notification settings - Fork 710
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
10 changed files
with
1,172 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,349 @@ | ||
import logging | ||
from typing import Optional | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@triton.jit | ||
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): | ||
expert = tl.program_id(0) | ||
low = 0 | ||
high = num_toks - 1 | ||
target_location = -1 | ||
while low <= high: | ||
mid = (low + high) // 2 | ||
|
||
if tl.load(reorder_topk_ids + mid) > expert: | ||
high = mid - 1 | ||
else: | ||
low = mid + 1 | ||
target_location = mid | ||
tl.store(seg_indptr + expert + 1, target_location + 1) | ||
|
||
|
||
@triton.jit | ||
def compute_src2dst_triton_kernel( | ||
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr | ||
): | ||
pid = tl.program_id(axis=0) | ||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
mask = dst_id < num_toks | ||
src_id = tl.load(reorder_ids + dst_id, mask=mask) | ||
tl.store(src2dst + src_id, dst_id, mask=mask) | ||
|
||
|
||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): | ||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) | ||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) | ||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) | ||
|
||
compute_seg_indptr_triton_kernel[(num_experts,)]( | ||
reorder_topk_ids, seg_indptr, topk_ids.numel() | ||
) | ||
|
||
BLOCK_SIZE = 512 | ||
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) | ||
compute_src2dst_triton_kernel[grid]( | ||
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE | ||
) | ||
return reorder_topk_ids, src2dst, seg_indptr | ||
|
||
|
||
@triton.jit | ||
def pre_reorder_triton_kernel( | ||
input_ptr, | ||
gateup_input_ptr, | ||
src2dst_ptr, | ||
topk_ids_ptr, | ||
a1_scales_ptr, | ||
start_expert_id, | ||
end_expert_id, | ||
topk, | ||
hidden_size, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
OutDtype = gateup_input_ptr.dtype.element_ty | ||
|
||
src_idx = tl.program_id(0) | ||
src2dst_ptr = src2dst_ptr + src_idx * topk | ||
topk_ids_ptr = topk_ids_ptr + src_idx * topk | ||
|
||
src_ptr = input_ptr + src_idx * hidden_size | ||
for idx in range(topk): | ||
expert_id = tl.load(topk_ids_ptr + idx) | ||
if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
if a1_scales_ptr is not None: | ||
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) | ||
else: | ||
scale = 1.0 | ||
|
||
dst_idx = tl.load(src2dst_ptr + idx) | ||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size | ||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
mask = offset < hidden_size | ||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) | ||
out_data = (in_data * scale).to(OutDtype) | ||
tl.store(dst_ptr + offset, out_data, mask=mask) | ||
|
||
|
||
@triton.jit | ||
def silu_and_mul_triton_kernel( | ||
gateup_output, | ||
down_input, | ||
hidden_size, | ||
reorder_topk_ids, | ||
scales, | ||
start_expert_id, | ||
end_expert_id, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
InDtype = gateup_output.dtype.element_ty | ||
OutDtype = down_input.dtype.element_ty | ||
|
||
half_hidden_size = hidden_size // 2 | ||
|
||
pid = tl.program_id(0) | ||
expert_id = tl.load(reorder_topk_ids + pid) | ||
if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
gateup_output_ptr = gateup_output + pid * hidden_size | ||
gate_output_ptr = gateup_output_ptr | ||
up_output_ptr = gateup_output_ptr + half_hidden_size | ||
down_input_ptr = down_input + pid * half_hidden_size | ||
|
||
if scales is not None: | ||
scale = tl.load(scales + expert_id - start_expert_id) | ||
scale = (1 / scale).to(InDtype) | ||
else: | ||
scale = 1 | ||
|
||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): | ||
offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
mask = offset < half_hidden_size | ||
|
||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) | ||
up_output = tl.load(up_output_ptr + offset, mask=mask) | ||
|
||
# silu & mul & quantize | ||
gate_output = gate_output * tl.sigmoid(gate_output) | ||
gate_output = gate_output.to(InDtype) | ||
|
||
silu_mul_output = gate_output * up_output * scale | ||
silu_mul_output = silu_mul_output.to(OutDtype) | ||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) | ||
|
||
|
||
@triton.jit | ||
def post_reorder_triton_kernel( | ||
down_output_ptr, | ||
output_ptr, | ||
src2dst_ptr, | ||
topk_ids_ptr, | ||
topk_weights_ptr, | ||
start_expert_id, | ||
end_expert_id, | ||
topk, | ||
hidden_size, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
InDtype = down_output_ptr.dtype.element_ty | ||
|
||
src_idx = tl.program_id(0) | ||
src2dst_ptr = src2dst_ptr + src_idx * topk | ||
topk_ids_ptr = topk_ids_ptr + src_idx * topk | ||
topk_weights_ptr = topk_weights_ptr + src_idx * topk | ||
|
||
computed = False | ||
store_ptr = output_ptr + src_idx * hidden_size | ||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
mask = offset < hidden_size | ||
|
||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) | ||
for idx in range(topk): | ||
expert_id = tl.load(topk_ids_ptr + idx) | ||
if expert_id >= start_expert_id and expert_id <= end_expert_id: | ||
computed = True | ||
dst_idx = tl.load(src2dst_ptr + idx) | ||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) | ||
load_ptr = down_output_ptr + dst_idx * hidden_size | ||
in_data = tl.load(load_ptr + offset, mask=mask) | ||
sum_vec += in_data * weigh_scale | ||
tl.store(store_ptr + offset, sum_vec, mask=mask) | ||
|
||
if computed == False: | ||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): | ||
offset = start_offset + tl.arange(0, BLOCK_SIZE) | ||
mask = offset < hidden_size | ||
tl.store( | ||
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask | ||
) | ||
|
||
|
||
@triton.jit | ||
def compute_m_range( | ||
pid, | ||
batch_size, | ||
seg_indptr, | ||
weight_indices, | ||
m_num_tiles_indptr, | ||
BLOCK_SIZE_M: tl.constexpr, | ||
): | ||
idx = 0 | ||
for bs in range(batch_size): | ||
tiles = tl.load(m_num_tiles_indptr + bs) | ||
if pid >= tiles: | ||
idx = bs | ||
|
||
idx_start = tl.load(m_num_tiles_indptr + idx) | ||
|
||
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M | ||
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) | ||
expert_id = tl.load(weight_indices + idx) | ||
return m_range_start, m_range_end, expert_id | ||
|
||
|
||
@triton.jit | ||
def grouped_gemm_triton_kernel( | ||
a, | ||
b, | ||
c, | ||
batch_size, | ||
N, | ||
K, | ||
seg_indptr, | ||
weight_indices, | ||
m_num_tiles_indptr, | ||
use_fp8_w8a8, | ||
scale_a, | ||
scale_b, | ||
a_stride_0: tl.constexpr, | ||
b_stride_0: tl.constexpr, | ||
b_stride_1: tl.constexpr, | ||
BLOCK_SIZE_M: tl.constexpr, | ||
BLOCK_SIZE_N: tl.constexpr, | ||
BLOCK_SIZE_K: tl.constexpr, | ||
): | ||
c_dtype = c.dtype.element_ty | ||
|
||
pid_m = tl.program_id(0) | ||
pid_n = tl.program_id(1) | ||
total_m_block = tl.load(m_num_tiles_indptr + batch_size) | ||
if pid_m >= total_m_block: | ||
return | ||
|
||
m_range_start, m_range_end, expert_id = compute_m_range( | ||
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M | ||
) | ||
if m_range_end - m_range_start == 0: | ||
return | ||
|
||
n_range_start = pid_n * BLOCK_SIZE_N | ||
n_range_end = min(n_range_start + BLOCK_SIZE_N, N) | ||
|
||
offs_am = tl.arange(0, BLOCK_SIZE_M) | ||
offs_bn = tl.arange(0, BLOCK_SIZE_N) | ||
|
||
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) | ||
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) | ||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
|
||
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] | ||
b_ptr = b + ( | ||
(expert_id * b_stride_0) | ||
+ (n_range_start + offs_bn[:, None]) * b_stride_1 | ||
+ offs_k[None, :] | ||
) | ||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | ||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||
a_tile = tl.load( | ||
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 | ||
) | ||
b_tile = tl.load( | ||
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 | ||
) | ||
accumulator = tl.dot(a_tile, b_tile.T, accumulator) | ||
a_ptr += BLOCK_SIZE_K | ||
b_ptr += BLOCK_SIZE_K | ||
|
||
if use_fp8_w8a8: | ||
scale_a_value = tl.load(scale_a + expert_id) | ||
scale_b_value = tl.load(scale_b + expert_id) | ||
accumulator *= scale_a_value * scale_b_value | ||
c_tile = accumulator.to(c_dtype) | ||
|
||
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) | ||
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] | ||
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) | ||
tl.store(c_ptr, c_tile, mask=c_mask) | ||
|
||
|
||
@triton.jit | ||
def compute_m_num_tiles_indptr( | ||
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr | ||
): | ||
for bs in range(batch_size): | ||
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) | ||
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) | ||
pre_num_tiles = tl.load(m_num_tiles_indptr + bs) | ||
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) | ||
|
||
|
||
def grouped_gemm_triton( | ||
a: torch.Tensor, | ||
b: torch.Tensor, | ||
c: torch.Tensor, | ||
batch_size: int, | ||
weight_column_major: bool, | ||
seg_indptr: Optional[torch.Tensor] = None, | ||
weight_indices: Optional[torch.Tensor] = None, | ||
use_fp8_w8a8: bool = False, | ||
scale_a: torch.Tensor = None, | ||
scale_b: torch.Tensor = None, | ||
): | ||
assert weight_column_major == True # TODO: more | ||
if use_fp8_w8a8: | ||
assert scale_a is not None and scale_b is not None | ||
|
||
config = { | ||
"BLOCK_SIZE_M": 128, | ||
"BLOCK_SIZE_N": 128, | ||
"BLOCK_SIZE_K": 128, | ||
} | ||
|
||
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) | ||
compute_m_num_tiles_indptr[(1,)]( | ||
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] | ||
) | ||
|
||
grid = lambda META: ( | ||
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, | ||
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), | ||
) | ||
|
||
grouped_gemm_triton_kernel[grid]( | ||
a, | ||
b, | ||
c, | ||
batch_size, | ||
b.size(1), | ||
b.size(2), | ||
seg_indptr, | ||
weight_indices, | ||
m_num_tiles_indptr, | ||
use_fp8_w8a8, | ||
scale_a, | ||
scale_b, | ||
a.stride(0), | ||
b.stride(0), | ||
b.stride(1), | ||
**config, | ||
) | ||
return c |
Oops, something went wrong.