Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][CPU] Multi-LoRA implementation for the CPU backend #11100

Merged
merged 33 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c8a8c5b
Added Multi-LoRA implementation for the CPU backend
Akshat-Tripathi Dec 11, 2024
3545ac1
Readded final
Akshat-Tripathi Dec 11, 2024
2b3c650
Decoupled PunicaWrapperCPU from PunicaWrapperGPU
Akshat-Tripathi Dec 11, 2024
cc0bd6c
Fixed typing
Akshat-Tripathi Dec 11, 2024
dc9091a
Renamed lora op directories
Akshat-Tripathi Dec 12, 2024
410b746
Fixed tests
Akshat-Tripathi Dec 12, 2024
8eab6a6
Merge branch 'main' into multi_lora_cpu
Akshat-Tripathi Dec 12, 2024
5dfcf62
Removed one-to-one correspondence between triton and torch ops
Akshat-Tripathi Dec 13, 2024
651dc04
Fixed mypy error
Akshat-Tripathi Dec 13, 2024
d1026bb
Removed redundant optionals
Akshat-Tripathi Dec 13, 2024
2840445
Merge branch 'main' into multi_lora_cpu
mosalov Dec 17, 2024
41c518f
Renamed add_input to add_inputs in punica_cpu.py.
mosalov Dec 17, 2024
abb03e5
Merge branch 'main' into multi_lora_cpu
mosalov Dec 19, 2024
d212fc9
Optimize kernel test
jeejeelee Dec 24, 2024
dcb1794
Merge branch 'main' of https://github.com/vllm-project/vllm into mult…
jeejeelee Dec 30, 2024
3a684b8
Modify test name
jeejeelee Dec 30, 2024
c9b9f63
Merge branch 'main' of https://github.com/vllm-project/vllm into mult…
jeejeelee Dec 31, 2024
acdb4e3
Merge branch 'main' into multi_lora_cpu
Akshat-Tripathi Jan 7, 2025
cfa082f
Removed assert
Akshat-Tripathi Jan 8, 2025
8eaf7ec
Merge branch 'main' into multi_lora_cpu
Akshat-Tripathi Jan 8, 2025
bd54243
Merge branch 'main' of https://github.com/vllm-project/vllm into mult…
jeejeelee Jan 9, 2025
21c3799
Resolve test conflicts
jeejeelee Jan 9, 2025
aeaf078
Merge branch 'main' of https://github.com/vllm-project/vllm into mult…
jeejeelee Jan 9, 2025
61c42cf
Sync main
jeejeelee Jan 9, 2025
0d19f03
Make isort happy
jeejeelee Jan 9, 2025
5af9cbb
Make yapf happy
jeejeelee Jan 9, 2025
202aca3
run-cpu-test.sh now runs multi-lora tests
Akshat-Tripathi Jan 10, 2025
76444f4
Updated compatibility docs
Akshat-Tripathi Jan 10, 2025
64c700f
Update .buildkite/run-cpu-test.sh
Akshat-Tripathi Jan 10, 2025
06399bc
Lint
Akshat-Tripathi Jan 10, 2025
69eb3dc
Lint2
Akshat-Tripathi Jan 10, 2025
e208468
Merge branch 'main' of https://github.com/vllm-project/vllm into mult…
jeejeelee Jan 12, 2025
d8cb9af
Reduce test memory
jeejeelee Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ function cpu_tests() {
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"

# Run multi-lora tests
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
set -e
pytest -s -v \
tests/lora/test_qwen2vl.py"
}

# All of CPU tests are expected to be finished less than 25 mins.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/features/compatibility_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
- ✅
- ✅
- ✅
- [✗](gh-pr:4830)
-
- ✅
* - <abbr title="Prompt Adapter">prmpt adptr</abbr>
- ✅
Expand Down
32 changes: 19 additions & 13 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.platforms import current_platform


class ContextIDInfo(TypedDict):
Expand Down Expand Up @@ -65,13 +66,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
@pytest.fixture
def dist_init():
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)

backend = "nccl"
if current_platform.is_cpu():
backend = "gloo"

init_distributed_environment(world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend=backend)
initialize_model_parallel(1, 1)
yield
cleanup_dist_env_and_memory(shutdown_ray=True)
Expand All @@ -81,13 +85,15 @@ def dist_init():
def dist_init_torch_only():
if torch.distributed.is_initialized():
return
backend = "nccl"
if current_platform.is_cpu():
backend = "gloo"

temp_file = tempfile.mkstemp()[1]
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file://{temp_file}",
)
torch.distributed.init_process_group(world_size=1,
rank=0,
init_method=f"file://{temp_file}",
backend=backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 Are our CPU vendor willing to provide hardware testing for these? Who should we contact?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there is no numa node restriction for CPU TP, so I think we can just change VLLM_CPU_OMP_THREADS_BIND=48-92 to something like VLLM_CPU_OMP_THREADS_BIND=48-70|71-92 to enable TP, even if these cores are on same numa node.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I meant to ask if the CPU LoRA testing should be placed here. If so, we might need to contact the CPU vendor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the target is to enable a test to verify the lora code path on CPU backend



@pytest.fixture
Expand Down
39 changes: 30 additions & 9 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
# TODO: Modify this based on platform
DEVICES = [

pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
reason="Backend not supported")

DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
] if current_platform.is_cuda_alike() else ["cpu"])

#For GPU, we will launch different triton kernels between the prefill and decode
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
Expand Down Expand Up @@ -198,6 +202,10 @@ def check_punica_wrapper(punica_wrapper) -> bool:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU

return type(punica_wrapper) is PunicaWrapperGPU
elif current_platform.is_cpu():
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU

return type(punica_wrapper) is PunicaWrapperCPU
else:
return False

Expand All @@ -211,7 +219,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
# device, see: https://github.com/triton-lang/triton/issues/2925
# Same below.
torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
Expand Down Expand Up @@ -313,7 +322,9 @@ def create_random_embedding_layer():
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size, stage) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
Expand Down Expand Up @@ -450,7 +461,9 @@ def create_random_embedding_layer():
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
stage) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device)
Expand Down Expand Up @@ -582,7 +595,9 @@ def _pretest():
def test_linear_replicated(dist_init, num_loras, device, stage,
bias_enabled) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
Expand Down Expand Up @@ -695,7 +710,9 @@ def create_random_linear_replicated_layer():
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
Expand Down Expand Up @@ -818,7 +835,9 @@ def create_random_linear_parallel_layer():
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage, bias_enabled) -> None:

torch.cuda.set_device(device)
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
Expand Down Expand Up @@ -971,6 +990,8 @@ class FakeConfig:
@pytest.mark.parametrize("rotary_dim", [None, 32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only CUDA backends are supported")
def test_rotary_embedding_long_context(dist_init, num_loras, device,
scaling_factors, max_position,
is_neox_style, rotary_dim, head_size,
Expand Down
21 changes: 11 additions & 10 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform

EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
Expand All @@ -28,9 +29,9 @@

EMBEDDING_PADDING_MODULES = ["lm_head"]

CUDA_DEVICES = [
DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
] if current_platform.is_cuda_alike() else ["cpu"])


def test_peft_helper(sql_lora_files):
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_peft_helper(sql_lora_files):
PEFTHelper.from_dict(config)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
os.path.join(sql_lora_files, "adapter_model.safetensors"))
Expand Down Expand Up @@ -171,7 +172,7 @@ def test_replace_submodules(dist_init, dummy_model):
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device("cuda"))
torch.device(DEVICES[0]))
model = manager.model

assert isinstance(model.get_submodule("dense1"),
Expand All @@ -183,7 +184,7 @@ def test_replace_submodules(dist_init, dummy_model):
RowParallelLinearWithLoRA)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lora_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
Expand Down Expand Up @@ -244,7 +245,7 @@ def test_lora_model_manager(dist_init, dummy_model, device):
assert manager.punica_wrapper.device == device


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
model = dummy_model
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
Expand Down Expand Up @@ -336,7 +337,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
assert manager.device == device


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lru_lora_model_manager(dist_init, dummy_model, device):
# This tests just the LRU cache functionality, everything else is
# tested in test_lora_model_manager
Expand Down Expand Up @@ -466,7 +467,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
assert manager.device == device


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
Expand Down Expand Up @@ -545,7 +546,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
device)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
Expand Down Expand Up @@ -621,7 +622,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
device)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_packed_loras(dist_init, dummy_model_gate_up, device):
model = dummy_model_gate_up
model.supported_lora_modules = ["gate_up_proj"]
Expand Down
4 changes: 3 additions & 1 deletion tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import vllm
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform

MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"

Expand All @@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
@pytest.mark.parametrize("tp_size", [4])
def test_mixtral_lora(mixtral_lora_files, tp_size):
"""Original test, the LoRA model has the common target modules, not all"""
if torch.cuda.device_count() < tp_size:
if torch.cuda.device_count(
) < tp_size and tp_size > 1 and current_platform.is_cuda_alike():
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

prompts = [
Expand Down
Loading
Loading