Skip to content

Commit

Permalink
[Hardware][Intel-Gaudi] Enable LoRA support for Intel Gaudi (HPU) (vl…
Browse files Browse the repository at this point in the history
…lm-project#10565)

Signed-off-by: Sanju C Sudhakaran <[email protected]>
  • Loading branch information
SanjuCSudhakaran authored Dec 12, 2024
1 parent f092153 commit 8195824
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 14 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e096d6f
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
6 changes: 6 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.lora.punica_wrapper import PunicaWrapperBase
Expand Down Expand Up @@ -1068,6 +1069,11 @@ def _get_logits(
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))

# HPU needs special handling to prune out dummy samples.
if current_platform.is_hpu():
lora_logits = lora_logits[:logits.shape[0], :]

logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
Expand Down
87 changes: 87 additions & 0 deletions vllm/lora/punica_wrapper/punica_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Optional, Tuple, Union, final

import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear)

from .punica_base import PunicaWrapperBase


@final
class PunicaWrapperHPU(PunicaWrapperBase):

def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
# Increasing max_num_batched_tokens by 3x to handle increase in
# tensor size due to padding.
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device)

def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs) -> None:
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)

def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0

for slice_idx in range(len(output_slices)):
dispatch_bgmv_linear(
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)

def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
y = y.view_as(y_org)

def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
raise NotImplementedError

def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs,
) -> None:
raise NotImplementedError
5 changes: 5 additions & 0 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,10 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError
21 changes: 8 additions & 13 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,10 @@ def load_model(self) -> None:
assert hasattr(
self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
assert not self.lora_config.bias_enabled, \
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
Expand Down Expand Up @@ -1282,11 +1286,9 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1]
max_seq_len = min(
self.bucketing_global_state.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
Expand All @@ -1304,7 +1306,6 @@ def warmup_scenario(self,
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
Expand All @@ -1326,16 +1327,10 @@ def warmup_scenario(self,
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
for idx in range(batch_size)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_pt_profiler_run else 1
if self.lora_config and not is_lora_profile_run:
lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size * seq_len,
prompt_mapping=[0] * batch_size * seq_len,
is_prefill=is_prompt))
self.set_active_loras(set(), lora_mapping)
if is_prompt:
seqs = [
self.create_dummy_seq_group_metadata(
Expand Down

0 comments on commit 8195824

Please sign in to comment.