Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][LoRA]Punica prefill kernels fusion #11234

Merged
merged 70 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
ec3590d
Init
jeejeelee Dec 10, 2024
9474fb0
Sync main
jeejeelee Dec 10, 2024
8c2ac4c
Fix bug
jeejeelee Dec 10, 2024
2897d05
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
35aebea
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
628a567
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
d04121c
Back up
jeejeelee Dec 11, 2024
a306f42
shrink_sgmv Done
jeejeelee Dec 11, 2024
f6bccc7
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 11, 2024
e5cb72e
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 12, 2024
b6013db
Optimize ptr compute
jeejeelee Dec 12, 2024
7f088ec
Merge commit 'b6013db4' into punica-kernel-fusion
jeejeelee Dec 13, 2024
32c5279
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 13, 2024
8d3742b
Increase the tile size
jeejeelee Dec 13, 2024
9564b33
Clean up triton interface
jeejeelee Dec 13, 2024
3eb3ac3
Sync main
jeejeelee Dec 16, 2024
4012466
Backup
jeejeelee Dec 16, 2024
18bbadf
Optimize one sclice kernel
jeejeelee Dec 16, 2024
43aae70
Delete unused code
jeejeelee Dec 16, 2024
482de15
Refactor expand
jeejeelee Dec 16, 2024
259d382
format
jeejeelee Dec 16, 2024
00f1904
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 17, 2024
a0197e3
Optimize logic
jeejeelee Dec 17, 2024
38ba4f1
Add comments
jeejeelee Dec 17, 2024
3c37226
Fix bug
jeejeelee Dec 17, 2024
45180c1
Fix expand bug
jeejeelee Dec 17, 2024
2e52d2c
Backup
jeejeelee Dec 17, 2024
2146141
revert expand tile size
jeejeelee Dec 17, 2024
d724891
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 18, 2024
9719617
Clean up code
jeejeelee Dec 18, 2024
5d2c557
Optimize expand tile size
jeejeelee Dec 18, 2024
958500d
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 19, 2024
5c88ec4
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 19, 2024
3460308
improve expand (#3)
Abatom Dec 19, 2024
24e893c
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 20, 2024
c9747c6
Lora expand (#4)
Abatom Dec 20, 2024
f3ecfc6
Lora expand (#5)
Abatom Dec 20, 2024
5859da7
Fix K size
jeejeelee Dec 20, 2024
b3ea6fc
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 21, 2024
eb01089
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 23, 2024
ebc9519
revert (#6)
Abatom Dec 24, 2024
a4f46b6
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 24, 2024
2cdf459
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 24, 2024
ba2c444
Add unit test
jeejeelee Dec 24, 2024
394886d
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 24, 2024
36fbeac
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 25, 2024
0f7897b
Optimize unit test
jeejeelee Dec 25, 2024
3edb696
Optimize unit test
jeejeelee Dec 25, 2024
49c6c21
Fix comment
jeejeelee Dec 25, 2024
bf3b9ca
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 26, 2024
fe24a41
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 26, 2024
9d89f47
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 27, 2024
489eca1
Optimize code
jeejeelee Dec 28, 2024
04ae0dd
Add lock for unit test
jeejeelee Dec 28, 2024
fa489f2
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 30, 2024
ea19a7d
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 30, 2024
65d0f2f
Optimize arg
jeejeelee Dec 30, 2024
797ae77
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 30, 2024
2b9f928
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 31, 2024
09fb9a9
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Dec 31, 2024
f446454
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 1, 2025
767b233
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 2, 2025
421382e
Fix expand bug
jeejeelee Jan 2, 2025
90a9117
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 3, 2025
2c79295
Reduce memory
jeejeelee Jan 3, 2025
7e8d3bd
Modify minicpmv test
jeejeelee Jan 4, 2025
02b1d80
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 4, 2025
bd8cc45
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 5, 2025
7ffd15e
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 6, 2025
c1c5b4b
Merge branch 'vllm-project:main' into punica-kernel-fusion
jeejeelee Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 80 additions & 86 deletions tests/lora/test_punica_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
from threading import Lock

import pytest
import torch

from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform

from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
from .utils import (assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices, ref_torch_groupgemm)

HIDDEN_SIZES = [
128,
Expand Down Expand Up @@ -112,21 +115,15 @@
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]


def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
_dict_lock = Lock()


@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
Expand All @@ -137,6 +134,7 @@ def test_punica_sgmv(
rank: int,
hidden_size: int,
scaling: float,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
Expand All @@ -148,19 +146,20 @@ def test_punica_sgmv(
seq_length = 128
(
inputs_tensor,
lora_weights,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
) = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
op_type,
device,
Expand All @@ -172,43 +171,64 @@ def test_punica_sgmv(
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
sgmv_shrink(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
for index in range(nslices):
ref_torch_groupgemm(
ref_out_tensor[index],
inputs_tensor,
lora_weights_lst[index],
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
with _dict_lock:
_LORA_B_PTR_DICT.clear()
sgmv_expand(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
offset_start=0,
add_inputs=True,
)

slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
ref_torch_groupgemm(
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
inputs_tensor[index],
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type,
)
slice_offset += hidden_size

assert_close(our_out_tensor, ref_out_tensor)


Expand Down Expand Up @@ -292,25 +312,22 @@ def test_punica_bgmv(
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
def test_punica_bgmv_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):

torch.set_default_device(device)
current_platform.seed_everything(seed)

seq_length = 128 if op_type == "sgmv" else 1
seq_length = 1
(
inputs_tensor,
lora_weights_lst,
Expand All @@ -330,41 +347,18 @@ def test_punica_expand_nslices(
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
else:

bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
Expand Down
Loading
Loading