Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marlin 24 prefill performance improvement (about 25% better on average) #4983

Merged
merged 9 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_quantize)
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)

Expand Down Expand Up @@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)

# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)

# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
Expand All @@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)

# Prepare
marlin_workspace = MarlinWorkspace(size_n)
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)

globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"is_k_full": is_k_full,
"a": a,
"a_tmp": a_tmp,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
"marlin_workspace": marlin_workspace,
}

min_run_time = 1
Expand Down Expand Up @@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))

if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))

results.append(
benchmark.Timer(
stmt=
Expand Down Expand Up @@ -135,8 +170,20 @@ def main(args):
continue

for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue

for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue

for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue

for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
Expand All @@ -159,7 +206,7 @@ def main(args):


# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand All @@ -178,6 +225,9 @@ def main(args):
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])

args = parser.parse_args()
main(args)
55 changes: 40 additions & 15 deletions csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ namespace marlin_24 {
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static constexpr int THREADS = 256;
static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory
static constexpr int STAGES = 4;

static constexpr int min_thread_n = 128;

static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int max_par = 64;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800

Expand Down Expand Up @@ -736,10 +736,10 @@ __global__ void Marlin_24(
for (int pipe = 0; pipe < stages;) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
matmul(pipe);
wait_for_stage();

fetch_to_registers(pipe + 1, (pipe + 1) % stages);
matmul(pipe);

pipe++;
slice_iters--;
Expand Down Expand Up @@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
// than better compute utilization
thread_k = 128;
thread_m = 128;
} else {
} else if (prob_n <= 256) {
thread_k = 64;
thread_m = 256;
} else {
thread_k = 32;
thread_m = 512;
}
}

Expand All @@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s;

constexpr int max_m_blocks = 4;

int* locks = (int*)workspace;
for (int i = 0; i < tot_n_blocks; i += 4) {
for (int i = 0; i < tot_n_blocks; i += max_m_blocks) {
int thread_n_blocks = tot_n_blocks - i;
prob_n = tot_n - 16 * i;
int par = 1;
if (thread_n_blocks > 4) {
if (thread_n_blocks > max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_n_blocks - pad) / 64;
par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16);
if (par > max_par) par = max_par;
prob_n = 64 * par;
i += 4 * (par - 1);
thread_n_blocks = 4;
prob_n = (max_m_blocks * 16) * par;
i += max_m_blocks * (par - 1);
thread_n_blocks = max_m_blocks;
}

// For compilation speed, we only define the kernel configurations that have
Expand All @@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
if (false) {
} // BMxBNxBK, group
// 4-bit
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64

CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64
Expand All @@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4(4, 16, 4, 2, -1)
CALL_IF_2_4(4, 16, 4, 2, 4)

CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(4, 32, 2, 1, 4)
CALL_IF_2_4(4, 32, 3, 1, -1)
CALL_IF_2_4(4, 32, 3, 1, 4)
CALL_IF_2_4(4, 32, 4, 1, -1)
CALL_IF_2_4(4, 32, 4, 1, 4)

// 8-bit
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64
CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128
CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64

CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64
Expand All @@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
CALL_IF_2_4(8, 16, 3, 2, 4)
CALL_IF_2_4(8, 16, 4, 2, -1)
CALL_IF_2_4(8, 16, 4, 2, 4)

CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64
CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64
CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64
CALL_IF_2_4(8, 32, 2, 1, 4)
CALL_IF_2_4(8, 32, 3, 1, -1)
CALL_IF_2_4(8, 32, 3, 1, 4)
CALL_IF_2_4(8, 32, 4, 1, -1)
CALL_IF_2_4(8, 32, 4, 1, 4)
else {
throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) +
", " + str(prob_k) + ", " + str(prob_n) + "]" +
Expand Down Expand Up @@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int thread_k = -1;
int thread_m = -1;
int sms = -1;
int max_par = 16;
int max_par = marlin_24::max_par;

int groupsize = -1;
if (b_scales.size(0) > 1) {
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
MARLIN_N_CHUNKS = [64, 128, 256]

MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [256]
MARLIN_24_N_CHUNKS = [512]

MNK_FACTORS = [
(1, 1, 1),
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 16
GPTQ_MARLIN_24_MAX_PARALLEL = 64

GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
Expand Down Expand Up @@ -53,14 +53,14 @@ def __init__(
self.tile_size = 16

# Min out_features dim
self.min_n_threads = 128
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N

# Min in_features dim
self.min_k_threads = 128
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K

# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL

# Permutation length used by the marlin kernels.
self.perm_len = 1024
Expand Down
Loading