From ec3590d9f70a62682a7a55accec8ba8e6b9b966e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 10 Dec 2024 01:49:47 +0000 Subject: [PATCH 01/35] Init Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand_slice.py | 170 ++++++++++++++++------------- vllm/lora/punica.py | 57 +++++++--- 2 files changed, 138 insertions(+), 89 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 55c4fb68ed128..f0d239e20678b 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -1,10 +1,12 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ +from typing import List, Tuple + import torch import triton import triton.language as tl @@ -22,14 +24,14 @@ def _sgmv_expand_slice_kernel( b_seq_start_loc, seq_lens, lora_indices, + slice_start_loc, xm_stride, xk_stride, # 1 - l0_stride, # hidden_size*max_rank - lora_k_stride, - lora_n_stride, + ls_d0_ptr, # lora stride(0) + ls_d1_ptr, + ls_d2_ptr, cm_stride, cn_stride, - slice_offset, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -39,14 +41,15 @@ def _sgmv_expand_slice_kernel( ): """ - Similar to the 'sgmv_expand' operator, but with an added parameter - 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator - might be that in the future, we could implement a fusion operator to - achieve the current functionality instead of having to call it multiple + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple times. """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) + slice_id = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num @@ -56,6 +59,7 @@ def _sgmv_expand_slice_kernel( lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -63,10 +67,26 @@ def _sgmv_expand_slice_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + # if CAST_TYPE: + # # fp32 + # cur_input_ptr=tl.load(input_ptr + slice_id).to( + # tl.pointer_type(tl.float32)) + # else: + # cur_input_ptr=tl.load(input_ptr + slice_id).to( + # out_ptr.dtype.element_ty) + # cur_input_ptr = tl.load(input_ptr + slice_id * xm_stride * xk_stride) + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + a_ptr = (input_ptr + slice_id * xm_stride * xk_stride + + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride, ) - b_ptr = (lora_ptr + l0_stride * lora_index + - offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + # lora + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): if EVEN_K: @@ -80,26 +100,30 @@ def _sgmv_expand_slice_kernel( mask=offset_k[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: - tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) accumulator += tl.dot( tiled_a, tiled_b, ) a_ptr += BLOCK_K * xk_stride - b_ptr += BLOCK_K * lora_n_stride - tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + b_ptr += BLOCK_K * cur_lora_d2_stride + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + # 获取每个slice的偏移地址 + if slice_start_loc is not None: + cur_slice_start = tl.load(slice_start_loc + slice_id) + else: + cur_slice_start = 0 offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride) M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < - (slice_offset + N)) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < + (cur_slice_start + N)) if ADD_INPUTS: - # explicitly pass in other=None to tell triton that masked values - # can be uninitialized. This is OK because the later tl.store operation - # uses the same mask, eliminating the risk of garbage values propagating - tiled_out = tl.load(c_ptr, mask=c_mask, other=None) + tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) @@ -107,7 +131,7 @@ def _sgmv_expand_slice_kernel( @torch.inference_mode() def _sgmv_expand_slice( inputs: torch.Tensor, - lora_b_weights: torch.Tensor, + lora_b_stacked: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -115,59 +139,53 @@ def _sgmv_expand_slice( batches: int, max_seq_length: int, token_nums: int, - slice_offset: int, - slice_size: int, add_inputs: bool = False, ) -> None: - """_summary_ - - Args: - inputs (torch.Tensor): input tensor - lora_b_weights (torch.Tensor): lora'a weight - output_tensor (torch.Tensor): output tensor - b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative - sequence lengths of the sequences in the batch, used to index - into sequence. E.g., if the sequence length is [4, 6], it is - [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence - length of the sequences in the batch - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch. An index of -1 means no lora should be - applied. - batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - token_nums (int): The token numbers in the batch. Used to verify if the - token numbers in the inputs matches the one in the metadata. - slice_offset (int): output_tensor's offset - slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False, adds the final lora - results to the output. - """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights.dtype in [ + assert lora_b_stacked[0].dtype in [ torch.float16, torch.bfloat16, ] - assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.size(1) == token_nums + assert inputs.size(0) == len(lora_b_stacked) + assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches - assert slice_size == lora_b_weights.size(-2) - assert inputs.is_contiguous() assert output_tensor.is_contiguous() + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + slice_offset = 0 + for lora_b_weight in lora_b_stacked: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) - if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weights.size(1) == 1 - lora_b_weights = lora_b_weights.squeeze(dim=1) - else: - assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) - - assert lora_b_weights.is_contiguous() + slice_start_tensor = torch.tensor(slice_offset_lst, + device=b_seq_start_loc.device) + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=b_seq_start_loc.device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, + device=b_seq_start_loc.device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, + device=b_seq_start_loc.device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, + device=b_seq_start_loc.device) # TODO tuning this config - N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + N, K = lora_b_stacked[0].shape[-2:] # K= rank,N=hidden_size BLOCK_M = 32 BLOCK_N = 32 @@ -175,7 +193,8 @@ def _sgmv_expand_slice( EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + + if inputs.dtype == torch.float32 and lora_b_stacked[0].dtype in [ torch.float16, torch.bfloat16, ]: @@ -183,24 +202,25 @@ def _sgmv_expand_slice( grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), batches, + len(lora_ptr_tensor), ) _sgmv_expand_slice_kernel[grid]( inputs, - lora_b_weights, + lora_ptr_tensor, output_tensor, N, K, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, - inputs.stride(0), + slice_start_tensor, inputs.stride(1), - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, output_tensor.stride(0), output_tensor.stride(1), - slice_offset, BLOCK_M, BLOCK_N, BLOCK_K, @@ -211,9 +231,9 @@ def _sgmv_expand_slice( return -def sgmv_expand_slice_fake( +def _sgmv_expand_slice_fake( inputs: torch.Tensor, - lora_b_weights: torch.Tensor, + lora_b_stacked: Tuple[torch.Tensor, ...], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -233,7 +253,7 @@ def sgmv_expand_slice_fake( op_name="sgmv_expand_slice", op_func=_sgmv_expand_slice, mutates_args=["output_tensor"], - fake_impl=sgmv_expand_slice_fake, + fake_impl=_sgmv_expand_slice_fake, ) sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 563d1181d6fcb..abf2f847c740e 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -438,6 +438,24 @@ def _expand_slice_prefill( add_input, ) + def _expand_nslices_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + def _expand_slice_decode( self, y: torch.Tensor, @@ -588,16 +606,20 @@ def add_expand( if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_input=add_input, - ) - offset_left += output_slices[slice_idx] + if self.is_prefill: + # NOTE fused kernel + self._expand_nslices_prefill(y, x, lora_b_stacked, add_input=True) + else: + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_input=add_input, + ) + offset_left += output_slices[slice_idx] y = y.view_as(y_org) def add_lora_embedding( @@ -670,10 +692,17 @@ def add_lora_linear( r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + # buffer = tuple( + # torch.zeros( + # (x.size(0), r), dtype=torch.float32, device=x.device) + # for _ in range(len(output_slices))) + + buffer = torch.zeros( + (len(output_slices), x.size(0), r), + dtype=torch.float32, + device=x.device, + ) + self.add_shrink(buffer, x, lora_a_stacked, scale) self.add_expand(y, buffer, From 8c2ac4ca4d2d497e97e031d9edc736131849ec8c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 10 Dec 2024 11:06:09 +0000 Subject: [PATCH 02/35] Fix bug Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand_slice.py | 39 ++++++++++++-------------- vllm/lora/punica_wrapper/punica_gpu.py | 20 +++++++------ 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index f0d239e20678b..285a64cffb6c8 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -25,8 +25,9 @@ def _sgmv_expand_slice_kernel( seq_lens, lora_indices, slice_start_loc, - xm_stride, - xk_stride, # 1 + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 ls_d0_ptr, # lora stride(0) ls_d1_ptr, ls_d2_ptr, @@ -67,23 +68,18 @@ def _sgmv_expand_slice_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - # if CAST_TYPE: - # # fp32 - # cur_input_ptr=tl.load(input_ptr + slice_id).to( - # tl.pointer_type(tl.float32)) - # else: - # cur_input_ptr=tl.load(input_ptr + slice_id).to( - # out_ptr.dtype.element_ty) - # cur_input_ptr = tl.load(input_ptr + slice_id * xm_stride * xk_stride) + # input + cur_input_ptr = input_ptr + slice_id * input_d0_stride + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + # lora cur_lora_ptr = tl.load(lora_ptr + slice_id).to( tl.pointer_type(out_ptr.dtype.element_ty)) cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - a_ptr = (input_ptr + slice_id * xm_stride * xk_stride + - cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride, ) - # lora + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + offset_k[:, None] * cur_lora_d2_stride + rbn[None, :] * cur_lora_d1_stride) @@ -105,15 +101,13 @@ def _sgmv_expand_slice_kernel( tiled_a, tiled_b, ) - a_ptr += BLOCK_K * xk_stride + a_ptr += BLOCK_K * input_d2_stride b_ptr += BLOCK_K * cur_lora_d2_stride tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) - # 获取每个slice的偏移地址 - if slice_start_loc is not None: - cur_slice_start = tl.load(slice_start_loc + slice_id) - else: - cur_slice_start = 0 + + cur_slice_start = tl.load(slice_start_loc + slice_id) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + @@ -139,6 +133,7 @@ def _sgmv_expand_slice( batches: int, max_seq_length: int, token_nums: int, + offset_start: int = 0, add_inputs: bool = False, ) -> None: assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] @@ -153,12 +148,13 @@ def _sgmv_expand_slice( assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert output_tensor.is_contiguous() + # TODO Optimize the following code slice_offset_lst = [] tensor_ptrs = [] lora_strides_d0 = [] lora_strides_d1 = [] lora_strides_d2 = [] - slice_offset = 0 + slice_offset = offset_start for lora_b_weight in lora_b_stacked: if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) assert lora_b_weight.size(1) == 1 @@ -214,6 +210,7 @@ def _sgmv_expand_slice( seq_len_tensor, lora_indices_tensor, slice_start_tensor, + inputs.stride(0), inputs.stride(1), inputs.stride(2), lora_strides_d0_tensor, diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 38d29d18f8c4a..8a8c7dc7ac789 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -116,6 +116,7 @@ def _expand_nslices_prefill( y: torch.Tensor, x: torch.Tensor, w_t_all: Tuple[torch.Tensor, ...], + offset_start: int, add_input: bool, ): #No LoRA request, so return directly @@ -126,6 +127,7 @@ def _expand_nslices_prefill( w_t_all, y, *self.prefill_metadata, + offset_start, add_input, ) @@ -235,24 +237,28 @@ def add_expand(self, """ y_org = y y = y.view(-1, y.shape[-1]) - offset_left = offset_start + # offset_ = offset_start if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) if self.is_prefill: # NOTE fused kernel - self._expand_nslices_prefill(y, x, lora_b_stacked, add_input=True) + self._expand_nslices_prefill(y, + x, + lora_b_stacked, + offset_start, + add_input=True) else: for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, x[slice_idx], lora_b_stacked[slice_idx], - offset_left, + offset_start, output_slices[slice_idx], add_input=add_input, ) - offset_left += output_slices[slice_idx] + offset_start += output_slices[slice_idx] y = y.view_as(y_org) def add_lora_embedding(self, @@ -323,11 +329,6 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - # buffer = tuple( - # torch.zeros( - # (x.size(0), r), dtype=torch.float32, device=x.device) - # for _ in range(len(output_slices))) - buffer = torch.zeros( (len(output_slices), x.size(0), r), dtype=torch.float32, @@ -341,6 +342,7 @@ def add_lora_linear(self, output_slices, add_input=True, **kwargs) + pass def add_lora_logits(self, y: torch.Tensor, From d04121ca0176aaecddbe80e166db980953533780 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 11 Dec 2024 05:42:41 +0000 Subject: [PATCH 03/35] Back up Signed-off-by: Jee Jee Li --- vllm/lora/punica_wrapper/punica_gpu.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 8a8c7dc7ac789..d125b09a071e0 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -111,7 +111,7 @@ def _expand_slice_prefill( add_input, ) - def _expand_nslices_prefill( + def _apply_expand_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -131,7 +131,7 @@ def _expand_nslices_prefill( add_input, ) - def _expand_slice_decode( + def _apply_expand_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -237,20 +237,17 @@ def add_expand(self, """ y_org = y y = y.view(-1, y.shape[-1]) - # offset_ = offset_start if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) if self.is_prefill: # NOTE fused kernel - self._expand_nslices_prefill(y, - x, - lora_b_stacked, - offset_start, - add_input=True) + self._apply_expand_prefill( + y, x, lora_b_stacked, offset_start, add_input=True + ) else: for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( + self._apply_expand_decode( y, x[slice_idx], lora_b_stacked[slice_idx], From a306f424c933aa960ae9e8978057a5c776eac632 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 11 Dec 2024 08:35:55 +0000 Subject: [PATCH 04/35] shrink_sgmv Done Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_shrink.py | 89 +++++++++++++++++--------- vllm/lora/punica_wrapper/punica_gpu.py | 35 ++++++++-- 2 files changed, 88 insertions(+), 36 deletions(-) diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 37d1dc84eebca..50904f2fd87fe 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -5,6 +5,8 @@ https://arxiv.org/abs/2310.18547 """ +from typing import List + import torch import triton import triton.language as tl @@ -15,7 +17,7 @@ @triton.jit def _sgmv_shrink_kernel( input_ptr, - lora_ptr, + lora_ptr, #1-3 out_ptr, N, K, @@ -25,9 +27,10 @@ def _sgmv_shrink_kernel( scaling, xm_stride, # hidden_size xk_stride, # 1 - l0_stride, # hidden_size*max_rank - lora_k_stride, - lora_n_stride, + ls_d0_ptr, # hidden_size*max_rank + ls_d1_ptr, + ls_d2_ptr, + c0_stride, cm_stride, cn_stride, BLOCK_M: tl.constexpr, @@ -42,12 +45,13 @@ def _sgmv_shrink_kernel( introducing SPLIT-K can improve performance """ pid = tl.program_id(axis=0) - pid_sk = tl.program_id(axis=1) + pid_mix = tl.program_id(axis=1) cur_batch = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num - + slice_id = pid_mix // SPLIT_K + pid_sk = pid_mix % SPLIT_K M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: return @@ -64,8 +68,16 @@ def _sgmv_shrink_kernel( a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride) - b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + - offset_k[:, None] * lora_n_stride) + + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + rbn[None, :] * cur_lora_d1_stride + + offset_k[:, None] * cur_lora_d2_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): @@ -83,12 +95,13 @@ def _sgmv_shrink_kernel( accumulator += tl.dot(tiled_a, tiled_b) a_ptr += BLOCK_K * SPLIT_K * xk_stride - b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + b_ptr += BLOCK_K * SPLIT_K * cur_lora_d2_stride offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) + cur_out_ptr = out_ptr + slice_id * c0_stride + c_ptr = cur_out_ptr + offset_cm[:, None] * cm_stride + offset_cn[ + None, :] * cn_stride c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) accumulator *= scaling @@ -102,7 +115,7 @@ def _sgmv_shrink_kernel( @torch.inference_mode() def _sgmv_shrink( inputs: torch.Tensor, - lora_a_weights: torch.Tensor, + lora_a_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -134,27 +147,44 @@ def _sgmv_shrink( token numbers in the inputs matches the one in the metadata. scaling (float): Scaling factor. """ - assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype in [torch.float16, torch.bfloat16] - assert lora_a_weights.dtype in [ + assert lora_a_weights[0].dtype in [ torch.float16, torch.bfloat16, ] assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.size(1) == lora_a_weights[0].size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert inputs.is_contiguous() - - if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) - assert lora_a_weights.size(1) == 1 - lora_a_weights = lora_a_weights.squeeze(dim=1) - else: - assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) - assert lora_a_weights.is_contiguous() assert output_tensor.is_contiguous() + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=b_seq_start_loc.device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, + device=b_seq_start_loc.device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, + device=b_seq_start_loc.device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, + device=b_seq_start_loc.device) + # TODO tuning this config - N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank BLOCK_M = 32 BLOCK_N = 16 BLOCK_K = 32 @@ -162,13 +192,13 @@ def _sgmv_shrink( EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), - SPLIT_K, + SPLIT_K * len(lora_a_weights), batches, ) _sgmv_shrink_kernel[grid]( inputs, - lora_a_weights, + lora_ptr_tensor, output_tensor, N, K, @@ -178,11 +208,12 @@ def _sgmv_shrink( scaling, inputs.stride(0), inputs.stride(1), - lora_a_weights.stride(0), - lora_a_weights.stride(1), - lora_a_weights.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, output_tensor.stride(0), output_tensor.stride(1), + output_tensor.stride(2), BLOCK_M, BLOCK_N, BLOCK_K, @@ -194,7 +225,7 @@ def _sgmv_shrink( def sgmv_shrink_fake( inputs: torch.Tensor, - lora_a_weights: torch.Tensor, + lora_a_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index d125b09a071e0..4f986ba7cdf85 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -160,7 +160,7 @@ def _apply_expand( expand_slice_fun: Callable = (self._expand_slice_prefill if self.is_prefill else - self._expand_slice_decode) + self._apply_expand_decode) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, @@ -180,6 +180,21 @@ def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) + def _apply_shrink_nslices_prefill(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + + self._shrink_prefill(y, x, w_t_all, scale) + y = y.view_as(y_org) + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs): @@ -203,9 +218,13 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + if self.is_prefill: + # NOTE fused kernel + self._apply_shrink_nslices_prefill(y, x, lora_a_stacked, scale) + else: + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) def add_expand(self, y: torch.Tensor, @@ -242,9 +261,11 @@ def add_expand(self, lora_bias_stacked) if self.is_prefill: # NOTE fused kernel - self._apply_expand_prefill( - y, x, lora_b_stacked, offset_start, add_input=True - ) + self._apply_expand_prefill(y, + x, + lora_b_stacked, + offset_start, + add_input=True) else: for slice_idx in range(len(lora_b_stacked)): self._apply_expand_decode( From b6013db47b0e0304c4e1a1b4e8a398c9a41c2b3f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 12 Dec 2024 04:51:36 +0000 Subject: [PATCH 05/35] Optimize ptr compute Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand_slice.py | 85 +++++++++++++++++++----------- vllm/lora/ops/sgmv_shrink.py | 70 +++++++++++++++--------- 2 files changed, 98 insertions(+), 57 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 285a64cffb6c8..991ba0da36c35 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import List, Tuple +from typing import Dict, List, Tuple import torch import triton @@ -122,6 +122,51 @@ def _sgmv_expand_slice_kernel( tl.store(c_ptr, tiled_c, mask=c_mask) +_LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +#TODO Optimize +def _get_lora_ptr(lora_weights, offset_start, device): + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if _LORA_PTR_DICT.get(key) is None: + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + + _LORA_PTR_DICT[key] = ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + ) + return _LORA_PTR_DICT.get(key) + + @torch.inference_mode() def _sgmv_expand_slice( inputs: torch.Tensor, @@ -148,37 +193,13 @@ def _sgmv_expand_slice( assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert output_tensor.is_contiguous() - # TODO Optimize the following code - slice_offset_lst = [] - tensor_ptrs = [] - lora_strides_d0 = [] - lora_strides_d1 = [] - lora_strides_d2 = [] - slice_offset = offset_start - for lora_b_weight in lora_b_stacked: - if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weight.size(1) == 1 - lora_b_weight = lora_b_weight.squeeze(dim=1) - else: - assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) - assert lora_b_weight.is_contiguous() - tensor_ptrs.append(lora_b_weight.data_ptr()) - lora_strides_d0.append(lora_b_weight.stride(0)) - lora_strides_d1.append(lora_b_weight.stride(1)) - lora_strides_d2.append(lora_b_weight.stride(2)) - slice_offset_lst.append(slice_offset) - slice_offset += lora_b_weight.size(1) - - slice_start_tensor = torch.tensor(slice_offset_lst, - device=b_seq_start_loc.device) - # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=b_seq_start_loc.device) - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, - device=b_seq_start_loc.device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, - device=b_seq_start_loc.device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, - device=b_seq_start_loc.device) + ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + ) = _get_lora_ptr(lora_b_stacked, offset_start, b_seq_start_loc.device) # TODO tuning this config N, K = lora_b_stacked[0].shape[-2:] # K= rank,N=hidden_size diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 50904f2fd87fe..d2fdbf02da2b2 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import List +from typing import Dict, List, Tuple import torch import triton @@ -112,6 +112,44 @@ def _sgmv_shrink_kernel( tl.atomic_add(c_ptr, accumulator, mask=c_mask) +_LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +#TODO FIX THIS +def _get_lora_ptr(lora_a_weights, device): + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) + + if _LORA_PTR_DICT.get(key) is None: + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + _LORA_PTR_DICT[key] = ( + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + ) + return _LORA_PTR_DICT.get(key) + + @torch.inference_mode() def _sgmv_shrink( inputs: torch.Tensor, @@ -159,30 +197,12 @@ def _sgmv_shrink( assert lora_indices_tensor.size(0) == batches assert inputs.is_contiguous() assert output_tensor.is_contiguous() - lora_strides_d0 = [] - lora_strides_d1 = [] - lora_strides_d2 = [] - tensor_ptrs = [] - for lora_a_weight in lora_a_weights: - if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_a_weight.size(1) == 1 - lora_a_weight = lora_a_weight.squeeze(dim=1) - else: - assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) - assert lora_a_weight.is_contiguous() - tensor_ptrs.append(lora_a_weight.data_ptr()) - lora_strides_d0.append(lora_a_weight.stride(0)) - lora_strides_d1.append(lora_a_weight.stride(1)) - lora_strides_d2.append(lora_a_weight.stride(2)) - - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=b_seq_start_loc.device) - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, - device=b_seq_start_loc.device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, - device=b_seq_start_loc.device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, - device=b_seq_start_loc.device) - + ( + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + ) = _get_lora_ptr(lora_a_weights, b_seq_start_loc.device) # TODO tuning this config N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank BLOCK_M = 32 From 8d3742bae83e28787d936386d65131ad2a490838 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 13 Dec 2024 09:28:52 +0000 Subject: [PATCH 06/35] Increase the tile size Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand_slice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 991ba0da36c35..a26aa9783c4f8 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -204,9 +204,9 @@ def _sgmv_expand_slice( # TODO tuning this config N, K = lora_b_stacked[0].shape[-2:] # K= rank,N=hidden_size - BLOCK_M = 32 - BLOCK_N = 32 - BLOCK_K = 16 + BLOCK_M = 64 + BLOCK_N = 64 + BLOCK_K = 32 EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False From 9564b33d220afbcd6ba41fc5173137885c28f72a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 13 Dec 2024 12:49:12 +0000 Subject: [PATCH 07/35] Clean up triton interface Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 194 ++++++++++------- vllm/lora/ops/sgmv_expand_slice.py | 279 ------------------------- vllm/lora/punica_wrapper/punica_gpu.py | 157 +++----------- 3 files changed, 156 insertions(+), 474 deletions(-) delete mode 100644 vllm/lora/ops/sgmv_expand_slice.py diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 77c5178493c44..74b428a32e3ff 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -1,10 +1,12 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ +from typing import Dict, List, Tuple + import torch import triton import triton.language as tl @@ -22,11 +24,13 @@ def _sgmv_expand_kernel( b_seq_start_loc, seq_lens, lora_indices, - xm_stride, - xk_stride, # 1 - l0_stride, # hidden_size*max_rank - lora_k_stride, - lora_n_stride, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, # lora stride(0) + ls_d1_ptr, + ls_d2_ptr, cm_stride, cn_stride, BLOCK_M: tl.constexpr, @@ -37,10 +41,16 @@ def _sgmv_expand_kernel( CAST_TYPE: tl.constexpr, ): """ - The sgmv's expand triton kernel is based on GroupGEMM. + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) + slice_id = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num @@ -50,6 +60,7 @@ def _sgmv_expand_kernel( lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -57,10 +68,21 @@ def _sgmv_expand_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride, ) - b_ptr = (lora_ptr + l0_stride * lora_index + - offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + # input + cur_input_ptr = input_ptr + slice_id * input_d0_stride + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + # lora + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): if EVEN_K: @@ -74,34 +96,81 @@ def _sgmv_expand_kernel( mask=offset_k[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: - tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) accumulator += tl.dot( tiled_a, tiled_b, ) - a_ptr += BLOCK_K * xk_stride - b_ptr += BLOCK_K * lora_n_stride - tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + a_ptr += BLOCK_K * input_d2_stride + b_ptr += BLOCK_K * cur_lora_d2_stride + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + + cur_slice_start = tl.load(slice_start_loc + slice_id) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + offset_cn[None, :] * cn_stride) M = tl.load(seq_lens + cur_batch) c_mask = (offset_cm[:, None] < - (cur_seq_start + M)) & (offset_cn[None, :] < N) + (cur_seq_start + M)) & (offset_cn[None, :] < + (cur_slice_start + N)) if ADD_INPUTS: - # explicitly pass in other=None to tell triton that masked values - # can be uninitialized. This is OK because the later tl.store operation - # uses the same mask, eliminating the risk of garbage values propagating - tiled_out = tl.load(c_ptr, mask=c_mask, other=None) + tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) +_LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +#TODO Optimize +def _get_lora_ptr(lora_weights, offset_start, device): + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if _LORA_PTR_DICT.get(key) is None: + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + + _LORA_PTR_DICT[key] = ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + ) + return _LORA_PTR_DICT.get(key) + + @torch.inference_mode() def _sgmv_expand( inputs: torch.Tensor, - lora_b_weights: torch.Tensor, + lora_b_stacked: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -109,61 +178,40 @@ def _sgmv_expand( batches: int, max_seq_length: int, token_nums: int, + offset_start: int = 0, add_inputs: bool = False, ) -> None: - """ - Args: - inputs (torch.Tensor): input tensor - lora_b_weights (torch.Tensor): lora'a weight - output_tensor (torch.Tensor): output tensor - b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative - sequence lengths of the sequences in the batch, used to index - into sequence. E.g., if the sequence length is [4, 6], it is - [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence - length of the sequences in the batch. - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch. An index of -1 means no lora should be - applied. - batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences in the - batch. - token_nums (int): The token numbers in the batch. Used to verify if the - token numbers in the inputs matches the one in the metadata. - add_inputs (bool, optional): Defaults to False, adds the final lora - results to the output. - """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights.dtype in [ + assert lora_b_stacked[0].dtype in [ torch.float16, torch.bfloat16, ] - assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.size(1) == token_nums + assert inputs.size(0) == len(lora_b_stacked) + assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches - assert inputs.is_contiguous() assert output_tensor.is_contiguous() - - if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weights.size(1) == 1 - lora_b_weights = lora_b_weights.squeeze(dim=1) - else: - assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) - - assert lora_b_weights.is_contiguous() + ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + ) = _get_lora_ptr(lora_b_stacked, offset_start, b_seq_start_loc.device) # TODO tuning this config + N, K = lora_b_stacked[0].shape[-2:] # K= rank,N=hidden_size - N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size - BLOCK_M = 32 - BLOCK_N = 32 - BLOCK_K = 16 + BLOCK_M = 64 + BLOCK_N = 64 + BLOCK_K = 32 EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + + if inputs.dtype == torch.float32 and lora_b_stacked[0].dtype in [ torch.float16, torch.bfloat16, ]: @@ -171,21 +219,24 @@ def _sgmv_expand( grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), batches, + len(lora_ptr_tensor), ) _sgmv_expand_kernel[grid]( inputs, - lora_b_weights, + lora_ptr_tensor, output_tensor, N, K, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, + slice_start_tensor, inputs.stride(0), inputs.stride(1), - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, output_tensor.stride(0), output_tensor.stride(1), BLOCK_M, @@ -198,9 +249,9 @@ def _sgmv_expand( return -def sgmv_expand_fake( +def _sgmv_expand_fake( inputs: torch.Tensor, - lora_b_weights: torch.Tensor, + lora_b_stacked: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -208,18 +259,19 @@ def sgmv_expand_fake( batches: int, max_seq_length: int, token_nums: int, + slice_offset: int, + slice_size: int, add_inputs: bool = False, ) -> None: return try: - direct_register_custom_op( op_name="sgmv_expand", op_func=_sgmv_expand, mutates_args=["output_tensor"], - fake_impl=sgmv_expand_fake, + fake_impl=_sgmv_expand_fake, ) sgmv_expand = torch.ops.vllm.sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py deleted file mode 100644 index a26aa9783c4f8..0000000000000 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -from typing import Dict, List, Tuple - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - - -@triton.jit -def _sgmv_expand_slice_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - slice_start_loc, - input_d0_stride, - input_d1_stride, - input_d2_stride, # 1 - ls_d0_ptr, # lora stride(0) - ls_d1_ptr, - ls_d2_ptr, - cm_stride, - cn_stride, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, -): - """ - - Similar to the 'sgmv_expand' operator, but with an added parameter - 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator - might be that in the future, we could implement a fusion operator to - achieve the current functionality instead of having to call it multiple - times. - """ - pid = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - slice_id = tl.program_id(axis=2) - cta_n_num = tl.cdiv(N, BLOCK_N) - pid_m = pid // cta_n_num - pid_n = pid % cta_n_num - M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: - return - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = tl.arange(0, BLOCK_K) - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - - # input - cur_input_ptr = input_ptr + slice_id * input_d0_stride - a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - # lora - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + - rbn[None, :] * cur_lora_d1_stride) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(tl.cdiv(K, BLOCK_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) - if CAST_TYPE: - tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) - accumulator += tl.dot( - tiled_a, - tiled_b, - ) - a_ptr += BLOCK_K * input_d2_stride - b_ptr += BLOCK_K * cur_lora_d2_stride - - tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) - - cur_slice_start = tl.load(slice_start_loc + slice_id) - - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) - M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < - (cur_seq_start + M)) & (offset_cn[None, :] < - (cur_slice_start + N)) - if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) - tiled_c += tiled_out - tl.store(c_ptr, tiled_c, mask=c_mask) - - -_LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} - - -#TODO Optimize -def _get_lora_ptr(lora_weights, offset_start, device): - - key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) - if _LORA_PTR_DICT.get(key) is None: - slice_offset_lst = [] - tensor_ptrs = [] - lora_strides_d0 = [] - lora_strides_d1 = [] - lora_strides_d2 = [] - slice_offset = offset_start - for lora_b_weight in lora_weights: - if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weight.size(1) == 1 - lora_b_weight = lora_b_weight.squeeze(dim=1) - else: - assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) - assert lora_b_weight.is_contiguous() - tensor_ptrs.append(lora_b_weight.data_ptr()) - lora_strides_d0.append(lora_b_weight.stride(0)) - lora_strides_d1.append(lora_b_weight.stride(1)) - lora_strides_d2.append(lora_b_weight.stride(2)) - slice_offset_lst.append(slice_offset) - slice_offset += lora_b_weight.size(1) - - slice_start_tensor = torch.tensor(slice_offset_lst, device=device) - # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) - - _LORA_PTR_DICT[key] = ( - slice_start_tensor, - lora_ptr_tensor, - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - ) - return _LORA_PTR_DICT.get(key) - - -@torch.inference_mode() -def _sgmv_expand_slice( - inputs: torch.Tensor, - lora_b_stacked: List[torch.Tensor], - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - offset_start: int = 0, - add_inputs: bool = False, -) -> None: - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_stacked[0].dtype in [ - torch.float16, - torch.bfloat16, - ] - - assert inputs.size(1) == token_nums - assert inputs.size(0) == len(lora_b_stacked) - - assert b_seq_start_loc.size(0) == batches - assert lora_indices_tensor.size(0) == batches - assert output_tensor.is_contiguous() - ( - slice_start_tensor, - lora_ptr_tensor, - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - ) = _get_lora_ptr(lora_b_stacked, offset_start, b_seq_start_loc.device) - - # TODO tuning this config - N, K = lora_b_stacked[0].shape[-2:] # K= rank,N=hidden_size - - BLOCK_M = 64 - BLOCK_N = 64 - BLOCK_K = 32 - EVEN_K = K % BLOCK_K == 0 - ADD_INPUTS = add_inputs - CAST_TYPE = False - - if inputs.dtype == torch.float32 and lora_b_stacked[0].dtype in [ - torch.float16, - torch.bfloat16, - ]: - CAST_TYPE = True - grid = ( - triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), - batches, - len(lora_ptr_tensor), - ) - _sgmv_expand_slice_kernel[grid]( - inputs, - lora_ptr_tensor, - output_tensor, - N, - K, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - slice_start_tensor, - inputs.stride(0), - inputs.stride(1), - inputs.stride(2), - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - output_tensor.stride(0), - output_tensor.stride(1), - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - ADD_INPUTS, - CAST_TYPE, - ) - return - - -def _sgmv_expand_slice_fake( - inputs: torch.Tensor, - lora_b_stacked: Tuple[torch.Tensor, ...], - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="sgmv_expand_slice", - op_func=_sgmv_expand_slice, - mutates_args=["output_tensor"], - fake_impl=_sgmv_expand_slice_fake, - ) - sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice - -except AttributeError: - sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 4f986ba7cdf85..41dfa8daf7b2a 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Callable, Optional, Tuple, Union, final +from typing import Optional, Tuple, Union, final import torch @@ -15,8 +15,8 @@ from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink + # from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from .punica_base import PunicaWrapperBase @@ -35,11 +35,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - def _shrink_prefill( + def _apply_shrink_prefill( self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], scale: float, ): #No LoRA request, so return directly @@ -53,7 +53,7 @@ def _shrink_prefill( scale, ) - def _shrink_decode( + def _apply_shrink_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -62,73 +62,25 @@ def _shrink_decode( ): bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def _expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_input, - ) - - def _expand_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool, - ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) - - def _expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand_slice( - x, - w_t_all, - y, - *self.prefill_metadata, - y_offset, - y_slice_size, - add_input, - ) - def _apply_expand_prefill( self, y: torch.Tensor, x: torch.Tensor, - w_t_all: Tuple[torch.Tensor, ...], + w_t_all: torch.Tensor, offset_start: int, - add_input: bool, + add_inputs: bool, ): #No LoRA request, so return directly if self.no_lora: return - sgmv_expand_slice( + + sgmv_expand( x, w_t_all, y, *self.prefill_metadata, - offset_start, - add_input, + offset_start=offset_start, + add_inputs=add_inputs, ) def _apply_expand_decode( @@ -138,62 +90,10 @@ def _apply_expand_decode( w_t_all: torch.Tensor, y_offset: Optional[int], y_slice_size: Optional[int], - add_input: bool, + add_inputs: bool, ): bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_input) - - def _apply_expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True, - ): - """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` - computation, which is suitable for the - GEMM of lora'b. - """ - - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._apply_expand_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) - - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y = y.view_as(y_org) - - def _apply_shrink_nslices_prefill(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - - self._shrink_prefill(y, x, w_t_all, scale) - y = y.view_as(y_org) + y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -217,14 +117,15 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], """ x = x.view(-1, x.shape[-1]) - # TODO fuse these kernels + if self.is_prefill: # NOTE fused kernel - self._apply_shrink_nslices_prefill(y, x, lora_a_stacked, scale) + self._apply_shrink_prefill(y, x, lora_a_stacked, scale) else: + # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + self._apply_shrink_decode(y[slice_idx], x, + lora_a_stacked[slice_idx], scale) def add_expand(self, y: torch.Tensor, @@ -265,7 +166,7 @@ def add_expand(self, x, lora_b_stacked, offset_start, - add_input=True) + add_inputs=True) else: for slice_idx in range(len(lora_b_stacked)): self._apply_expand_decode( @@ -274,7 +175,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_start, output_slices[slice_idx], - add_input=add_input, + add_inputs=add_input, ) offset_start += output_slices[slice_idx] y = y.view_as(y_org) @@ -298,10 +199,18 @@ def add_lora_embedding(self, add_input (bool): Default to True. """ - # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_input) + if self.is_prefill: + sgmv_expand( + x.unsqueeze(dim=0), + [lora_b_stacked], + y, + *self.prefill_metadata, + offset_start=0, + add_inputs=add_input, + ) + else: + bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices, + add_input) def add_lora_linear(self, y: torch.Tensor, @@ -358,7 +267,7 @@ def add_lora_linear(self, lora_b_stacked, None, output_slices, - add_input=True, + add_inputs=True, **kwargs) pass From 40124669ee0b68b0419ae421f8fd1043e3b25b95 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 16 Dec 2024 06:47:40 +0000 Subject: [PATCH 08/35] Backup Signed-off-by: Jee Jee Li --- vllm/lora/punica_wrapper/punica_gpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index d8dbaa6daaeb0..c70300d32e2da 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -168,6 +168,7 @@ def add_expand(self, offset_start, add_inputs=True) else: + # TODO fuse these kernels for slice_idx in range(len(lora_b_stacked)): self._apply_expand_decode( y, From 18bbadf16227f27770214c01a86ebb813a590af3 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 16 Dec 2024 15:26:25 +0000 Subject: [PATCH 09/35] Optimize one sclice kernel Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_shrink.py | 126 +++++++++++++------------ vllm/lora/punica_wrapper/punica_gpu.py | 2 +- 2 files changed, 65 insertions(+), 63 deletions(-) diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index d2fdbf02da2b2..c0ffdf1bfd1e2 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -16,29 +16,29 @@ @triton.jit def _sgmv_shrink_kernel( - input_ptr, - lora_ptr, #1-3 - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - scaling, - xm_stride, # hidden_size - xk_stride, # 1 - ls_d0_ptr, # hidden_size*max_rank - ls_d1_ptr, - ls_d2_ptr, - c0_stride, - cm_stride, - cn_stride, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, -): + input_ptr, + lora_ptr, #1-3 + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + ls_d0_ptr, # hidden_size*max_rank + ls_d1_ptr, + ls_d2_ptr, + c0_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + NSLICE_NUM: tl.constexpr): """ The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, @@ -50,8 +50,14 @@ def _sgmv_shrink_kernel( cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num - slice_id = pid_mix // SPLIT_K - pid_sk = pid_mix % SPLIT_K + if NSLICE_NUM == 1: + slice_id = 0 + pid_sk = tl.program_id(axis=1) + else: + pid_mix = tl.program_id(axis=1) + slice_id = pid_mix // SPLIT_K + pid_sk = pid_mix % SPLIT_K + M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: return @@ -68,12 +74,17 @@ def _sgmv_shrink_kernel( a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride) - - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(input_ptr.dtype.element_ty)) - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + if NSLICE_NUM == 1: + cur_lora_ptr = lora_ptr + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + rbn[None, :] * cur_lora_d1_stride + @@ -99,7 +110,7 @@ def _sgmv_shrink_kernel( offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - cur_out_ptr = out_ptr + slice_id * c0_stride + cur_out_ptr = out_ptr if NSLICE_NUM == 1 else out_ptr + slice_id * c0_stride c_ptr = cur_out_ptr + offset_cm[:, None] * cm_stride + offset_cn[ None, :] * cn_stride c_mask = (offset_cm[:, None] < @@ -115,9 +126,7 @@ def _sgmv_shrink_kernel( _LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -#TODO FIX THIS def _get_lora_ptr(lora_a_weights, device): - key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) if _LORA_PTR_DICT.get(key) is None: @@ -137,10 +146,19 @@ def _get_lora_ptr(lora_a_weights, device): lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + if len(lora_a_weights) > 1: + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, + device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, + device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, + device=device) + else: + lora_ptr_tensor = lora_a_weights[0] + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] _LORA_PTR_DICT[key] = ( lora_ptr_tensor, lora_strides_d0_tensor, @@ -216,30 +234,14 @@ def _sgmv_shrink( batches, ) - _sgmv_shrink_kernel[grid]( - inputs, - lora_ptr_tensor, - output_tensor, - N, - K, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - scaling, - inputs.stride(0), - inputs.stride(1), - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - output_tensor.stride(0), - output_tensor.stride(1), - output_tensor.stride(2), - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - SPLIT_K, - ) + _sgmv_shrink_kernel[grid](inputs, lora_ptr_tensor, output_tensor, N, K, + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, scaling, inputs.stride(0), + inputs.stride(1), lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, + output_tensor.stride(0), output_tensor.stride(1), + output_tensor.stride(2), BLOCK_M, BLOCK_N, + BLOCK_K, EVEN_K, SPLIT_K, len(lora_a_weights)) return diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index c70300d32e2da..a27316ab1cd28 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -168,7 +168,7 @@ def add_expand(self, offset_start, add_inputs=True) else: - # TODO fuse these kernels + # TODO fuse these kernels for slice_idx in range(len(lora_b_stacked)): self._apply_expand_decode( y, From 43aae7025b24cef82adf8c758e3204c7e453836d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 16 Dec 2024 15:38:37 +0000 Subject: [PATCH 10/35] Delete unused code Signed-off-by: Jee Jee Li --- vllm/lora/punica_wrapper/punica_gpu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index a27316ab1cd28..076a563583b62 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -15,7 +15,6 @@ from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink - # from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_shrink import sgmv_shrink From 482de154c1419408aa3f7e1a5c61b9c7640e79d5 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 16 Dec 2024 16:12:40 +0000 Subject: [PATCH 11/35] Refactor expand Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 161 ++++++++++++++++++----------------- vllm/lora/ops/sgmv_shrink.py | 10 +-- 2 files changed, 89 insertions(+), 82 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 74b428a32e3ff..44feeaa3b69e1 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -16,30 +16,30 @@ @triton.jit def _sgmv_expand_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - slice_start_loc, - input_d0_stride, - input_d1_stride, - input_d2_stride, # 1 - ls_d0_ptr, # lora stride(0) - ls_d1_ptr, - ls_d2_ptr, - cm_stride, - cn_stride, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, -): + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, # lora stride(0) + ls_d1_ptr, + ls_d2_ptr, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr): """ Similar to the 'sgmv_expand' operator, but with an added parameter @@ -50,7 +50,10 @@ def _sgmv_expand_kernel( """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) - slice_id = tl.program_id(axis=2) + if SLICE_NUM == 1: + slice_id: tl.constexpr = 0 + else: + slice_id = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num @@ -69,16 +72,26 @@ def _sgmv_expand_kernel( rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) # input - cur_input_ptr = input_ptr + slice_id * input_d0_stride - a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - # lora - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + if SLICE_NUM == 1: + a_ptr = (input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + # lora + cur_lora_ptr = lora_ptr + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + # lora + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + offset_k[:, None] * cur_lora_d2_stride + @@ -105,8 +118,10 @@ def _sgmv_expand_kernel( b_ptr += BLOCK_K * cur_lora_d2_stride tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) - - cur_slice_start = tl.load(slice_start_loc + slice_id) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start @@ -149,14 +164,22 @@ def _get_lora_ptr(lora_weights, offset_start, device): lora_strides_d2.append(lora_b_weight.stride(2)) slice_offset_lst.append(slice_offset) slice_offset += lora_b_weight.size(1) - - slice_start_tensor = torch.tensor(slice_offset_lst, device=device) - # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) - + if len(lora_weights) > 1: + # note these are device tensors + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, + device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, + device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, + device=device) + else: + slice_start_tensor = slice_offset_lst[0] + lora_ptr_tensor = lora_weights[0] + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] _LORA_PTR_DICT[key] = ( slice_start_tensor, lora_ptr_tensor, @@ -170,7 +193,7 @@ def _get_lora_ptr(lora_weights, offset_start, device): @torch.inference_mode() def _sgmv_expand( inputs: torch.Tensor, - lora_b_stacked: List[torch.Tensor], + lora_b_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -182,13 +205,13 @@ def _sgmv_expand( add_inputs: bool = False, ) -> None: assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_stacked[0].dtype in [ + assert lora_b_weights[0].dtype in [ torch.float16, torch.bfloat16, ] assert inputs.size(1) == token_nums - assert inputs.size(0) == len(lora_b_stacked) + assert inputs.size(0) == len(lora_b_weights) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches @@ -199,10 +222,10 @@ def _sgmv_expand( lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, - ) = _get_lora_ptr(lora_b_stacked, offset_start, b_seq_start_loc.device) + ) = _get_lora_ptr(lora_b_weights, offset_start, b_seq_start_loc.device) # TODO tuning this config - N, K = lora_b_stacked[0].shape[-2:] # K= rank,N=hidden_size + N, K = lora_b_weights[0].shape[-2:] # K= rank,N=hidden_size BLOCK_M = 64 BLOCK_N = 64 @@ -211,7 +234,7 @@ def _sgmv_expand( ADD_INPUTS = add_inputs CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_stacked[0].dtype in [ + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ torch.float16, torch.bfloat16, ]: @@ -221,37 +244,21 @@ def _sgmv_expand( batches, len(lora_ptr_tensor), ) - _sgmv_expand_kernel[grid]( - inputs, - lora_ptr_tensor, - output_tensor, - N, - K, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - slice_start_tensor, - inputs.stride(0), - inputs.stride(1), - inputs.stride(2), - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - output_tensor.stride(0), - output_tensor.stride(1), - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - ADD_INPUTS, - CAST_TYPE, - ) + _sgmv_expand_kernel[grid](inputs, lora_ptr_tensor, output_tensor, N, K, + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, slice_start_tensor, + inputs.stride(0), inputs.stride(1), + inputs.stride(2), lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, + output_tensor.stride(0), output_tensor.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, ADD_INPUTS, + CAST_TYPE, len(lora_b_weights)) return def _sgmv_expand_fake( inputs: torch.Tensor, - lora_b_stacked: List[torch.Tensor], + lora_b_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c0ffdf1bfd1e2..6e3aa72ca8ea9 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -38,7 +38,7 @@ def _sgmv_shrink_kernel( BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, - NSLICE_NUM: tl.constexpr): + SLICE_NUM: tl.constexpr): """ The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, @@ -50,8 +50,8 @@ def _sgmv_shrink_kernel( cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num - if NSLICE_NUM == 1: - slice_id = 0 + if SLICE_NUM == 1: + slice_id: tl.constexpr = 0 pid_sk = tl.program_id(axis=1) else: pid_mix = tl.program_id(axis=1) @@ -74,7 +74,7 @@ def _sgmv_shrink_kernel( a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride) - if NSLICE_NUM == 1: + if SLICE_NUM == 1: cur_lora_ptr = lora_ptr cur_lora_d0_stride = ls_d0_ptr cur_lora_d1_stride = ls_d1_ptr @@ -110,7 +110,7 @@ def _sgmv_shrink_kernel( offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - cur_out_ptr = out_ptr if NSLICE_NUM == 1 else out_ptr + slice_id * c0_stride + cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * c0_stride c_ptr = cur_out_ptr + offset_cm[:, None] * cm_stride + offset_cn[ None, :] * cn_stride c_mask = (offset_cm[:, None] < From 259d382f1e93440b0c827eae32e93d9b41fdd6f3 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 16 Dec 2024 16:32:55 +0000 Subject: [PATCH 12/35] format Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 39 ++++++++++++++++++++++++++---------- vllm/lora/ops/sgmv_shrink.py | 37 +++++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 44feeaa3b69e1..cbde5d5ab47a2 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -141,7 +141,7 @@ def _sgmv_expand_kernel( #TODO Optimize -def _get_lora_ptr(lora_weights, offset_start, device): +def _get_lora_b_ptr(lora_weights, offset_start, device): key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) if _LORA_PTR_DICT.get(key) is None: @@ -222,7 +222,7 @@ def _sgmv_expand( lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, - ) = _get_lora_ptr(lora_b_weights, offset_start, b_seq_start_loc.device) + ) = _get_lora_b_ptr(lora_b_weights, offset_start, b_seq_start_loc.device) # TODO tuning this config N, K = lora_b_weights[0].shape[-2:] # K= rank,N=hidden_size @@ -244,15 +244,32 @@ def _sgmv_expand( batches, len(lora_ptr_tensor), ) - _sgmv_expand_kernel[grid](inputs, lora_ptr_tensor, output_tensor, N, K, - b_seq_start_loc, seq_len_tensor, - lora_indices_tensor, slice_start_tensor, - inputs.stride(0), inputs.stride(1), - inputs.stride(2), lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, - output_tensor.stride(0), output_tensor.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, ADD_INPUTS, - CAST_TYPE, len(lora_b_weights)) + _sgmv_expand_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + len(lora_b_weights), + ) return diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 6e3aa72ca8ea9..c539fcb36d679 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -126,7 +126,7 @@ def _sgmv_shrink_kernel( _LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -def _get_lora_ptr(lora_a_weights, device): +def _get_lora_a_ptr(lora_a_weights, device): key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) if _LORA_PTR_DICT.get(key) is None: @@ -220,7 +220,7 @@ def _sgmv_shrink( lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, - ) = _get_lora_ptr(lora_a_weights, b_seq_start_loc.device) + ) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) # TODO tuning this config N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank BLOCK_M = 32 @@ -234,14 +234,31 @@ def _sgmv_shrink( batches, ) - _sgmv_shrink_kernel[grid](inputs, lora_ptr_tensor, output_tensor, N, K, - b_seq_start_loc, seq_len_tensor, - lora_indices_tensor, scaling, inputs.stride(0), - inputs.stride(1), lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, - output_tensor.stride(0), output_tensor.stride(1), - output_tensor.stride(2), BLOCK_M, BLOCK_N, - BLOCK_K, EVEN_K, SPLIT_K, len(lora_a_weights)) + _sgmv_shrink_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + len(lora_a_weights), + ) return From a0197e3e4f6a17f9d34ffcd39c8c254aecd1822d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 17 Dec 2024 03:55:21 +0000 Subject: [PATCH 13/35] Optimize logic Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 48 +++++++++++++++++++++++++----------- vllm/lora/ops/sgmv_shrink.py | 44 +++++++++++++++++++++------------ 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index cbde5d5ab47a2..1f4646f05b6b7 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -39,7 +39,8 @@ def _sgmv_expand_kernel( EVEN_K: tl.constexpr, ADD_INPUTS: tl.constexpr, CAST_TYPE: tl.constexpr, - SLICE_NUM: tl.constexpr): + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): """ Similar to the 'sgmv_expand' operator, but with an added parameter @@ -71,17 +72,25 @@ def _sgmv_expand_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - # input + if SAME_STRIDE: + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + if SLICE_NUM == 1: + # input a_ptr = (input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + offset_k[None, :] * input_d2_stride, ) # lora cur_lora_ptr = lora_ptr - cur_lora_d0_stride = ls_d0_ptr - cur_lora_d1_stride = ls_d1_ptr - cur_lora_d2_stride = ls_d2_ptr + else: + # input cur_input_ptr = input_ptr + slice_id * input_d0_stride a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + @@ -89,9 +98,6 @@ def _sgmv_expand_kernel( # lora cur_lora_ptr = tl.load(lora_ptr + slice_id).to( tl.pointer_type(out_ptr.dtype.element_ty)) - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + offset_k[:, None] * cur_lora_d2_stride + @@ -164,28 +170,40 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): lora_strides_d2.append(lora_b_weight.stride(2)) slice_offset_lst.append(slice_offset) slice_offset += lora_b_weight.size(1) + if len(lora_weights) > 1: # note these are device tensors slice_start_tensor = torch.tensor(slice_offset_lst, device=device) lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, - device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, - device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, - device=device) + else: slice_start_tensor = slice_offset_lst[0] lora_ptr_tensor = lora_weights[0] + + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if ((set(lora_strides_d0) == 1) and (set(lora_strides_d1) == 1) + and (set(lora_strides_d2) == 1)): lora_strides_d0_tensor = lora_strides_d0[0] lora_strides_d1_tensor = lora_strides_d1[0] lora_strides_d2_tensor = lora_strides_d2[0] + same_stride = True + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, + device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, + device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, + device=device) + same_stride = False + _LORA_PTR_DICT[key] = ( slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, + same_stride, ) return _LORA_PTR_DICT.get(key) @@ -222,6 +240,7 @@ def _sgmv_expand( lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, + same_stride, ) = _get_lora_b_ptr(lora_b_weights, offset_start, b_seq_start_loc.device) # TODO tuning this config @@ -269,6 +288,7 @@ def _sgmv_expand( ADD_INPUTS, CAST_TYPE, len(lora_b_weights), + same_stride, ) return diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c539fcb36d679..dcbac4b82d2bf 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -38,7 +38,8 @@ def _sgmv_shrink_kernel( BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, - SLICE_NUM: tl.constexpr): + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): """ The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, @@ -74,18 +75,21 @@ def _sgmv_shrink_kernel( a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride) - if SLICE_NUM == 1: - cur_lora_ptr = lora_ptr + if SAME_STRIDE: cur_lora_d0_stride = ls_d0_ptr cur_lora_d1_stride = ls_d1_ptr cur_lora_d2_stride = ls_d2_ptr else: - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(input_ptr.dtype.element_ty)) cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + if SLICE_NUM == 1: + cur_lora_ptr = lora_ptr + else: + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + rbn[None, :] * cur_lora_d1_stride + offset_k[:, None] * cur_lora_d2_stride) @@ -148,23 +152,29 @@ def _get_lora_a_ptr(lora_a_weights, device): if len(lora_a_weights) > 1: lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + + else: + lora_ptr_tensor = lora_a_weights[0] + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if ((set(lora_strides_d0) == 1) and (set(lora_strides_d1) == 1) + and (set(lora_strides_d2) == 1)): + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + same_stride = True + else: lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) - else: - lora_ptr_tensor = lora_a_weights[0] - lora_strides_d0_tensor = lora_strides_d0[0] - lora_strides_d1_tensor = lora_strides_d1[0] - lora_strides_d2_tensor = lora_strides_d2[0] - _LORA_PTR_DICT[key] = ( - lora_ptr_tensor, - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - ) + same_stride = False + + _LORA_PTR_DICT[key] = (lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, + same_stride) return _LORA_PTR_DICT.get(key) @@ -220,6 +230,7 @@ def _sgmv_shrink( lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, + same_stride, ) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) # TODO tuning this config N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank @@ -258,6 +269,7 @@ def _sgmv_shrink( EVEN_K, SPLIT_K, len(lora_a_weights), + same_stride, ) return From 38ba4f1c71dc27e8d188a199ad4f2e635e11bb57 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 17 Dec 2024 04:53:29 +0000 Subject: [PATCH 14/35] Add comments Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 18 ++++++------------ vllm/lora/ops/sgmv_shrink.py | 4 +++- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 1f4646f05b6b7..e92d44d1e9781 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -80,13 +80,13 @@ def _sgmv_expand_kernel( cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - + if SLICE_NUM == 1: # input a_ptr = (input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + offset_k[None, :] * input_d2_stride, ) - # lora + # current lora ptr cur_lora_ptr = lora_ptr else: @@ -95,7 +95,7 @@ def _sgmv_expand_kernel( a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + offset_k[None, :] * input_d2_stride, ) - # lora + # current lora ptr cur_lora_ptr = tl.load(lora_ptr + slice_id).to( tl.pointer_type(out_ptr.dtype.element_ty)) @@ -146,7 +146,6 @@ def _sgmv_expand_kernel( _LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -#TODO Optimize def _get_lora_b_ptr(lora_weights, offset_start, device): key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) @@ -197,14 +196,9 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): device=device) same_stride = False - _LORA_PTR_DICT[key] = ( - slice_start_tensor, - lora_ptr_tensor, - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - same_stride, - ) + _LORA_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, + lora_strides_d0_tensor, lora_strides_d1_tensor, + lora_strides_d2_tensor, same_stride) return _LORA_PTR_DICT.get(key) diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index dcbac4b82d2bf..496230967b26a 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -72,7 +72,7 @@ def _sgmv_shrink_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - + # input ptr a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride) if SAME_STRIDE: @@ -85,8 +85,10 @@ def _sgmv_shrink_kernel( cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) if SLICE_NUM == 1: + # current lora ptr cur_lora_ptr = lora_ptr else: + # current lora ptr cur_lora_ptr = tl.load(lora_ptr + slice_id).to( tl.pointer_type(input_ptr.dtype.element_ty)) From 3c372265bb639015d155382311fb0d85d16abf5b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 17 Dec 2024 05:24:15 +0000 Subject: [PATCH 15/35] Fix bug Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 7 ++++--- vllm/lora/ops/sgmv_shrink.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index e92d44d1e9781..2946f03662d79 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -181,12 +181,13 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): # If each lora has the same stride, there's no need to use a # tensor for storage. - if ((set(lora_strides_d0) == 1) and (set(lora_strides_d1) == 1) - and (set(lora_strides_d2) == 1)): + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 + and len(set(lora_strides_d2)) == 1): lora_strides_d0_tensor = lora_strides_d0[0] lora_strides_d1_tensor = lora_strides_d1[0] lora_strides_d2_tensor = lora_strides_d2[0] same_stride = True + else: lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) @@ -195,7 +196,7 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) same_stride = False - + print(f"same_stride:{same_stride}") _LORA_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, same_stride) diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 496230967b26a..fd3e36bf53e8f 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -159,8 +159,8 @@ def _get_lora_a_ptr(lora_a_weights, device): lora_ptr_tensor = lora_a_weights[0] # If each lora has the same stride, there's no need to use a # tensor for storage. - if ((set(lora_strides_d0) == 1) and (set(lora_strides_d1) == 1) - and (set(lora_strides_d2) == 1)): + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 + and len(set(lora_strides_d2)) == 1): lora_strides_d0_tensor = lora_strides_d0[0] lora_strides_d1_tensor = lora_strides_d1[0] lora_strides_d2_tensor = lora_strides_d2[0] From 45180c13a20032f9745b6a330c1ea506a9d0ed7b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 17 Dec 2024 10:00:26 +0000 Subject: [PATCH 16/35] Fix expand bug Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 2946f03662d79..8096d3c517b93 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -51,10 +51,8 @@ def _sgmv_expand_kernel( """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) - if SLICE_NUM == 1: - slice_id: tl.constexpr = 0 - else: - slice_id = tl.program_id(axis=2) + + slice_id = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num @@ -80,25 +78,18 @@ def _sgmv_expand_kernel( cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - if SLICE_NUM == 1: - # input - a_ptr = (input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - # current lora ptr + cur_input_ptr = input_ptr cur_lora_ptr = lora_ptr else: - # input cur_input_ptr = input_ptr + slice_id * input_d0_stride - a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - # current lora ptr cur_lora_ptr = tl.load(lora_ptr + slice_id).to( tl.pointer_type(out_ptr.dtype.element_ty)) + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + offset_k[:, None] * cur_lora_d2_stride + rbn[None, :] * cur_lora_d1_stride) @@ -171,13 +162,12 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): slice_offset += lora_b_weight.size(1) if len(lora_weights) > 1: - # note these are device tensors - slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + #note these are device tensors lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) else: slice_start_tensor = slice_offset_lst[0] - lora_ptr_tensor = lora_weights[0] + lora_ptr_tensor = lora_b_weight[0] # If each lora has the same stride, there's no need to use a # tensor for storage. @@ -196,7 +186,6 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) same_stride = False - print(f"same_stride:{same_stride}") _LORA_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, lora_strides_d1_tensor, lora_strides_d2_tensor, same_stride) @@ -256,7 +245,7 @@ def _sgmv_expand( grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), batches, - len(lora_ptr_tensor), + len(lora_b_weights), ) _sgmv_expand_kernel[grid]( inputs, From 2e52d2c420a01997a2001fa0900f412bc1360ef7 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 17 Dec 2024 13:22:58 +0000 Subject: [PATCH 17/35] Backup Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 8096d3c517b93..cfef692f8abd6 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -162,7 +162,7 @@ def _get_lora_b_ptr(lora_weights, offset_start, device): slice_offset += lora_b_weight.size(1) if len(lora_weights) > 1: - #note these are device tensors + # note these are device tensors lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) slice_start_tensor = torch.tensor(slice_offset_lst, device=device) else: From 2146141b459486cbbcc7b0bd0c7087a4ea67ab04 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 17 Dec 2024 14:41:48 +0000 Subject: [PATCH 18/35] revert expand tile size Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index cfef692f8abd6..02bbcebfd4734 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -230,9 +230,9 @@ def _sgmv_expand( # TODO tuning this config N, K = lora_b_weights[0].shape[-2:] # K= rank,N=hidden_size - BLOCK_M = 64 - BLOCK_N = 64 - BLOCK_K = 32 + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False From 9719617c36e850270c7ba43e855bf4ec6728ce41 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 18 Dec 2024 08:39:13 +0000 Subject: [PATCH 19/35] Clean up code Signed-off-by: Jee Jee Li --- requirements-common.txt | 2 +- vllm/lora/ops/sgmv_expand.py | 62 +----------------- vllm/lora/ops/sgmv_shrink.py | 55 +--------------- vllm/lora/ops/utils.py | 122 ++++++++++++++++++++++++++++++++++- 4 files changed, 128 insertions(+), 113 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index bd2b4b7a01668..43ebed471db2c 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,7 +11,7 @@ protobuf # Required by LlamaTokenizer. fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' aiohttp -openai >= 1.45.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) +openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) uvicorn[standard] pydantic >= 2.9 # Required for fastapi >= 0.113.0 pillow # Required for image processing diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 02bbcebfd4734..4cb7f2dc7ccf0 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Dict, List, Tuple +from typing import List import torch import triton @@ -13,6 +13,8 @@ from vllm.utils import direct_register_custom_op +from .utils import _get_lora_b_ptr + @triton.jit def _sgmv_expand_kernel( @@ -134,64 +136,6 @@ def _sgmv_expand_kernel( tl.store(c_ptr, tiled_c, mask=c_mask) -_LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} - - -def _get_lora_b_ptr(lora_weights, offset_start, device): - - key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) - if _LORA_PTR_DICT.get(key) is None: - slice_offset_lst = [] - tensor_ptrs = [] - lora_strides_d0 = [] - lora_strides_d1 = [] - lora_strides_d2 = [] - slice_offset = offset_start - for lora_b_weight in lora_weights: - if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weight.size(1) == 1 - lora_b_weight = lora_b_weight.squeeze(dim=1) - else: - assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) - assert lora_b_weight.is_contiguous() - tensor_ptrs.append(lora_b_weight.data_ptr()) - lora_strides_d0.append(lora_b_weight.stride(0)) - lora_strides_d1.append(lora_b_weight.stride(1)) - lora_strides_d2.append(lora_b_weight.stride(2)) - slice_offset_lst.append(slice_offset) - slice_offset += lora_b_weight.size(1) - - if len(lora_weights) > 1: - # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - slice_start_tensor = torch.tensor(slice_offset_lst, device=device) - else: - slice_start_tensor = slice_offset_lst[0] - lora_ptr_tensor = lora_b_weight[0] - - # If each lora has the same stride, there's no need to use a - # tensor for storage. - if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 - and len(set(lora_strides_d2)) == 1): - lora_strides_d0_tensor = lora_strides_d0[0] - lora_strides_d1_tensor = lora_strides_d1[0] - lora_strides_d2_tensor = lora_strides_d2[0] - same_stride = True - - else: - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, - device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, - device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, - device=device) - same_stride = False - _LORA_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, - lora_strides_d0_tensor, lora_strides_d1_tensor, - lora_strides_d2_tensor, same_stride) - return _LORA_PTR_DICT.get(key) - - @torch.inference_mode() def _sgmv_expand( inputs: torch.Tensor, diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index fd3e36bf53e8f..d1b5141443715 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Dict, List, Tuple +from typing import List import torch import triton @@ -13,6 +13,8 @@ from vllm.utils import direct_register_custom_op +from .utils import _get_lora_a_ptr + @triton.jit def _sgmv_shrink_kernel( @@ -129,57 +131,6 @@ def _sgmv_shrink_kernel( tl.atomic_add(c_ptr, accumulator, mask=c_mask) -_LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} - - -def _get_lora_a_ptr(lora_a_weights, device): - key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) - - if _LORA_PTR_DICT.get(key) is None: - lora_strides_d0 = [] - lora_strides_d1 = [] - lora_strides_d2 = [] - tensor_ptrs = [] - for lora_a_weight in lora_a_weights: - if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_a_weight.size(1) == 1 - lora_a_weight = lora_a_weight.squeeze(dim=1) - else: - assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) - assert lora_a_weight.is_contiguous() - tensor_ptrs.append(lora_a_weight.data_ptr()) - lora_strides_d0.append(lora_a_weight.stride(0)) - lora_strides_d1.append(lora_a_weight.stride(1)) - lora_strides_d2.append(lora_a_weight.stride(2)) - - if len(lora_a_weights) > 1: - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - - else: - lora_ptr_tensor = lora_a_weights[0] - # If each lora has the same stride, there's no need to use a - # tensor for storage. - if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 - and len(set(lora_strides_d2)) == 1): - lora_strides_d0_tensor = lora_strides_d0[0] - lora_strides_d1_tensor = lora_strides_d1[0] - lora_strides_d2_tensor = lora_strides_d2[0] - same_stride = True - else: - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, - device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, - device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, - device=device) - same_stride = False - - _LORA_PTR_DICT[key] = (lora_ptr_tensor, lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, - same_stride) - return _LORA_PTR_DICT.get(key) - - @torch.inference_mode() def _sgmv_shrink( inputs: torch.Tensor, diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py index 7c3e27313ad97..635840beb9a9f 100644 --- a/vllm/lora/ops/utils.py +++ b/vllm/lora/ops/utils.py @@ -1,5 +1,7 @@ import functools -from typing import Dict +from typing import Dict, Tuple + +import torch @functools.lru_cache @@ -44,3 +46,121 @@ def get_lora_op_configs(op_type: str, batch: int, if not config: config = _get_default_config(op_type, batch, hidden_size) return config + + +_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +def _get_lora_a_ptr(lora_a_weights, device): + """ + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) + + if values := _LORA_A_PTR_DICT.get(key): + return values + + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + + if len(lora_a_weights) > 1: + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + + else: + lora_ptr_tensor = lora_a_weights[0] + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 + and len(set(lora_strides_d2)) == 1): + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + same_stride = True + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + same_stride = False + + _LORA_A_PTR_DICT[key] = (lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, + same_stride) + return _LORA_A_PTR_DICT.get(key) + + +def _get_lora_b_ptr(lora_weights, offset_start, device): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, + and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + + """ + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if values := _LORA_B_PTR_DICT.get(key): + return values + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + + if len(lora_weights) > 1: + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + else: + slice_start_tensor = slice_offset_lst[0] + lora_ptr_tensor = lora_b_weight[0] + + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 + and len(set(lora_strides_d2)) == 1): + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + same_stride = True + + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + same_stride = False + + _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, + lora_strides_d0_tensor, lora_strides_d1_tensor, + lora_strides_d2_tensor, same_stride) + return _LORA_B_PTR_DICT.get(key) From 5d2c557a55e0a6fa3dd736c054e72c150de6ff1c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 18 Dec 2024 09:54:04 +0000 Subject: [PATCH 20/35] Optimize expand tile size Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 4cb7f2dc7ccf0..6a5e1d697c236 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -174,8 +174,8 @@ def _sgmv_expand( # TODO tuning this config N, K = lora_b_weights[0].shape[-2:] # K= rank,N=hidden_size - BLOCK_M = 32 - BLOCK_N = 32 + BLOCK_M = 64 + BLOCK_N = 128 BLOCK_K = 16 EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs From 3460308fba44b7449a5f04798215095c38dc5034 Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Thu, 19 Dec 2024 21:17:25 +0800 Subject: [PATCH 21/35] improve expand (#3) Signed-off-by: Abatom --- vllm/lora/ops/sgmv_expand.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 6a5e1d697c236..c1f100c541e38 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -53,12 +53,20 @@ def _sgmv_expand_kernel( """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) - slice_id = tl.program_id(axis=2) - cta_n_num = tl.cdiv(N, BLOCK_N) - pid_m = pid // cta_n_num - pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + GROUP_SIZE_M: tl.constexpr = 1 + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m * BLOCK_M > M: return lora_index = tl.load(lora_indices + cur_batch) From c9747c632e96733e6d7ef78ca5a9d4a4816f569a Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Fri, 20 Dec 2024 11:48:28 +0800 Subject: [PATCH 22/35] Lora expand (#4) * L2 Signed-off-by: Abatom * L2 Signed-off-by: Abatom --------- Signed-off-by: Abatom --- vllm/lora/ops/sgmv_expand.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index c1f100c541e38..36dbbc54d62f2 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -57,15 +57,16 @@ def _sgmv_expand_kernel( M = tl.load(seq_lens + cur_batch) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - GROUP_SIZE_M: tl.constexpr = 1 - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + GROUP_M: tl.constexpr = 1 + width = GROUP_M * grid_n + group_id = pid // width + first_pid_m = group_id * GROUP_M + group_idx = pid % width + group_size_m = min(grid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (group_idx % group_size_m) + pid_n = group_idx // group_size_m if pid_m * BLOCK_M > M: return @@ -79,6 +80,7 @@ def _sgmv_expand_kernel( offset_k = tl.arange(0, BLOCK_K) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + rak = tl.max_contiguous(tl.multiple_of(offset_k % K, BLOCK_K), BLOCK_K) if SAME_STRIDE: cur_lora_d0_stride = ls_d0_ptr @@ -99,9 +101,9 @@ def _sgmv_expand_kernel( a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) + rak[None, :] * input_d2_stride, ) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + + rak[:, None] * cur_lora_d2_stride + rbn[None, :] * cur_lora_d1_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): @@ -110,10 +112,10 @@ def _sgmv_expand_kernel( tiled_b = tl.load(b_ptr) else: tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, + mask=rak[None, :] < K - k * BLOCK_K, other=0) tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, + mask=rak[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) From f3ecfc64acff2232421e8cb3a5d4f71d524f064e Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Fri, 20 Dec 2024 14:01:27 +0800 Subject: [PATCH 23/35] Lora expand (#5) * L2 Signed-off-by: Abatom * L2 Signed-off-by: Abatom * L2 Signed-off-by: Abatom --------- Signed-off-by: Abatom From 5859da77873558a594571a98ab0a3f084e2b2642 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 20 Dec 2024 06:56:29 +0000 Subject: [PATCH 24/35] Fix K size Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 36dbbc54d62f2..cbfeb0ec6de02 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -80,7 +80,6 @@ def _sgmv_expand_kernel( offset_k = tl.arange(0, BLOCK_K) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - rak = tl.max_contiguous(tl.multiple_of(offset_k % K, BLOCK_K), BLOCK_K) if SAME_STRIDE: cur_lora_d0_stride = ls_d0_ptr @@ -101,9 +100,9 @@ def _sgmv_expand_kernel( a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + ram[:, None] * input_d1_stride + - rak[None, :] * input_d2_stride, ) + offset_k[None, :] * input_d2_stride, ) b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - rak[:, None] * cur_lora_d2_stride + + offset_k[:, None] * cur_lora_d2_stride + rbn[None, :] * cur_lora_d1_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): @@ -112,10 +111,10 @@ def _sgmv_expand_kernel( tiled_b = tl.load(b_ptr) else: tiled_a = tl.load(a_ptr, - mask=rak[None, :] < K - k * BLOCK_K, + mask=offset_k[None, :] < K - k * BLOCK_K, other=0) tiled_b = tl.load(b_ptr, - mask=rak[:, None] < K - k * BLOCK_K, + mask=offset_k[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) From ebc9519ba7c7371a8d9771d95e643a10a83f3838 Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Tue, 24 Dec 2024 10:57:08 +0800 Subject: [PATCH 25/35] revert (#6) Signed-off-by: Abatom --- vllm/lora/ops/sgmv_expand.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index cbfeb0ec6de02..2aebdfa964619 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -55,19 +55,10 @@ def _sgmv_expand_kernel( cur_batch = tl.program_id(axis=1) slice_id = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num M = tl.load(seq_lens + cur_batch) - - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - GROUP_M: tl.constexpr = 1 - width = GROUP_M * grid_n - group_id = pid // width - first_pid_m = group_id * GROUP_M - group_idx = pid % width - group_size_m = min(grid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (group_idx % group_size_m) - pid_n = group_idx // group_size_m - if pid_m * BLOCK_M > M: return lora_index = tl.load(lora_indices + cur_batch) From ba2c4442ea4f357e596553ad475eb1fa77fe10e6 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 24 Dec 2024 10:34:06 +0000 Subject: [PATCH 26/35] Add unit test Signed-off-by: Jee Jee Li --- tests/lora/test_punica_sizes.py | 75 ++++++++++++++++++++++++--------- vllm/lora/ops/sgmv_expand.py | 3 +- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97d..d7ea44d23eee6 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -11,8 +11,8 @@ from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.utils import _LORA_B_PTR_DICT from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -172,9 +172,10 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": + our_out_tensor = our_out_tensor.unsqueeze(dim=0) sgmv_shrink( inputs_tensor, - lora_weights, + (lora_weights, ), our_out_tensor, b_seq_start_loc, seq_len_tensor, @@ -184,10 +185,13 @@ def test_punica_sgmv( token_nums, scaling, ) + our_out_tensor = our_out_tensor.squeeze(dim=0) else: + inputs_tensor = inputs_tensor.unsqueeze(dim=0) + sgmv_expand( inputs_tensor, - lora_weights, + (lora_weights, ), our_out_tensor, b_seq_start_loc, seq_len_tensor, @@ -197,6 +201,7 @@ def test_punica_sgmv( token_nums, add_inputs=True, ) + inputs_tensor = inputs_tensor.squeeze(dim=0) ref_torch_groupgemm( ref_out_tensor, inputs_tensor, @@ -292,7 +297,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("op_type", ["sgmv"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_punica_expand_nslices( @@ -337,25 +342,28 @@ def test_punica_expand_nslices( else: max_seq_length = max_seq_length.item() slice_offset = 0 + + if op_type == "sgmv": + _LORA_B_PTR_DICT.clear() + inputs_tensor_nslices = inputs_tensor.unsqueeze(dim=0) + inputs_tensor_nslices = inputs_tensor_nslices.repeat((nslices, 1, 1)) + sgmv_expand( + inputs_tensor_nslices, + lora_weights_lst, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + add_inputs=True, + ) + for index in range(nslices): lora_weights = lora_weights_lst[index] - if op_type == "sgmv": - sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - + if op_type == "bgmv": bgmv_expand_slice( inputs_tensor, lora_weights, @@ -378,3 +386,28 @@ def test_punica_expand_nslices( slice_offset += hidden_size assert_close(our_outputs, ref_outputs) + + +if __name__ == "__main__": + from itertools import product + + for ele in product( + BATCHES, + NUM_LORA, + MAX_RANKS, + HIDDEN_SIZES, + [2], + DTYPES, + [ + "sgmv", + ], + SEED, + CUDA_DEVICES, + ): + try: + print(f"{ele} start...") + test_punica_expand_nslices(*ele) + print(f"{ele} passed") + except Exception as error: + raise error + print("Done") diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 2aebdfa964619..0f5c595c38678 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -231,8 +231,7 @@ def _sgmv_expand_fake( batches: int, max_seq_length: int, token_nums: int, - slice_offset: int, - slice_size: int, + offset_start: int = 0, add_inputs: bool = False, ) -> None: return From 0f7897b6ff21dff764e4aeeb17f70843216df6f0 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Dec 2024 08:02:00 +0000 Subject: [PATCH 27/35] Optimize unit test Signed-off-by: Jee Jee Li --- tests/lora/test_punica_sizes.py | 147 +++++++++---------------- tests/lora/test_punica_variation.py | 118 +++++++++----------- tests/lora/utils.py | 144 ++++++++++++++++++++---- vllm/lora/punica_wrapper/punica_gpu.py | 1 - 4 files changed, 227 insertions(+), 183 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index d7ea44d23eee6..1de38358a67de 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -12,11 +12,12 @@ from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_shrink import sgmv_shrink -from vllm.lora.ops.utils import _LORA_B_PTR_DICT +from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform -from .utils import (generate_data, generate_data_for_expand_nslices, - ref_torch_groupgemm) +from .utils import (assert_close, generate_data, + generate_data_for_expand_nslices, + generate_data_for_nslices, ref_torch_groupgemm) HIDDEN_SIZES = [ 128, @@ -113,20 +114,12 @@ CUDA_DEVICES = [f"cuda:{0}"] -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @@ -137,6 +130,7 @@ def test_punica_sgmv( rank: int, hidden_size: int, scaling: float, + nslices: int, dtype: torch.dtype, op_type: str, seed: int, @@ -148,19 +142,20 @@ def test_punica_sgmv( seq_length = 128 ( inputs_tensor, - lora_weights, + lora_weights_lst, our_out_tensor, ref_out_tensor, b_seq_start_loc, lora_indices_tensor, seq_len_tensor, indices, - ) = generate_data( + ) = generate_data_for_nslices( batches, hidden_size, num_loras, rank, seq_length, + nslices, dtype, op_type, device, @@ -172,10 +167,11 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": - our_out_tensor = our_out_tensor.unsqueeze(dim=0) + # Preventing cache error pointer. + _LORA_A_PTR_DICT.clear() sgmv_shrink( inputs_tensor, - (lora_weights, ), + lora_weights_lst, our_out_tensor, b_seq_start_loc, seq_len_tensor, @@ -185,13 +181,22 @@ def test_punica_sgmv( token_nums, scaling, ) - our_out_tensor = our_out_tensor.squeeze(dim=0) + for index in range(nslices): + ref_torch_groupgemm( + ref_out_tensor[index], + inputs_tensor, + lora_weights_lst[index], + lora_indices_tensor, + seq_len_tensor, + batches, + scaling, + op_type, + ) else: - inputs_tensor = inputs_tensor.unsqueeze(dim=0) - + _LORA_B_PTR_DICT.clear() sgmv_expand( inputs_tensor, - (lora_weights, ), + lora_weights_lst, our_out_tensor, b_seq_start_loc, seq_len_tensor, @@ -199,21 +204,25 @@ def test_punica_sgmv( batches, max_seq_length, token_nums, + offset_start=0, add_inputs=True, ) - inputs_tensor = inputs_tensor.squeeze(dim=0) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) + + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + ref_torch_groupgemm( + ref_out_tensor[:, slice_offset:slice_offset + hidden_size], + inputs_tensor[index], + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type, + ) + slice_offset += hidden_size + assert_close(our_out_tensor, ref_out_tensor) @@ -297,25 +306,22 @@ def test_punica_bgmv( @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_punica_expand_nslices( +def test_punica_bgmv_expand_nslices( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, - op_type: str, seed: int, device: str, ): - torch.set_default_device(device) current_platform.seed_everything(seed) - seq_length = 128 if op_type == "sgmv" else 1 + seq_length = 1 ( inputs_tensor, lora_weights_lst, @@ -335,44 +341,18 @@ def test_punica_expand_nslices( nslices, device, ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() slice_offset = 0 - - if op_type == "sgmv": - _LORA_B_PTR_DICT.clear() - inputs_tensor_nslices = inputs_tensor.unsqueeze(dim=0) - inputs_tensor_nslices = inputs_tensor_nslices.repeat((nslices, 1, 1)) - sgmv_expand( - inputs_tensor_nslices, - lora_weights_lst, + for index in range(nslices): + lora_weights = lora_weights_lst[index] + bgmv_expand_slice( + inputs_tensor, + lora_weights, our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, + indices, slice_offset, + slice_size=hidden_size, add_inputs=True, ) - - for index in range(nslices): - lora_weights = lora_weights_lst[index] - if op_type == "bgmv": - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], inputs_tensor, @@ -386,28 +366,3 @@ def test_punica_expand_nslices( slice_offset += hidden_size assert_close(our_outputs, ref_outputs) - - -if __name__ == "__main__": - from itertools import product - - for ele in product( - BATCHES, - NUM_LORA, - MAX_RANKS, - HIDDEN_SIZES, - [2], - DTYPES, - [ - "sgmv", - ], - SEED, - CUDA_DEVICES, - ): - try: - print(f"{ele} start...") - test_punica_expand_nslices(*ele) - print(f"{ele} passed") - except Exception as error: - raise error - print("Done") diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 3b20033271d26..4fd16925eac1a 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -11,12 +11,13 @@ import vllm.lora.ops.bgmv_expand_slice import vllm.lora.ops.bgmv_shrink import vllm.lora.ops.sgmv_expand -import vllm.lora.ops.sgmv_expand_slice import vllm.lora.ops.sgmv_shrink # noqa: F401 +from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform -from .utils import (generate_data, generate_data_for_expand_nslices, - ref_torch_groupgemm) +from .utils import (assert_close, generate_data, + generate_data_for_expand_nslices, + generate_data_for_nslices, ref_torch_groupgemm) HIDDEN_SIZES = [4097] @@ -28,23 +29,12 @@ SEED = [0] CUDA_DEVICES = [f"cuda:{0}"] - -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - # Unlike test_punica_sizes.py, we directly utilize custom op for # testing, which verifies the correct registration of these ops. bgmv_expand = torch.ops.vllm.bgmv_expand bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice bgmv_shrink = torch.ops.vllm.bgmv_shrink sgmv_expand = torch.ops.vllm.sgmv_expand -sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice sgmv_shrink = torch.ops.vllm.sgmv_shrink @@ -53,6 +43,7 @@ def assert_close(a, b): @pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @@ -63,6 +54,7 @@ def test_punica_sgmv( rank: int, hidden_size: int, scaling: float, + nslices: int, dtype: torch.dtype, op_type: str, seed: int, @@ -74,19 +66,20 @@ def test_punica_sgmv( seq_length = 128 ( inputs_tensor, - lora_weights, + lora_weights_lst, our_out_tensor, ref_out_tensor, b_seq_start_loc, lora_indices_tensor, seq_len_tensor, indices, - ) = generate_data( + ) = generate_data_for_nslices( batches, hidden_size, num_loras, rank, seq_length, + nslices, dtype, op_type, device, @@ -98,9 +91,11 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": + # Preventing cache error pointer. + _LORA_A_PTR_DICT.clear() sgmv_shrink( inputs_tensor, - lora_weights, + lora_weights_lst, our_out_tensor, b_seq_start_loc, seq_len_tensor, @@ -110,10 +105,22 @@ def test_punica_sgmv( token_nums, scaling, ) + for index in range(nslices): + ref_torch_groupgemm( + ref_out_tensor[index], + inputs_tensor, + lora_weights_lst[index], + lora_indices_tensor, + seq_len_tensor, + batches, + scaling, + op_type, + ) else: + _LORA_B_PTR_DICT.clear() sgmv_expand( inputs_tensor, - lora_weights, + lora_weights_lst, our_out_tensor, b_seq_start_loc, seq_len_tensor, @@ -121,20 +128,25 @@ def test_punica_sgmv( batches, max_seq_length, token_nums, + offset_start=0, add_inputs=True, ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) + + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + ref_torch_groupgemm( + ref_out_tensor[:, slice_offset:slice_offset + hidden_size], + inputs_tensor[index], + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type, + ) + slice_offset += hidden_size + assert_close(our_out_tensor, ref_out_tensor) @@ -220,24 +232,22 @@ def test_punica_bgmv( @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_punica_expand_nslices( +def test_punica_bgmv_expand_nslices( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, - op_type: str, seed: int, device: str, ): torch.set_default_device(device) current_platform.seed_everything(seed) - seq_length = 128 if op_type == "sgmv" else 1 + seq_length = 1 ( inputs_tensor, lora_weights_lst, @@ -257,40 +267,18 @@ def test_punica_expand_nslices( nslices, device, ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] - if op_type == "sgmv": - sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], inputs_tensor, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index e394c33b3f9ea..b66d18074a7bf 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -18,11 +18,13 @@ def set_module_lora(self, module_name: str, lora: LoRALayerWeights): def get_module_lora(self, module_name: str) -> LoRALayerWeights: return self._loras[module_name] - def init_random_lora(self, - module_name: str, - weight: torch.Tensor, - rank: int = 8, - generate_embeddings_tensor: int = 0): + def init_random_lora( + self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0, + ): lora = LoRALayerWeights( module_name, rank=rank, @@ -35,21 +37,25 @@ def init_random_lora(self, device=self._device), ) if generate_embeddings_tensor: - lora.embeddings_tensor = torch.rand(5, - generate_embeddings_tensor, - dtype=weight.dtype, - device=self._device) + lora.embeddings_tensor = torch.rand( + 5, + generate_embeddings_tensor, + dtype=weight.dtype, + device=self._device, + ) self.set_module_lora(module_name, lora) return lora - def init_lora(self, - module_name: str, - input_dim: int, - output_dim: int, - rank=8, - noop=False, - embeddings_tensor=None): + def init_lora( + self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None, + ): lora = LoRALayerWeights( module_name, rank=rank, @@ -125,8 +131,16 @@ def ref_torch_groupgemm( return -def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, - op_type, device): +def generate_data( + batches, + hidden_size, + lora_nums, + max_rank, + seq_length, + dtype, + op_type, + device, +): seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -187,8 +201,16 @@ def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, ) -def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, - seq_length, dtype, nslices, device): +def generate_data_for_expand_nslices( + batches, + hidden_size, + lora_nums, + max_rank, + seq_length, + dtype, + nslices, + device, +): seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -221,7 +243,87 @@ def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, for b_id in range(batches): lora_index = lora_indices_tensor[b_id] indices[current_offset:current_offset + - seq_len_tensor[b_id]] = lora_index.item() + seq_len_tensor[b_id]] = (lora_index.item()) + current_offset += seq_len_tensor[b_id].item() + + lora_indices_tensor = lora_indices_tensor.to(device) + return ( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) + + +def generate_data_for_nslices( + batches, + hidden_size, + lora_nums, + max_rank, + seq_length, + nslices, + dtype, + op_type, + device, +): + seq_len_tensor = torch.randint(seq_length, seq_length + 1, + (batches, )).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum() + + lora_weights_lst = [] + if op_type == "shrink": + + inputs_tensor = torch.rand((total_tokens, hidden_size), + dtype=dtype).to(device) + + for _ in range(nslices): + if op_type == "shrink": + lora_weights_lst.append( + torch.rand( + (lora_nums, max_rank, hidden_size), # col-major + dtype=dtype, + ).to(device)) + # NOTE shrink kernel using torch.float32 as output type + # shrink op need atomic_add, so output is initinized by 0 + our_out_tensor = torch.zeros( + (nslices, total_tokens, max_rank), + dtype=torch.float32, + ).to(device) + else: + inputs_tensor = torch.rand( + (nslices, total_tokens, max_rank), + dtype=dtype, + ).to(device) + for _ in range(nslices): + lora_weights_lst.append( + torch.rand( + (lora_nums, hidden_size, max_rank), # col-major + dtype=dtype, + ).to(device)) + # expand op needs to complete y+=a@lora_b, so output is + # initinized randomly + our_out_tensor = torch.rand((total_tokens, hidden_size * nslices), + dtype=dtype).to(device) + + # Ensure the same input. + ref_out_tensor = our_out_tensor.clone() + lora_indices_tensor = torch.randint(0, + lora_nums - 1 if lora_nums > 1 else 1, + (batches, )) + indices = torch.zeros((total_tokens), dtype=torch.long).to(device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + indices[current_offset:current_offset + + seq_len_tensor[b_id]] = (lora_index.item()) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 076a563583b62..278f7b5a8e9f4 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -269,7 +269,6 @@ def add_lora_linear(self, output_slices, add_inputs=True, **kwargs) - pass def add_lora_logits(self, y: torch.Tensor, From 3edb696bad6a9c855df90494510c6fe69d3652fc Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Dec 2024 08:26:30 +0000 Subject: [PATCH 28/35] Optimize unit test Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 24 ++++++++++++++++++++++++ vllm/lora/ops/sgmv_shrink.py | 3 +-- vllm/lora/ops/utils.py | 4 ++-- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 0f5c595c38678..934beee012a39 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -150,6 +150,30 @@ def _sgmv_expand( offset_start: int = 0, add_inputs: bool = False, ) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (List[torch.Tensor]): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights[0].dtype in [ torch.float16, diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index d1b5141443715..1ce0d22cdd890 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -145,10 +145,9 @@ def _sgmv_shrink( scaling: float, ) -> None: """ - Args: inputs (torch.Tensor): input tensor - lora_a_weights (torch.Tensor): lora'a weight + lora_a_weights (List[torch.Tensor]): lora'a weight output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py index 635840beb9a9f..2a3d94d785a46 100644 --- a/vllm/lora/ops/utils.py +++ b/vllm/lora/ops/utils.py @@ -55,7 +55,7 @@ def get_lora_op_configs(op_type: str, batch: int, def _get_lora_a_ptr(lora_a_weights, device): """ `_LORA_A_PTR_DICT` collects the required information during `profile_run`, - and subsequent usage is through LUT. + After this, it remains constant and subsequent usage is through LUT. Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py """ @@ -108,7 +108,7 @@ def _get_lora_a_ptr(lora_a_weights, device): def _get_lora_b_ptr(lora_weights, offset_start, device): """ `_LORA_B_PTR_DICT` collects the required information during `profile_run`, - and subsequent usage is through LUT. + After this, it remains constant and subsequent usage is through LUT. Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py From 49c6c21dc2b17db0a5dcb5fd7d6815c77afbebfe Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 25 Dec 2024 08:41:00 +0000 Subject: [PATCH 29/35] Fix comment Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 934beee012a39..2bb14985bad98 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -153,7 +153,7 @@ def _sgmv_expand( """ Args: inputs (torch.Tensor): input tensor - lora_b_weights (List[torch.Tensor]): lora'a weight + lora_b_weights (List[torch.Tensor]): lora'b weight output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index From 489eca1e77fef4985321b88e870c4caa44fc833f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 28 Dec 2024 01:56:06 +0000 Subject: [PATCH 30/35] Optimize code Signed-off-by: Jee Jee Li --- tests/lora/test_punica_sizes.py | 10 ++++++++-- tests/lora/test_punica_variation.py | 10 ++++++++-- vllm/lora/ops/sgmv_expand.py | 10 +++++----- vllm/lora/ops/sgmv_shrink.py | 10 ++++++---- vllm/lora/ops/utils.py | 7 ++++--- 5 files changed, 31 insertions(+), 16 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 1de38358a67de..27f4dc21c7ed2 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,6 +4,8 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ +from threading import Lock + import pytest import torch @@ -113,6 +115,8 @@ SEED = [0] CUDA_DEVICES = [f"cuda:{0}"] +_dict_lock = Lock() + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @@ -168,7 +172,8 @@ def test_punica_sgmv( max_seq_length = max_seq_length.item() if op_type == "shrink": # Preventing cache error pointer. - _LORA_A_PTR_DICT.clear() + with _dict_lock: + _LORA_A_PTR_DICT.clear() sgmv_shrink( inputs_tensor, lora_weights_lst, @@ -193,7 +198,8 @@ def test_punica_sgmv( op_type, ) else: - _LORA_B_PTR_DICT.clear() + with _dict_lock: + _LORA_B_PTR_DICT.clear() sgmv_expand( inputs_tensor, lora_weights_lst, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 4fd16925eac1a..6b9a16630af0e 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,6 +3,8 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ +from threading import Lock + import pytest import torch @@ -37,6 +39,8 @@ sgmv_expand = torch.ops.vllm.sgmv_expand sgmv_shrink = torch.ops.vllm.sgmv_shrink +_dict_lock = Lock() + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @@ -92,7 +96,8 @@ def test_punica_sgmv( max_seq_length = max_seq_length.item() if op_type == "shrink": # Preventing cache error pointer. - _LORA_A_PTR_DICT.clear() + with _dict_lock: + _LORA_A_PTR_DICT.clear() sgmv_shrink( inputs_tensor, lora_weights_lst, @@ -117,7 +122,8 @@ def test_punica_sgmv( op_type, ) else: - _LORA_B_PTR_DICT.clear() + with _dict_lock: + _LORA_B_PTR_DICT.clear() sgmv_expand( inputs_tensor, lora_weights_lst, diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 2bb14985bad98..f6a9ac0777f51 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -71,12 +71,14 @@ def _sgmv_expand_kernel( offset_k = tl.arange(0, BLOCK_K) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - + # ls_d*_ptr can be either an integer or a pointer if SAME_STRIDE: + # integer cur_lora_d0_stride = ls_d0_ptr cur_lora_d1_stride = ls_d1_ptr cur_lora_d2_stride = ls_d2_ptr else: + # pointer cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) @@ -175,10 +177,8 @@ def _sgmv_expand( output tensor. Defaults to False. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights[0].dtype in [ - torch.float16, - torch.bfloat16, - ] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] assert inputs.size(1) == token_nums assert inputs.size(0) == len(lora_b_weights) diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 1ce0d22cdd890..13a5e09662a56 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -77,11 +77,14 @@ def _sgmv_shrink_kernel( # input ptr a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + offset_k[None, :] * xk_stride) + # ls_d*_ptr can be either an integer or a pointer if SAME_STRIDE: + # integer cur_lora_d0_stride = ls_d0_ptr cur_lora_d1_stride = ls_d1_ptr cur_lora_d2_stride = ls_d2_ptr else: + # pointer cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) @@ -167,10 +170,9 @@ def _sgmv_shrink( """ assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype in [torch.float16, torch.bfloat16] - assert lora_a_weights[0].dtype in [ - torch.float16, - torch.bfloat16, - ] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights[0].size(-1) assert b_seq_start_loc.size(0) == batches diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py index 2a3d94d785a46..1507c9c0aeee9 100644 --- a/vllm/lora/ops/utils.py +++ b/vllm/lora/ops/utils.py @@ -1,5 +1,5 @@ import functools -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch @@ -52,7 +52,7 @@ def get_lora_op_configs(op_type: str, batch: int, _LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -def _get_lora_a_ptr(lora_a_weights, device): +def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): """ `_LORA_A_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. @@ -105,7 +105,8 @@ def _get_lora_a_ptr(lora_a_weights, device): return _LORA_A_PTR_DICT.get(key) -def _get_lora_b_ptr(lora_weights, offset_start, device): +def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, + device: str): """ `_LORA_B_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. From 04ae0dd7494bfbd325f7a21b513022cf6ec4456e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 28 Dec 2024 02:04:33 +0000 Subject: [PATCH 31/35] Add lock for unit test Signed-off-by: Jee Jee Li --- tests/lora/test_punica_sizes.py | 50 ++++++++++++++--------------- tests/lora/test_punica_variation.py | 50 ++++++++++++++--------------- 2 files changed, 50 insertions(+), 50 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 27f4dc21c7ed2..0351fedd1cfa5 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -174,18 +174,18 @@ def test_punica_sgmv( # Preventing cache error pointer. with _dict_lock: _LORA_A_PTR_DICT.clear() - sgmv_shrink( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) + sgmv_shrink( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) for index in range(nslices): ref_torch_groupgemm( ref_out_tensor[index], @@ -200,19 +200,19 @@ def test_punica_sgmv( else: with _dict_lock: _LORA_B_PTR_DICT.clear() - sgmv_expand( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - offset_start=0, - add_inputs=True, - ) + sgmv_expand( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + offset_start=0, + add_inputs=True, + ) slice_offset = 0 for index in range(nslices): diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 6b9a16630af0e..9ee10e7c23ee6 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -98,18 +98,18 @@ def test_punica_sgmv( # Preventing cache error pointer. with _dict_lock: _LORA_A_PTR_DICT.clear() - sgmv_shrink( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) + sgmv_shrink( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) for index in range(nslices): ref_torch_groupgemm( ref_out_tensor[index], @@ -124,19 +124,19 @@ def test_punica_sgmv( else: with _dict_lock: _LORA_B_PTR_DICT.clear() - sgmv_expand( - inputs_tensor, - lora_weights_lst, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - offset_start=0, - add_inputs=True, - ) + sgmv_expand( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + offset_start=0, + add_inputs=True, + ) slice_offset = 0 for index in range(nslices): From 65d0f2f21db332a3897414206035fa01d6f04edb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Dec 2024 07:31:27 +0000 Subject: [PATCH 32/35] Optimize arg Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 12 ++++++------ vllm/lora/ops/sgmv_shrink.py | 28 +++++++++++++++------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index f6a9ac0777f51..d2e7187d84f3b 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -30,11 +30,11 @@ def _sgmv_expand_kernel( input_d0_stride, input_d1_stride, input_d2_stride, # 1 - ls_d0_ptr, # lora stride(0) + ls_d0_ptr, ls_d1_ptr, - ls_d2_ptr, - cm_stride, - cn_stride, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -126,8 +126,8 @@ def _sgmv_expand_kernel( offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) + c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) M = tl.load(seq_lens + cur_batch) c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 13a5e09662a56..cd52579198607 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -27,14 +27,14 @@ def _sgmv_shrink_kernel( seq_lens, lora_indices, scaling, - xm_stride, # hidden_size - xk_stride, # 1 - ls_d0_ptr, # hidden_size*max_rank + input_d0_stride, + input_d1_stride, # 1 + ls_d0_ptr, ls_d1_ptr, - ls_d2_ptr, - c0_stride, - cm_stride, - cn_stride, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, + output_d2_stride, # 1 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -75,8 +75,9 @@ def _sgmv_shrink_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) # input ptr - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride) + a_ptr = (input_ptr + cur_seq_start * input_d0_stride + + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) # ls_d*_ptr can be either an integer or a pointer if SAME_STRIDE: # integer @@ -116,14 +117,15 @@ def _sgmv_shrink_kernel( other=0.0) accumulator += tl.dot(tiled_a, tiled_b) - a_ptr += BLOCK_K * SPLIT_K * xk_stride + a_ptr += BLOCK_K * SPLIT_K * input_d1_stride b_ptr += BLOCK_K * SPLIT_K * cur_lora_d2_stride offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * c0_stride - c_ptr = cur_out_ptr + offset_cm[:, None] * cm_stride + offset_cn[ - None, :] * cn_stride + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) accumulator *= scaling From 421382e033824f1e2642a4476ad89fd95fa1a66b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 2 Jan 2025 08:47:37 +0000 Subject: [PATCH 33/35] Fix expand bug Signed-off-by: Jee Jee Li --- vllm/lora/ops/sgmv_expand.py | 32 ++++++++++++++----------- vllm/lora/ops/sgmv_shrink.py | 45 +++++++++++------------------------- vllm/lora/ops/utils.py | 42 ++++++++++++++++----------------- 3 files changed, 51 insertions(+), 68 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index d2e7187d84f3b..8af44b703810b 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -35,6 +35,7 @@ def _sgmv_expand_kernel( ls_d2_ptr, # 1 output_d0_stride, output_d1_stride, # 1 + output_hs_ptr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -54,13 +55,18 @@ def _sgmv_expand_kernel( pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) slice_id = tl.program_id(axis=2) - cta_n_num = tl.cdiv(N, BLOCK_N) + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) pid_m = pid // cta_n_num pid_n = pid % cta_n_num M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: return + if pid_n * BLOCK_N > curr_N: + return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return @@ -70,7 +76,8 @@ def _sgmv_expand_kernel( offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_k = tl.arange(0, BLOCK_K) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N), + BLOCK_N) # ls_d*_ptr can be either an integer or a pointer if SAME_STRIDE: # integer @@ -131,7 +138,7 @@ def _sgmv_expand_kernel( M = tl.load(seq_lens + cur_batch) c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < - (cur_slice_start + N)) + (cur_slice_start + curr_N)) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out @@ -186,17 +193,13 @@ def _sgmv_expand( assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert output_tensor.is_contiguous() - ( - slice_start_tensor, - lora_ptr_tensor, - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - same_stride, - ) = _get_lora_b_ptr(lora_b_weights, offset_start, b_seq_start_loc.device) + (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, + b_seq_start_loc.device) # TODO tuning this config - N, K = lora_b_weights[0].shape[-2:] # K= rank,N=hidden_size + K = lora_b_weights[0].shape[-1] # K= rank BLOCK_M = 64 BLOCK_N = 128 @@ -211,7 +214,7 @@ def _sgmv_expand( ]: CAST_TYPE = True grid = ( - triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), batches, len(lora_b_weights), ) @@ -219,7 +222,7 @@ def _sgmv_expand( inputs, lora_ptr_tensor, output_tensor, - N, + MAX_N, K, b_seq_start_loc, seq_len_tensor, @@ -233,6 +236,7 @@ def _sgmv_expand( lora_strides_d2_tensor, output_tensor.stride(0), output_tensor.stride(1), + hidden_sizes_tensor, BLOCK_M, BLOCK_N, BLOCK_K, diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index cd52579198607..3d2ebe8286f56 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -29,9 +29,9 @@ def _sgmv_shrink_kernel( scaling, input_d0_stride, input_d1_stride, # 1 - ls_d0_ptr, - ls_d1_ptr, - ls_d2_ptr, # 1 + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, # 1 output_d0_stride, output_d1_stride, output_d2_stride, # 1 @@ -40,8 +40,7 @@ def _sgmv_shrink_kernel( BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, - SLICE_NUM: tl.constexpr, - SAME_STRIDE: tl.constexpr): + SLICE_NUM: tl.constexpr): """ The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, @@ -78,17 +77,6 @@ def _sgmv_shrink_kernel( a_ptr = (input_ptr + cur_seq_start * input_d0_stride + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride) - # ls_d*_ptr can be either an integer or a pointer - if SAME_STRIDE: - # integer - cur_lora_d0_stride = ls_d0_ptr - cur_lora_d1_stride = ls_d1_ptr - cur_lora_d2_stride = ls_d2_ptr - else: - # pointer - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) if SLICE_NUM == 1: # current lora ptr @@ -98,9 +86,9 @@ def _sgmv_shrink_kernel( cur_lora_ptr = tl.load(lora_ptr + slice_id).to( tl.pointer_type(input_ptr.dtype.element_ty)) - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - rbn[None, :] * cur_lora_d1_stride + - offset_k[:, None] * cur_lora_d2_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): @@ -118,7 +106,7 @@ def _sgmv_shrink_kernel( accumulator += tl.dot(tiled_a, tiled_b) a_ptr += BLOCK_K * SPLIT_K * input_d1_stride - b_ptr += BLOCK_K * SPLIT_K * cur_lora_d2_stride + b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -181,13 +169,8 @@ def _sgmv_shrink( assert lora_indices_tensor.size(0) == batches assert inputs.is_contiguous() assert output_tensor.is_contiguous() - ( - lora_ptr_tensor, - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, - same_stride, - ) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) # TODO tuning this config N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank BLOCK_M = 32 @@ -200,7 +183,6 @@ def _sgmv_shrink( SPLIT_K * len(lora_a_weights), batches, ) - _sgmv_shrink_kernel[grid]( inputs, lora_ptr_tensor, @@ -213,9 +195,9 @@ def _sgmv_shrink( scaling, inputs.stride(0), inputs.stride(1), - lora_strides_d0_tensor, - lora_strides_d1_tensor, - lora_strides_d2_tensor, + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, output_tensor.stride(0), output_tensor.stride(1), output_tensor.stride(2), @@ -225,7 +207,6 @@ def _sgmv_shrink( EVEN_K, SPLIT_K, len(lora_a_weights), - same_stride, ) return diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py index 1507c9c0aeee9..7df5bc2c225e5 100644 --- a/vllm/lora/ops/utils.py +++ b/vllm/lora/ops/utils.py @@ -79,29 +79,21 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): lora_strides_d0.append(lora_a_weight.stride(0)) lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) - if len(lora_a_weights) > 1: lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - else: lora_ptr_tensor = lora_a_weights[0] - # If each lora has the same stride, there's no need to use a - # tensor for storage. - if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 - and len(set(lora_strides_d2)) == 1): - lora_strides_d0_tensor = lora_strides_d0[0] - lora_strides_d1_tensor = lora_strides_d1[0] - lora_strides_d2_tensor = lora_strides_d2[0] - same_stride = True - else: - lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) - lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) - lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) - same_stride = False - _LORA_A_PTR_DICT[key] = (lora_ptr_tensor, lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, - same_stride) + if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1): + raise ValueError("All LoRA weights must have the same stride.") + + _LORA_A_PTR_DICT[key] = ( + lora_ptr_tensor, + lora_strides_d0[0], + lora_strides_d1[0], + lora_strides_d2[0], + ) return _LORA_A_PTR_DICT.get(key) @@ -123,6 +115,7 @@ def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, lora_strides_d0 = [] lora_strides_d1 = [] lora_strides_d2 = [] + hidden_sizes = [] slice_offset = offset_start for lora_b_weight in lora_weights: if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) @@ -137,6 +130,7 @@ def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, lora_strides_d2.append(lora_b_weight.stride(2)) slice_offset_lst.append(slice_offset) slice_offset += lora_b_weight.size(1) + hidden_sizes.append(lora_b_weight.size(1)) if len(lora_weights) > 1: # note these are device tensors @@ -148,20 +142,24 @@ def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, # If each lora has the same stride, there's no need to use a # tensor for storage. - if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 - and len(set(lora_strides_d2)) == 1): + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and + len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: lora_strides_d0_tensor = lora_strides_d0[0] lora_strides_d1_tensor = lora_strides_d1[0] lora_strides_d2_tensor = lora_strides_d2[0] + hidden_sizes_tensor = hidden_sizes[0] same_stride = True else: lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device) same_stride = False - + # MAX_N is the maximum hidden size among all the lora_b weights + MAX_N = max(hidden_sizes) _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, lora_strides_d1_tensor, - lora_strides_d2_tensor, same_stride) + lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) return _LORA_B_PTR_DICT.get(key) From 2c79295285bfe709876a04b0217868c32b8f1434 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 3 Jan 2025 12:59:14 +0000 Subject: [PATCH 34/35] Reduce memory Signed-off-by: Jee Jee Li --- tests/lora/test_minicpmv.py | 5 +++-- tests/lora/test_minicpmv_tp.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 78bf5a1617233..8017c21b29da0 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -64,8 +64,9 @@ def test_minicpmv_lora(minicpmv_lora_files): MODEL_PATH, max_num_seqs=2, enable_lora=True, - max_loras=4, - max_lora_rank=64, + max_loras=2, + max_lora_rank=8, + enforce_eager=True, trust_remote_code=True, enable_chunked_prefill=True, ) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 930f177953a5f..05aad03d9d280 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -64,8 +64,8 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): MODEL_PATH, enable_lora=True, max_num_seqs=2, - max_loras=4, - max_lora_rank=64, + max_loras=2, + max_lora_rank=8, tensor_parallel_size=2, trust_remote_code=True, fully_sharded_loras=fully_sharded, @@ -89,6 +89,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): max_lora_rank=64, tensor_parallel_size=4, trust_remote_code=True, + enforce_eager=True, fully_sharded_loras=fully_sharded, enable_chunked_prefill=True, ) From 7e8d3bd38b80495bd5c8d779d376daef45e64cf6 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 4 Jan 2025 06:47:47 +0000 Subject: [PATCH 35/35] Modify minicpmv test Signed-off-by: Jee Jee Li --- .buildkite/test-pipeline.yaml | 3 +- tests/lora/test_minicpmv.py | 78 ---------------------------------- tests/lora/test_minicpmv_tp.py | 64 +++++++++++++++++++--------- 3 files changed, 46 insertions(+), 99 deletions(-) delete mode 100644 tests/lora/test_minicpmv.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c6f8316412e2f..f3294c3c88192 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -242,7 +242,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min @@ -533,6 +533,7 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_minicpmv_tp.py - label: Weight Loading Multiple GPU Test # 33min diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py deleted file mode 100644 index 8017c21b29da0..0000000000000 --- a/tests/lora/test_minicpmv.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import List - -import pytest - -import vllm -from vllm.assets.image import ImageAsset -from vllm.lora.request import LoRARequest -from vllm.platforms import current_platform - -MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" - -PROMPT_TEMPLATE = ( - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" - "(./)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n") - -IMAGE_ASSETS = [ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), -] - -# After fine-tuning with LoRA, all generated content should start begin `A`. -EXPECTED_OUTPUT = [ - "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 - "A pink cherry blossom tree with a blue sky in the background.", -] - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - sampling_params = vllm.SamplingParams( - temperature=0, - max_tokens=5, - stop_token_ids=[128001, 128009], # eos_id, eot_id - ) - - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] - - outputs = llm.generate( - inputs, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") -def test_minicpmv_lora(minicpmv_lora_files): - llm = vllm.LLM( - MODEL_PATH, - max_num_seqs=2, - enable_lora=True, - max_loras=2, - max_lora_rank=8, - enforce_eager=True, - trust_remote_code=True, - enable_chunked_prefill=True, - ) - output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output1[i]) - output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output2[i]) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 05aad03d9d280..3b0f18325a40b 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -3,10 +3,10 @@ import pytest import vllm +from tests.utils import fork_new_process_for_each_test from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest - -from ..utils import multi_gpu_test +from vllm.platforms import current_platform MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" @@ -17,13 +17,11 @@ IMAGE_ASSETS = [ ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), ] # After fine-tuning with LoRA, all generated content should start begin `A`. EXPECTED_OUTPUT = [ "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 - "A pink cherry blossom tree with a blue sky in the background.", ] @@ -50,37 +48,40 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: # Print the outputs. generated_texts: List[str] = [] for output in outputs: - prompt = output.prompt generated_text = output.outputs[0].text.strip() generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Generated text: {generated_text!r}") return generated_texts -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") +@fork_new_process_for_each_test +def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, - enable_lora=True, max_num_seqs=2, + enable_lora=True, max_loras=2, max_lora_rank=8, - tensor_parallel_size=2, + enforce_eager=True, trust_remote_code=True, - fully_sharded_loras=fully_sharded, enable_chunked_prefill=True, ) - - output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) - + output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + assert EXPECTED_OUTPUT[i].startswith(output1[i]) + output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output2[i]) -@multi_gpu_test(num_gpus=4) -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") +@fork_new_process_for_each_test +def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, enable_lora=True, @@ -90,9 +91,32 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): tensor_parallel_size=4, trust_remote_code=True, enforce_eager=True, - fully_sharded_loras=fully_sharded, enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + + +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") +@fork_new_process_for_each_test +def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=2, + max_lora_rank=8, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + ) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i])