Skip to content

Commit

Permalink
[Hardware][CPU] Multi-LoRA implementation for the CPU backend (#11100)
Browse files Browse the repository at this point in the history
Signed-off-by: Akshat Tripathi <[email protected]>
Signed-off-by: Oleg Mosalov <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Co-authored-by: Oleg Mosalov <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
4 people authored Jan 12, 2025
1 parent f967e51 commit 8bddb73
Show file tree
Hide file tree
Showing 25 changed files with 855 additions and 193 deletions.
6 changes: 6 additions & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ function cpu_tests() {
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"

# Run multi-lora tests
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
set -e
pytest -s -v \
tests/lora/test_qwen2vl.py"
}

# All of CPU tests are expected to be finished less than 25 mins.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/features/compatibility_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
- ✅
- ✅
- ✅
- [✗](gh-pr:4830)
-
- ✅
* - <abbr title="Prompt Adapter">prmpt adptr</abbr>
- ✅
Expand Down
32 changes: 19 additions & 13 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.platforms import current_platform


class ContextIDInfo(TypedDict):
Expand Down Expand Up @@ -65,13 +66,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)

backend = "nccl"
if current_platform.is_cpu():
backend = "gloo"

init_distributed_environment(world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend=backend)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory(shutdown_ray=True)
Expand All @@ -81,13 +85,15 @@ def dist_init():
def dist_init_torch_only():
if torch.distributed.is_initialized():
return
backend = "nccl"
if current_platform.is_cpu():
backend = "gloo"

temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
torch.distributed.init_process_group(world_size=1,
rank=0,
init_method=f"file://{temp_file}",
backend=backend)


@pytest.fixture
Expand Down
39 changes: 30 additions & 9 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
# TODO: Modify this based on platform
DEVICES = [

pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
reason="Backend not supported")

DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
] if current_platform.is_cuda_alike() else ["cpu"])

#For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
Expand Down Expand Up @@ -198,6 +202,10 @@ def check_punica_wrapper(punica_wrapper) -> bool:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

return type(punica_wrapper) is PunicaWrapperGPU
elif current_platform.is_cpu():
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

return type(punica_wrapper) is PunicaWrapperCPU
else:
return False

Expand All @@ -211,7 +219,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
Expand Down Expand Up @@ -313,7 +322,9 @@ def create_random_embedding_layer():
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size, stage) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
Expand Down Expand Up @@ -450,7 +461,9 @@ def create_random_embedding_layer():
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
Expand Down Expand Up @@ -582,7 +595,9 @@ def _pretest():
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
Expand Down Expand Up @@ -695,7 +710,9 @@ def create_random_linear_replicated_layer():
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
Expand Down Expand Up @@ -818,7 +835,9 @@ def create_random_linear_parallel_layer():
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
Expand Down Expand Up @@ -971,6 +990,8 @@ class FakeConfig:
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only CUDA backends are supported")
def test_rotary_embedding_long_context(dist_init, num_loras, device,
scaling_factors, max_position,
is_neox_style, rotary_dim, head_size,
Expand Down
21 changes: 11 additions & 10 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform

EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
Expand All @@ -28,9 +29,9 @@

EMBEDDING_PADDING_MODULES = ["lm_head"]

CUDA_DEVICES = [
DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
] if current_platform.is_cuda_alike() else ["cpu"])


def test_peft_helper(sql_lora_files):
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_peft_helper(sql_lora_files):
PEFTHelper.from_dict(config)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
Expand Down Expand Up @@ -171,7 +172,7 @@ def test_replace_submodules(dist_init, dummy_model):
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device("cuda"))
torch.device(DEVICES[0]))
model = manager.model

assert isinstance(model.get_submodule("dense1"),
Expand All @@ -183,7 +184,7 @@ def test_replace_submodules(dist_init, dummy_model):
RowParallelLinearWithLoRA)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
Expand Down Expand Up @@ -244,7 +245,7 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.punica_wrapper.device == device


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
Expand Down Expand Up @@ -336,7 +337,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.device == device


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
Expand Down Expand Up @@ -466,7 +467,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.device == device


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
Expand Down Expand Up @@ -545,7 +546,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
device)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
Expand Down Expand Up @@ -621,7 +622,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
device)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
Expand Down
4 changes: 3 additions & 1 deletion tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import vllm
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"

Expand All @@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora(mixtral_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
if torch.cuda.device_count(
) < tp_size and tp_size > 1 and current_platform.is_cuda_alike():
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

prompts = [
Expand Down
Loading

0 comments on commit 8bddb73

Please sign in to comment.