From 8bddb735123204872788a8ffe117321de7550e6c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Sun, 12 Jan 2025 13:01:52 +0000 Subject: [PATCH] [Hardware][CPU] Multi-LoRA implementation for the CPU backend (#11100) Signed-off-by: Akshat Tripathi Signed-off-by: Oleg Mosalov Signed-off-by: Jee Jee Li Co-authored-by: Oleg Mosalov Co-authored-by: Jee Jee Li Co-authored-by: Isotr0py <2037008807@qq.com> --- .buildkite/run-cpu-test.sh | 6 + docs/source/features/compatibility_matrix.md | 2 +- tests/lora/conftest.py | 32 +- tests/lora/test_layers.py | 39 +- tests/lora/test_lora_manager.py | 21 +- tests/lora/test_mixtral.py | 4 +- ...nica_sizes.py => test_punica_ops_sizes.py} | 124 ++++--- ...iation.py => test_punica_ops_variation.py} | 134 ++++--- tests/lora/test_quant_model.py | 3 +- tests/lora/utils.py | 27 -- vllm/executor/cpu_executor.py | 3 - vllm/lora/ops/torch_ops/__init__.py | 13 + vllm/lora/ops/torch_ops/lora_ops.py | 113 ++++++ vllm/lora/ops/triton_ops/__init__.py | 13 + vllm/lora/ops/{ => triton_ops}/bgmv_expand.py | 0 .../ops/{ => triton_ops}/bgmv_expand_slice.py | 0 vllm/lora/ops/{ => triton_ops}/bgmv_shrink.py | 0 vllm/lora/ops/{ => triton_ops}/sgmv_expand.py | 0 vllm/lora/ops/{ => triton_ops}/sgmv_shrink.py | 0 vllm/lora/ops/{ => triton_ops}/utils.py | 0 vllm/lora/punica_wrapper/punica_cpu.py | 346 ++++++++++++++++++ vllm/lora/punica_wrapper/punica_gpu.py | 10 +- vllm/lora/punica_wrapper/punica_selector.py | 5 + vllm/worker/cpu_model_runner.py | 133 ++++++- vllm/worker/cpu_worker.py | 20 +- 25 files changed, 855 insertions(+), 193 deletions(-) rename tests/lora/{test_punica_sizes.py => test_punica_ops_sizes.py} (77%) rename tests/lora/{test_punica_variation.py => test_punica_ops_variation.py} (74%) create mode 100644 vllm/lora/ops/torch_ops/__init__.py create mode 100644 vllm/lora/ops/torch_ops/lora_ops.py create mode 100644 vllm/lora/ops/triton_ops/__init__.py rename vllm/lora/ops/{ => triton_ops}/bgmv_expand.py (100%) rename vllm/lora/ops/{ => triton_ops}/bgmv_expand_slice.py (100%) rename vllm/lora/ops/{ => triton_ops}/bgmv_shrink.py (100%) rename vllm/lora/ops/{ => triton_ops}/sgmv_expand.py (100%) rename vllm/lora/ops/{ => triton_ops}/sgmv_shrink.py (100%) rename vllm/lora/ops/{ => triton_ops}/utils.py (100%) create mode 100644 vllm/lora/punica_wrapper/punica_cpu.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 4ae66f6f3215a..9925db7bea593 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -75,6 +75,12 @@ function cpu_tests() { --num-prompts 20 \ --endpoint /v1/completions \ --tokenizer facebook/opt-125m" + + # Run multi-lora tests + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v \ + tests/lora/test_qwen2vl.py" } # All of CPU tests are expected to be finished less than 25 mins. diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index 8d8f7dca2e5b5..ea1d545ff3d73 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar - ✅ - ✅ - ✅ - - [✗](gh-pr:4830) + - ✅ - ✅ * - prmpt adptr - ✅ diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 57ebaa424fc59..e7378d00765f0 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from vllm.platforms import current_platform class ContextIDInfo(TypedDict): @@ -65,13 +66,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend="nccl", - ) + + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + + init_distributed_environment(world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -81,13 +85,15 @@ def dist_init(): def dist_init_torch_only(): if torch.distributed.is_initialized(): return + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group( - backend="nccl", - world_size=1, - rank=0, - init_method=f"file://{temp_file}", - ) + torch.distributed.init_process_group(world_size=1, + rank=0, + init_method=f"file://{temp_file}", + backend=backend) @pytest.fixture diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index fb8c0b2a7ba26..08a589d7ee29c 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -48,10 +48,14 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -# TODO: Modify this based on platform -DEVICES = [ + +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda_alike() or current_platform.is_cpu()), + reason="Backend not supported") + +DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +] if current_platform.is_cuda_alike() else ["cpu"]) #For GPU, we will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) @@ -198,6 +202,10 @@ def check_punica_wrapper(punica_wrapper) -> bool: from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU return type(punica_wrapper) is PunicaWrapperGPU + elif current_platform.is_cpu(): + from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU + + return type(punica_wrapper) is PunicaWrapperCPU else: return False @@ -211,7 +219,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA # device, see: https://github.com/triton-lang/triton/issues/2925 # Same below. - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 @@ -313,7 +322,9 @@ def create_random_embedding_layer(): def test_embeddings_with_new_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) @@ -450,7 +461,9 @@ def create_random_embedding_layer(): def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) @@ -582,7 +595,9 @@ def _pretest(): def test_linear_replicated(dist_init, num_loras, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) assert check_punica_wrapper(punica_wrapper) @@ -695,7 +710,9 @@ def create_random_linear_replicated_layer(): def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) assert check_punica_wrapper(punica_wrapper) @@ -818,7 +835,9 @@ def create_random_linear_parallel_layer(): def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) assert check_punica_wrapper(punica_wrapper) @@ -971,6 +990,8 @@ class FakeConfig: @pytest.mark.parametrize("rotary_dim", [None, 32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only CUDA backends are supported") def test_rotary_embedding_long_context(dist_init, num_loras, device, scaling_factors, max_position, is_neox_style, rotary_dim, head_size, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index a099f36b0a465..ca523c66abe42 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -20,6 +20,7 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", @@ -28,9 +29,9 @@ EMBEDDING_PADDING_MODULES = ["lm_head"] -CUDA_DEVICES = [ +DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +] if current_platform.is_cuda_alike() else ["cpu"]) def test_peft_helper(sql_lora_files): @@ -83,7 +84,7 @@ def test_peft_helper(sql_lora_files): PEFTHelper.from_dict(config) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) @@ -171,7 +172,7 @@ def test_replace_submodules(dist_init, dummy_model): manager = LoRAModelManager( model, 1, 1, 1, LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device("cuda")) + torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), @@ -183,7 +184,7 @@ def test_replace_submodules(dist_init, dummy_model): RowParallelLinearWithLoRA) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -244,7 +245,7 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.punica_wrapper.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -336,7 +337,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager @@ -466,7 +467,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) @@ -545,7 +546,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. @@ -621,7 +622,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up model.supported_lora_modules = ["gate_up_proj"] diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 797a495201d33..940a865228806 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -5,6 +5,7 @@ import vllm from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" @@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count() < tp_size: + if torch.cuda.device_count( + ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_ops_sizes.py similarity index 77% rename from tests/lora/test_punica_sizes.py rename to tests/lora/test_punica_ops_sizes.py index 0351fedd1cfa5..433ca7577d084 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_ops_sizes.py @@ -9,17 +9,16 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_shrink import sgmv_shrink -from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform from .utils import (assert_close, generate_data, generate_data_for_expand_nslices, - generate_data_for_nslices, ref_torch_groupgemm) + generate_data_for_nslices) HIDDEN_SIZES = [ 128, @@ -113,7 +112,7 @@ MAX_RANKS = [32] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] +DEVICES = [f"cuda:{0}"] _dict_lock = Lock() @@ -127,7 +126,7 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -174,7 +173,7 @@ def test_punica_sgmv( # Preventing cache error pointer. with _dict_lock: _LORA_A_PTR_DICT.clear() - sgmv_shrink( + torch.ops.vllm.sgmv_shrink( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -187,20 +186,23 @@ def test_punica_sgmv( scaling, ) for index in range(nslices): - ref_torch_groupgemm( - ref_out_tensor[index], + sgmv_shrink( inputs_tensor, lora_weights_lst[index], - lora_indices_tensor, + ref_out_tensor[index], + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, + max_seq_length, + token_nums, scaling, - op_type, ) + else: with _dict_lock: _LORA_B_PTR_DICT.clear() - sgmv_expand( + torch.ops.vllm.sgmv_expand( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -213,21 +215,39 @@ def test_punica_sgmv( offset_start=0, add_inputs=True, ) - - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - ref_torch_groupgemm( - ref_out_tensor[:, slice_offset:slice_offset + hidden_size], - inputs_tensor[index], - lora_weights, - lora_indices_tensor, + if nslices == 1: + # Verify the torch's sgmv_expand op + sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + ref_out_tensor, + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, - 1.0, - op_type, + max_seq_length, + token_nums, + add_inputs=True, ) - slice_offset += hidden_size + else: + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + sgmv_expand_slice( + inputs_tensor[index], + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + slice_offset += hidden_size assert_close(our_out_tensor, ref_out_tensor) @@ -240,7 +260,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -276,31 +296,38 @@ def test_punica_bgmv( device, ) if op_type == "shrink": - bgmv_shrink( + torch.ops.vllm.bgmv_shrink( inputs_tensor, lora_weights, our_out_tensor, indices, scaling, ) + + bgmv_shrink( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + scaling, + ) + else: - bgmv_expand( + torch.ops.vllm.bgmv_expand( inputs_tensor, lora_weights, our_out_tensor, indices, add_inputs=True, ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) + bgmv_expand( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + add_inputs=True, + ) + if op_type == "shrink": ref_out_tensor = ref_out_tensor.to(torch.float32) assert_close(our_out_tensor, ref_out_tensor) @@ -313,7 +340,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv_expand_nslices( batches: int, num_loras: int, @@ -350,7 +377,7 @@ def test_punica_bgmv_expand_nslices( slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] - bgmv_expand_slice( + torch.ops.vllm.bgmv_expand_slice( inputs_tensor, lora_weights, our_outputs, @@ -359,15 +386,14 @@ def test_punica_bgmv_expand_nslices( slice_size=hidden_size, add_inputs=True, ) - ref_torch_groupgemm( - ref_outputs[:, slice_offset:slice_offset + hidden_size], + bgmv_expand_slice( inputs_tensor, lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - 1.0, - op_type="expand", + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, ) slice_offset += hidden_size diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_ops_variation.py similarity index 74% rename from tests/lora/test_punica_variation.py rename to tests/lora/test_punica_ops_variation.py index 9ee10e7c23ee6..2bb84c1cf11e9 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_ops_variation.py @@ -9,19 +9,18 @@ import torch # Enable custom op register -import vllm.lora.ops.bgmv_expand -import vllm.lora.ops.bgmv_expand_slice -import vllm.lora.ops.bgmv_shrink -import vllm.lora.ops.sgmv_expand -import vllm.lora.ops.sgmv_shrink # noqa: F401 -from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform from .utils import (assert_close, generate_data, generate_data_for_expand_nslices, - generate_data_for_nslices, ref_torch_groupgemm) + generate_data_for_nslices) -HIDDEN_SIZES = [4097] +HIDDEN_SIZES = [2049] BATCHES = [1, 4, 16, 32] NUM_LORA = [1, 8, 32, 128] @@ -29,15 +28,7 @@ MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] - -# Unlike test_punica_sizes.py, we directly utilize custom op for -# testing, which verifies the correct registration of these ops. -bgmv_expand = torch.ops.vllm.bgmv_expand -bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice -bgmv_shrink = torch.ops.vllm.bgmv_shrink -sgmv_expand = torch.ops.vllm.sgmv_expand -sgmv_shrink = torch.ops.vllm.sgmv_shrink +DEVICES = [f"cuda:{0}"] _dict_lock = Lock() @@ -51,7 +42,7 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -98,7 +89,7 @@ def test_punica_sgmv( # Preventing cache error pointer. with _dict_lock: _LORA_A_PTR_DICT.clear() - sgmv_shrink( + torch.ops.vllm.sgmv_shrink( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -111,20 +102,23 @@ def test_punica_sgmv( scaling, ) for index in range(nslices): - ref_torch_groupgemm( - ref_out_tensor[index], + sgmv_shrink( inputs_tensor, lora_weights_lst[index], - lora_indices_tensor, + ref_out_tensor[index], + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, + max_seq_length, + token_nums, scaling, - op_type, ) + else: with _dict_lock: _LORA_B_PTR_DICT.clear() - sgmv_expand( + torch.ops.vllm.sgmv_expand( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -137,21 +131,39 @@ def test_punica_sgmv( offset_start=0, add_inputs=True, ) - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - ref_torch_groupgemm( - ref_out_tensor[:, slice_offset:slice_offset + hidden_size], - inputs_tensor[index], - lora_weights, - lora_indices_tensor, + if nslices == 1: + # Verify the torch's sgmv_expand op + sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + ref_out_tensor, + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, - 1.0, - op_type, + max_seq_length, + token_nums, + add_inputs=True, ) - slice_offset += hidden_size + else: + for index in range(nslices): + lora_weights = lora_weights_lst[index] + sgmv_expand_slice( + inputs_tensor[index], + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + slice_offset += hidden_size assert_close(our_out_tensor, ref_out_tensor) @@ -164,7 +176,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -176,7 +188,6 @@ def test_punica_bgmv( seed: int, device: str, ): - torch.set_default_device(device) current_platform.seed_everything(seed) @@ -201,32 +212,38 @@ def test_punica_bgmv( device, ) if op_type == "shrink": - bgmv_shrink( + torch.ops.vllm.bgmv_shrink( inputs_tensor, lora_weights, our_out_tensor, indices, scaling, ) - else: - bgmv_expand( + bgmv_shrink( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + scaling, + ) + + else: + torch.ops.vllm.bgmv_expand( inputs_tensor, lora_weights, our_out_tensor, indices, add_inputs=True, ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) + bgmv_expand( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + add_inputs=True, + ) + if op_type == "shrink": ref_out_tensor = ref_out_tensor.to(torch.float32) assert_close(our_out_tensor, ref_out_tensor) @@ -239,7 +256,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv_expand_nslices( batches: int, num_loras: int, @@ -276,7 +293,7 @@ def test_punica_bgmv_expand_nslices( slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] - bgmv_expand_slice( + torch.ops.vllm.bgmv_expand_slice( inputs_tensor, lora_weights, our_outputs, @@ -285,15 +302,14 @@ def test_punica_bgmv_expand_nslices( slice_size=hidden_size, add_inputs=True, ) - ref_torch_groupgemm( - ref_outputs[:, slice_offset:slice_offset + hidden_size], + bgmv_expand_slice( inputs_tensor, lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - 1.0, - op_type="expand", + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, ) slice_offset += hidden_size diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 026269667b473..26bf770cc0d4a 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -72,7 +72,8 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, tp_size): - if num_gpus_available < tp_size: + if num_gpus_available < tp_size and \ + tp_size > 1 and current_platform.is_cuda_alike(): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM( diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b66d18074a7bf..ce47546f2154b 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -104,33 +104,6 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -def ref_torch_groupgemm( - out_tensor, - inputs, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling, - op_type, -) -> torch.Tensor: - out_list = [] - current_offset = 0 - for lora_index, b_length in zip(range(batches), seq_len_tensor): - input_weight = inputs[current_offset:b_length + current_offset, :] - current_offset += b_length - lora_weight = lora_weights[lora_indices_tensor[lora_index]] - result = torch.nn.functional.linear(input_weight, lora_weight) - result *= scaling - out_list.append(result) - cat_result = torch.cat(out_list, dim=0) - if op_type == "expand": - out_tensor += cat_result - else: - out_tensor.copy_(cat_result) - return - - def generate_data( batches, hidden_size, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index c7f018d9a203e..b9a6bee5720fd 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -22,9 +22,6 @@ class CPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" - # Reminder: Please update docs/source/features/compatibility_matrix.md - # If the feature combo become valid - assert self.lora_config is None, "cpu backend doesn't support LoRA" # # Environment variables for CPU executor diff --git a/vllm/lora/ops/torch_ops/__init__.py b/vllm/lora/ops/torch_ops/__init__.py new file mode 100644 index 0000000000000..9c9159b95f308 --- /dev/null +++ b/vllm/lora/ops/torch_ops/__init__.py @@ -0,0 +1,13 @@ +from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py new file mode 100644 index 0000000000000..5f5aafd516159 --- /dev/null +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -0,0 +1,113 @@ +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py new file mode 100644 index 0000000000000..9805b6dd5038e --- /dev/null +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -0,0 +1,13 @@ +from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401 + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/triton_ops/bgmv_expand.py similarity index 100% rename from vllm/lora/ops/bgmv_expand.py rename to vllm/lora/ops/triton_ops/bgmv_expand.py diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/triton_ops/bgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/bgmv_expand_slice.py rename to vllm/lora/ops/triton_ops/bgmv_expand_slice.py diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/triton_ops/bgmv_shrink.py similarity index 100% rename from vllm/lora/ops/bgmv_shrink.py rename to vllm/lora/ops/triton_ops/bgmv_shrink.py diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/triton_ops/sgmv_expand.py similarity index 100% rename from vllm/lora/ops/sgmv_expand.py rename to vllm/lora/ops/triton_ops/sgmv_expand.py diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/triton_ops/sgmv_shrink.py similarity index 100% rename from vllm/lora/ops/sgmv_shrink.py rename to vllm/lora/ops/triton_ops/sgmv_shrink.py diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/triton_ops/utils.py similarity index 100% rename from vllm/lora/ops/utils.py rename to vllm/lora/ops/triton_ops/utils.py diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py new file mode 100644 index 0000000000000..b9ae3e07492c0 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -0,0 +1,346 @@ +from typing import Callable, Optional, Tuple, Union + +import torch + +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperCPU(PunicaWrapperBase): + """ + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 278f7b5a8e9f4..451f23e49f27c 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -12,11 +12,11 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_shrink import sgmv_shrink + from vllm.lora.ops.triton_ops import bgmv_expand + from vllm.lora.ops.triton_ops import bgmv_expand_slice + from vllm.lora.ops.triton_ops import bgmv_shrink + from vllm.lora.ops.triton_ops import sgmv_expand + from vllm.lora.ops.triton_ops import sgmv_shrink from .punica_base import PunicaWrapperBase diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 9791d492d8e48..9f1606e672dea 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -12,6 +12,11 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU logger.info_once("Using PunicaWrapperGPU.") return PunicaWrapperGPU(*args, **kwargs) + elif current_platform.is_cpu(): + # Lazy import to avoid ImportError + from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU + logger.info_once("Using PunicaWrapperCPU.") + return PunicaWrapperCPU(*args, **kwargs) elif current_platform.is_hpu(): # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d99db4e0c6c40..303d9a15e9c3c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,8 +2,8 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, - Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, + TypeVar, Union) import torch from torch import nn @@ -12,10 +12,14 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, @@ -49,6 +53,8 @@ class ModelInputForCPU(ModelRunnerInputBase): virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -57,6 +63,8 @@ def as_broadcastable_tensor_dict( "input_positions": self.input_positions, "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -143,7 +151,11 @@ def __init__(self, or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.enable_lora = self.runner.lora_config is not None self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( @@ -183,15 +195,28 @@ def build(self) -> ModelInputForCPU: attn_metadata = self.att_metadata_builder.build( input_data.seq_lens, input_data.query_lens, -1, -1) - return self.model_input_cls( - input_tokens=input_tokens, - input_positions=input_positions, - token_type_ids=token_type_ids, - seq_lens=input_data.seq_lens, - query_lens=input_data.query_lens, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - ) + is_prompt = (self.seq_group_metadata_list[0].is_prompt + if self.seq_group_metadata_list else None) + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(seq.lora_request + for seq in self.seq_group_metadata_list + if seq.lora_request is not None) + + lora_mapping = self._prepare_lora_input( + self.seq_group_metadata_list, is_prompt) + + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: @@ -381,6 +406,24 @@ def _compute_multi_modal_input(self, self.input_data.multi_modal_placeholder_maps[modality].extend( placeholder_map) + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool) -> LoRAMapping: + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + + index_mapping += [lora_id] * query_len + prompt_mapping += [lora_id] * ( + query_len if seq.sampling_params + and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill) + class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ @@ -431,10 +474,41 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -459,6 +533,37 @@ def sampler(self): def vocab_size(self) -> int: return self.model_config.get_vocab_size() + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( @@ -515,6 +620,12 @@ def execute_model( raise ValueError( "CPU worker does not support multi-step execution.") + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + model_executable = self.model multimodal_kwargs = {} diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 494c6506f3c0f..3e5fcf11b9e16 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Set, Tuple, Type import torch import torch.distributed @@ -11,14 +11,14 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -111,7 +111,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class CPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -266,6 +266,18 @@ def initialize_cache(self, num_gpu_blocks: int, # Initialize the cache. self._init_cache_engine() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: """Raise errors if the num_cpu_blocks is invalid. """