forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hardware][Intel-Gaudi] Enable LoRA support for Intel Gaudi (HPU) (vl…
…lm-project#10565) Signed-off-by: Sanju C Sudhakaran <[email protected]>
- Loading branch information
1 parent
f092153
commit 8195824
Showing
5 changed files
with
107 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Optional, Tuple, Union, final | ||
|
||
import torch | ||
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, | ||
dispatch_bgmv_linear) | ||
|
||
from .punica_base import PunicaWrapperBase | ||
|
||
|
||
@final | ||
class PunicaWrapperHPU(PunicaWrapperBase): | ||
|
||
def __init__(self, max_num_batched_tokens: int, max_batches: int, | ||
device: Union[torch.device, str], **kwargs): | ||
# Increasing max_num_batched_tokens by 3x to handle increase in | ||
# tensor size due to padding. | ||
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, | ||
max_batches, device) | ||
|
||
def add_lora_embedding(self, | ||
y: torch.Tensor, | ||
x: torch.Tensor, | ||
lora_b_stacked: torch.Tensor, | ||
add_input: bool = True, | ||
**kwargs) -> None: | ||
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) | ||
|
||
def add_lora_linear(self, | ||
y: torch.Tensor, | ||
x: torch.Tensor, | ||
lora_a_stacked: Tuple[torch.Tensor, ...], | ||
lora_b_stacked: Tuple[torch.Tensor, ...], | ||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], | ||
scale: float, | ||
output_slices: Tuple[int, ...], | ||
*, | ||
buffer: Optional[Tuple[torch.Tensor, ...]] = None, | ||
**kwargs) -> None: | ||
y_org = y | ||
x = x.view(-1, x.shape[-1]) | ||
y = y.view(-1, y.shape[-1]) | ||
offset_left = 0 | ||
|
||
for slice_idx in range(len(output_slices)): | ||
dispatch_bgmv_linear( | ||
y[:, offset_left:offset_left + output_slices[slice_idx]], x, | ||
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) | ||
offset_left += output_slices[slice_idx] | ||
y = y.view_as(y_org) | ||
|
||
def add_lora_logits(self, | ||
y: torch.Tensor, | ||
x: torch.Tensor, | ||
lora_a_stacked: torch.Tensor, | ||
lora_b_stacked: torch.Tensor, | ||
scale, | ||
*, | ||
buffer: Optional[torch.Tensor] = None, | ||
**kwargs) -> None: | ||
y_org = y | ||
y = y.view(-1, y.shape[-1]) | ||
x = x.view(-1, x.shape[-1]) | ||
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) | ||
y = y.view_as(y_org) | ||
|
||
def add_shrink( | ||
self, | ||
y: Union[Tuple[torch.Tensor, ...], torch.Tensor], | ||
x: torch.Tensor, | ||
lora_a_stacked: Tuple[torch.Tensor, ...], | ||
scale: float, | ||
**kwargs, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
def add_expand( | ||
self, | ||
y: torch.Tensor, | ||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor], | ||
lora_b_stacked: Tuple[torch.Tensor, ...], | ||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], | ||
output_slices: Tuple[int, ...], | ||
offset_start: int = 0, | ||
add_input=True, | ||
**kwargs, | ||
) -> None: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters