From f9b7c64f7d81eed1e7ab1fa28e14826f64dded02 Mon Sep 17 00:00:00 2001 From: xiaobochen <35516720+xiaobochen123@users.noreply.github.com> Date: Fri, 6 Dec 2024 02:44:01 +0800 Subject: [PATCH 1/5] MoE Expert Parallel Impl (#2203) Co-authored-by: HAI --- .github/workflows/pr-test.yml | 6 + python/sglang/srt/layers/ep_moe/__init__.py | 0 python/sglang/srt/layers/ep_moe/kernels.py | 349 +++++++++ python/sglang/srt/layers/ep_moe/layer.py | 661 ++++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/deepseek_v2.py | 8 +- python/sglang/srt/models/mixtral.py | 18 +- python/sglang/srt/server_args.py | 23 + test/srt/test_moe_ep.py | 113 +++ 10 files changed, 1172 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/layers/ep_moe/__init__.py create mode 100644 python/sglang/srt/layers/ep_moe/kernels.py create mode 100644 python/sglang/srt/layers/ep_moe/layer.py create mode 100644 test/srt/test_moe_ep.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 59f0006e128..49c6ec88327 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -105,6 +105,12 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py + - name: Evaluate MoE EP accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_moe_ep.py + performance-test-1-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner diff --git a/python/sglang/srt/layers/ep_moe/__init__.py b/python/sglang/srt/layers/ep_moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/ep_moe/kernels.py new file mode 100644 index 00000000000..e0486891aa7 --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/kernels.py @@ -0,0 +1,349 @@ +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +logger = logging.getLogger(__name__) + + +@triton.jit +def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): + expert = tl.program_id(0) + low = 0 + high = num_toks - 1 + target_location = -1 + while low <= high: + mid = (low + high) // 2 + + if tl.load(reorder_topk_ids + mid) > expert: + high = mid - 1 + else: + low = mid + 1 + target_location = mid + tl.store(seg_indptr + expert + 1, target_location + 1) + + +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) + + compute_seg_indptr_triton_kernel[(num_experts,)]( + reorder_topk_ids, seg_indptr, topk_ids.numel() + ) + + BLOCK_SIZE = 512 + grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE + ) + return reorder_topk_ids, src2dst, seg_indptr + + +@triton.jit +def pre_reorder_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + if a1_scales_ptr is not None: + scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id) + else: + scale = 1.0 + + dst_idx = tl.load(src2dst_ptr + idx) + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + out_data = (in_data * scale).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + +@triton.jit +def silu_and_mul_triton_kernel( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + BLOCK_SIZE: tl.constexpr, +): + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + + pid = tl.program_id(0) + expert_id = tl.load(reorder_topk_ids + pid) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + gateup_output_ptr = gateup_output + pid * hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + down_input_ptr = down_input + pid * half_hidden_size + + if scales is not None: + scale = tl.load(scales + expert_id - start_expert_id) + scale = (1 / scale).to(InDtype) + else: + scale = 1 + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + # silu & mul & quantize + gate_output = gate_output * tl.sigmoid(gate_output) + gate_output = gate_output.to(InDtype) + + silu_mul_output = gate_output * up_output * scale + silu_mul_output = silu_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) + + +@triton.jit +def post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + computed = False + store_ptr = output_ptr + src_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + expert_id = tl.load(topk_ids_ptr + idx) + if expert_id >= start_expert_id and expert_id <= end_expert_id: + computed = True + dst_idx = tl.load(src2dst_ptr + idx) + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + if computed == False: + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + tl.store( + store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask + ) + + +@triton.jit +def compute_m_range( + pid, + batch_size, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + BLOCK_SIZE_M: tl.constexpr, +): + idx = 0 + for bs in range(batch_size): + tiles = tl.load(m_num_tiles_indptr + bs) + if pid >= tiles: + idx = bs + + idx_start = tl.load(m_num_tiles_indptr + idx) + + m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M + m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M) + expert_id = tl.load(weight_indices + idx) + return m_range_start, m_range_end, expert_id + + +@triton.jit +def grouped_gemm_triton_kernel( + a, + b, + c, + batch_size, + N, + K, + seg_indptr, + weight_indices, + m_num_tiles_indptr, + use_fp8_w8a8, + scale_a, + scale_b, + a_stride_0: tl.constexpr, + b_stride_0: tl.constexpr, + b_stride_1: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + c_dtype = c.dtype.element_ty + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + total_m_block = tl.load(m_num_tiles_indptr + batch_size) + if pid_m >= total_m_block: + return + + m_range_start, m_range_end, expert_id = compute_m_range( + pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M + ) + if m_range_end - m_range_start == 0: + return + + n_range_start = pid_n * BLOCK_SIZE_N + n_range_end = min(n_range_start + BLOCK_SIZE_N, N) + + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0) + offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :] + b_ptr = b + ( + (expert_id * b_stride_0) + + (n_range_start + offs_bn[:, None]) * b_stride_1 + + offs_k[None, :] + ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a_tile = tl.load( + a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + b_tile = tl.load( + b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 + ) + accumulator = tl.dot(a_tile, b_tile.T, accumulator) + a_ptr += BLOCK_SIZE_K + b_ptr += BLOCK_SIZE_K + + if use_fp8_w8a8: + scale_a_value = tl.load(scale_a + expert_id) + scale_b_value = tl.load(scale_b + expert_id) + accumulator *= scale_a_value * scale_b_value + c_tile = accumulator.to(c_dtype) + + offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N) + c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :] + c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end) + tl.store(c_ptr, c_tile, mask=c_mask) + + +@triton.jit +def compute_m_num_tiles_indptr( + m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr +): + for bs in range(batch_size): + m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs) + cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M) + pre_num_tiles = tl.load(m_num_tiles_indptr + bs) + tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles) + + +def grouped_gemm_triton( + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, +): + assert weight_column_major == True # TODO: more + if use_fp8_w8a8: + assert scale_a is not None and scale_b is not None + + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + } + + m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64) + compute_m_num_tiles_indptr[(1,)]( + m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"] + ) + + grid = lambda META: ( + triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size, + triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]), + ) + + grouped_gemm_triton_kernel[grid]( + a, + b, + c, + batch_size, + b.size(1), + b.size(2), + seg_indptr, + weight_indices, + m_num_tiles_indptr, + use_fp8_w8a8, + scale_a, + scale_b, + a.stride(0), + b.stride(0), + b.stride(1), + **config, + ) + return c diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/ep_moe/layer.py new file mode 100644 index 00000000000..eca119845a7 --- /dev/null +++ b/python/sglang/srt/layers/ep_moe/layer.py @@ -0,0 +1,661 @@ +import logging +from typing import Callable, List, Optional, Tuple + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod + +from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.ep_moe.kernels import ( + grouped_gemm_triton, + post_reorder_triton_kernel, + pre_reorder_triton_kernel, + run_moe_ep_preproess, + silu_and_mul_triton_kernel, +) +from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk +from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import is_hip, set_weight_attrs + +logger = logging.getLogger(__name__) + + +class GroupedGemmRunner(torch.nn.Module): + flashinfer_gemm_warpper = None + + def __init__(self, device, use_flashinfer: bool = False): + super().__init__() + self.device = device + self.use_flashinfer = use_flashinfer + if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None: + GroupedGemmRunner._init_flashinfer_wrapper(device) + + @classmethod + def _init_flashinfer_wrapper(cls, device): + from flashinfer import SegmentGEMMWrapper + + workspace_buffer = torch.empty( + 128 * 1024 * 1024, dtype=torch.int8, device=device + ) + cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer) + + # c = a * b + def forward( + self, + a: torch.Tensor, + b: torch.Tensor, + c: torch.Tensor, + batch_size: int, + weight_column_major: bool, + seg_indptr: Optional[torch.Tensor] = None, + weight_indices: Optional[torch.Tensor] = None, + use_fp8_w8a8: bool = False, + scale_a: torch.Tensor = None, + scale_b: torch.Tensor = None, + ): + if self.use_flashinfer: + # TODO: flashinfer + assert False + assert GroupedGemmRunner.flashinfer_gemm_warpper is not None + c = GroupedGemmRunner.flashinfer_gemm_warpper.run( + x=a, + weights=b, + batch_size=batch_size, + weight_column_major=weight_column_major, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + ) + else: + assert weight_column_major == True + c = grouped_gemm_triton( + a, + b, + c, + batch_size, + weight_column_major, + seg_indptr, + weight_indices, + use_fp8_w8a8, + scale_a, + scale_b, + ) + return c + + +class EPMoE(torch.nn.Module): + """ + MoE Expert Parallel Impl + + + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.tp_rank = get_tensor_model_parallel_rank() + + self.num_experts = num_experts + assert self.num_experts % self.tp_size == 0 + self.num_experts_per_partition = self.num_experts // self.tp_size + self.start_expert_id = self.tp_rank * self.num_experts_per_partition + self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 + + self.top_k = top_k + self.intermediate_size = intermediate_size + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() + self.use_fp8_w8a8 = False + self.activation_scheme = None + else: + self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( + quant_config + ) + self.use_fp8_w8a8 = True + self.fp8_dtype = torch.float8_e4m3fn + self.activation_scheme = quant_config.activation_scheme + + self.quant_method.create_weights( + layer=self, + num_experts_per_partition=self.num_experts_per_partition, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + self.grouped_gemm_runner = None + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + if self.grouped_gemm_runner is None: + self.grouped_gemm_runner = GroupedGemmRunner( + hidden_states.device, use_flashinfer=False # TODO: use flashinfer + ) + + topk_weights, topk_ids = self.select_experts( + hidden_states, + router_logits, + self.top_k, + self.renormalize, + self.topk_group, + self.num_expert_group, + ) + + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( + topk_ids, self.num_experts + ) + + gateup_input = torch.empty( + (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), + device=hidden_states.device, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + ) + if self.activation_scheme == "dynamic": + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + + # PreReorder + pre_reorder_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + self.w13_input_scale, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + + seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2] + weight_indices_cur_rank = torch.arange( + 0, + self.num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + # GroupGemm-0 + gateup_output = torch.empty( + gateup_input.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + gateup_output = self.grouped_gemm_runner( + a=gateup_input, + b=self.w13_weight, + c=gateup_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w13_input_scale, + scale_b=self.w13_weight_scale, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + ) + if self.w2_input_scale is None: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + BLOCK_SIZE=512, + ) + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + self.w2_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + down_output = self.grouped_gemm_runner( + a=down_input, + b=self.w2_weight, + c=down_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w2_input_scale, + scale_b=self.w2_weight_scale, + ) + + # PostReorder + output = torch.empty_like(hidden_states) + post_reorder_triton_kernel[(hidden_states.size(0),)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + self.start_expert_id, + self.end_expert_id, + self.top_k, + hidden_states.size(1), + BLOCK_SIZE=512, + ) + return output + + def select_experts( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + ): + if self.use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + return topk_weights, topk_ids.to(torch.int32) + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_" + ), + f"experts.{expert_id}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + if expert_id < self.start_expert_id or expert_id > self.end_expert_id: + return + expert_id = expert_id - self.start_expert_id + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}." + ) + + # Special case for fp8 scales. + if "scale" in weight_name: + self._load_fp8_scale( + param.data, loaded_weight, weight_name, shard_id, expert_id + ) + return + + expert_data = param.data[expert_id] + if shard_id == "w2": + param.data[expert_id] = loaded_weight + elif shard_id == "w1": + param.data[expert_id][: self.intermediate_size, :] = loaded_weight + elif shard_id == "w3": + param.data[expert_id][self.intermediate_size :, :] = loaded_weight + else: + raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}") + + def _load_fp8_scale( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight + + +@register_custom_op("sglang_unquantized_ep_moe") +class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def create_weights( + self, + layer: torch.nn.Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) + w13_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + w13_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + ones_tensor, + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class Fp8EPMoEMethod(Fp8MoEMethod): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts_per_partition: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts_per_partition, + hidden_size, + intermediate_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": "tensor"}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts_per_partition, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) + + for expert in range(layer.num_experts_per_partition): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + if self.quant_config.activation_scheme == "static": + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 28677efeac4..5855d4248ff 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -58,6 +58,7 @@ "torchao_config": ServerArgs.torchao_config, "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_ep_moe": ServerArgs.enable_ep_moe, } diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4eaedbccbff..3f0cbecac15 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -141,6 +141,7 @@ def __init__( "torchao_config": server_args.torchao_config, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, + "enable_ep_moe": server_args.enable_ep_moe, } ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 80db9a35c71..63cea92c289 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -31,6 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -113,12 +114,12 @@ def __init__( "Only silu is supported for now." ) - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -834,7 +835,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index f1ae1f57a3d..f3fad226091 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,9 +21,13 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.layers.ep_moe.layer import EPMoE from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -38,6 +42,7 @@ ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -63,6 +68,7 @@ def __init__( prefix: str = "", ): super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. @@ -74,14 +80,13 @@ def __init__( quant_config=None, prefix=f"{prefix}.gate", ) - - self.experts = FusedMoE( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, renormalize=True, quant_config=quant_config, tp_size=tp_size, @@ -95,6 +100,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) @@ -319,7 +326,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7b337500fd7..8719d919068 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -93,6 +93,8 @@ class ServerArgs: # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Expert parallelism + ep_size: int = 1 # Multi-node distributed serving dist_init_addr: Optional[str] = None @@ -130,6 +132,7 @@ class ServerArgs: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False + enable_ep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -216,6 +219,12 @@ def __post_init__(self): "Data parallel size is adjusted to be the same as tensor parallel size. " "Overlap scheduler is disabled." ) + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) # GGUF if ( @@ -526,6 +535,14 @@ def add_cli_args(parser: argparse.ArgumentParser): "shortest_queue", ], ) + # Expert parallelism + parser.add_argument( + "--expert-parallel-size", + "--ep-size", + type=int, + default=ServerArgs.ep_size, + help="The expert parallelism size.", + ) # Multi-node distributed serving parser.add_argument( @@ -681,6 +698,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.", ) + parser.add_argument( + "--enable-ep-moe", + action="store_true", + help="Enabling expert parallelism for moe. The ep size is equal to the tp size.", + ) parser.add_argument( "--enable-torch-compile", action="store_true", @@ -760,6 +782,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.ep_size = args.expert_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py new file mode 100644 index 00000000000..4d9fd435edb --- /dev/null +++ b/test/srt/test_moe_ep.py @@ -0,0 +1,113 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEpMoE(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +class TestEpMoEFP8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--ep-size", + "2", + "--enable-ep-moe", + "--quantization", + "fp8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.5 + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.8 + + +if __name__ == "__main__": + unittest.main() From 2fdaffbe1719054c14f5edb101abc1eec2b39ae9 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Fri, 6 Dec 2024 17:02:11 -0800 Subject: [PATCH 2/5] Initial version of Granite connector --- python/sglang/lang/chat_template.py | 31 ++ python/sglang/srt/models/granite.py | 518 ++++++++++++++++++++++++++++ 2 files changed, 549 insertions(+) create mode 100644 python/sglang/srt/models/granite.py diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 3e5ac8dd522..5a0474fa894 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -320,6 +320,28 @@ def get_chat_template_by_model_path(model_path): ) ) +register_chat_template( + ChatTemplate( + name="granite3-instruct", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_of_role|>system<|end_of_role|>", + "<|end_of_text|>", + ), + "user": ( + "<|start_of_role|>user<|end_of_role|>", + "<|end_of_text|>", + ), + "assistant": ( + "<|start_of_role|>assistant<|end_of_role|>", + "<|end_of_text|>", + ), + }, + stop_str=("<|end_of_text|>",), + ) +) + @register_chat_template_matching_function def match_dbrx(model_path: str): @@ -402,6 +424,15 @@ def match_c4ai_command_r(model_path: str): return get_chat_template("c4ai-command-r") +@register_chat_template_matching_function +def match_granite_instruct(model_path: str): + model_path = model_path.lower() + # When future versions of Granite are released, this code may + # need to be updated. For now, assume that the Granite 3.0 + # template works across the board. + if "granite" in model_path and "instruct" in model_path: + return get_chat_template("granite-3-instruct") + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py new file mode 100644 index 00000000000..f4f7b067290 --- /dev/null +++ b/python/sglang/srt/models/granite.py @@ -0,0 +1,518 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 +"""Inference-only Granite model compatible with HuggingFace weights.""" + +import logging +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import GraniteConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.rotary_embedding import get_rope + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class GraniteMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + def __init__( + self, + config: GraniteConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + # if residual is None: + # residual = hidden_states + # hidden_states = self.input_layernorm(hidden_states) + # else: + # hidden_states, residual = self.input_layernorm(hidden_states, residual) + # hidden_states = self.self_attn( + # positions=positions, + # hidden_states=hidden_states, + # forward_batch=forward_batch, + # ) + # Code from vllm head + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + # Old code + #hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + # Code from vllm head + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states, residual + + +class GraniteModel(nn.Module): + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + GraniteDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + hidden_states *= self.config.embedding_multiplier + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GraniteForCausalLM(nn.Module): + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = GraniteModel(config, quant_config=quant_config) + # If tie_word_embeddings == True, then input and output embeddings are + # the same tensor. Enforce during object creation so that weights will + # load correctly even if the LM head weights don't have a separate entry + # in the state dict. + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + if not get_embedding: + logits_processor_output: LogitsProcessorOutput = self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + #print(f"Logits processor output before:\n{logits_processor_output}") + logits_processor_output.next_token_logits /= self.config.logits_scaling + #print(f"Logits processor output after:\n{logits_processor_output}") + # TODO: Divide logits indside logits_processor_output by self.config.logits_scaling + return logits_processor_output + else: + return self.pooler(hidden_states, forward_batch) + + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name in ["kv_proj"]: + return self.config.hidden_size, self.config.hidden_size // ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + + def get_module_name_from_weight_name(self, name): + for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading kv_scale from ckpts towards new design. + if name.endswith(".kv_scale") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def get_weights_by_name( + self, name: str, truncate_size: int = 100, tp_size: int = 1 + ) -> Optional[torch.Tensor]: + """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. + + Only used for unit test with an unoptimized performance. + For optimized performance, please use torch.save and torch.load. + """ + try: + if name == "lm_head.weight" and self.config.tie_word_embeddings: + logger.info( + "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight." + ) + return ( + self.model.embed_tokens.weight.cpu() + .to(torch.float32) + .numpy() + .tolist()[:truncate_size] + ) + + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.named_parameters()) + param = params_dict[mapped_name] + if mapped_shard_id is not None: + if mapped_shard_id in ["q", "k", "v"]: + num_heads = self.config.num_attention_heads // tp_size + num_kv_heads = self.config.num_key_value_heads // tp_size + head_dim = ( + self.config.hidden_size // self.config.num_attention_heads + ) + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + intermediate_size = self.config.intermediate_size + slice_size = intermediate_size // tp_size + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size + + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + if tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + + except Exception: + logger.error( + f"Error getting weights by name {name} in GraniteForCausalLM: {get_exception_traceback()}" + ) + return None + +EntryClass = [GraniteForCausalLM] From dbd6194be60d4baf882e08526ccfdd6f621394c2 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Tue, 10 Dec 2024 17:15:05 -0800 Subject: [PATCH 3/5] Additional updates for Granite compatibility --- python/sglang/lang/chat_template.py | 2 +- python/sglang/srt/layers/logits_processor.py | 11 +++- python/sglang/srt/models/granite.py | 60 +++++++++----------- test/srt/models/test_generation_models.py | 3 +- 4 files changed, 41 insertions(+), 35 deletions(-) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 5a0474fa894..41b8ec0d820 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -322,7 +322,7 @@ def get_chat_template_by_model_path(model_path): register_chat_template( ChatTemplate( - name="granite3-instruct", + name="granite-3-instruct", default_system_prompt=None, role_prefix_and_suffix={ "system": ( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 274c4c311ec..c7121094efd 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -89,9 +89,11 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): class LogitsProcessor(nn.Module): - def __init__(self, config, skip_all_gather: bool = False): + def __init__(self, config, skip_all_gather: bool = False, + logit_scale: Optional[float] = None): super().__init__() self.config = config + self.logit_scale = logit_scale self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 ) @@ -233,6 +235,9 @@ def forward( all_logits = self._get_logits(states, lm_head) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) + + # The LM head's weights may be zero-padded for parallelism. Remove any + # extra logits that this padding may have produced. all_logits = all_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): @@ -288,6 +293,10 @@ def _get_logits( else: # GGUF models logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + + # Optional scaling factor, backported from vLLM 0.4 + if self.logit_scale is not None: + logits.mul_(self.logit_scale) # In-place multiply return logits diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index f4f7b067290..d681fae97f6 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -227,34 +227,20 @@ def forward( residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention - # if residual is None: - # residual = hidden_states - # hidden_states = self.input_layernorm(hidden_states) - # else: - # hidden_states, residual = self.input_layernorm(hidden_states, residual) - # hidden_states = self.self_attn( - # positions=positions, - # hidden_states=hidden_states, - # forward_batch=forward_batch, - # ) - # Code from vllm head - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, - ) - hidden_states = residual + hidden_states * self.residual_multiplier + ) * self.residual_multiplier # multiplier for Maximal Update Parameterization # Fully Connected - # Old code - #hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - # Code from vllm head - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) * self.residual_multiplier return hidden_states, residual @@ -270,7 +256,7 @@ def __init__( self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, - config.hidden_size, + config.hidden_size ) self.layers = nn.ModuleList( [ @@ -321,13 +307,19 @@ def __init__( # the same tensor. Enforce during object creation so that weights will # load correctly even if the LM head weights don't have a separate entry # in the state dict. + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) if self.config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + self.lm_head.tie_weights(self.model.embed_tokens) + + # Granite logit scaling factors are applied via division, but + # LogitsProcessor expects a multiplicative factor. + if hasattr(config, "logits_scaling"): + logit_scale = 1.0 / config.logits_scaling else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) - self.logits_processor = LogitsProcessor(config) + logit_scale = None + self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -352,10 +344,6 @@ def forward( logits_processor_output: LogitsProcessorOutput = self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch ) - #print(f"Logits processor output before:\n{logits_processor_output}") - logits_processor_output.next_token_logits /= self.config.logits_scaling - #print(f"Logits processor output after:\n{logits_processor_output}") - # TODO: Divide logits indside logits_processor_output by self.config.logits_scaling return logits_processor_output else: return self.pooler(hidden_states, forward_batch) @@ -419,6 +407,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if name.startswith("model.vision_tower") and name not in params_dict: continue + if "lm_head.weight" in name and self.config.tie_word_embeddings: + # Input and output embeddings are tied, so the output embeddings + # may not be present in the checkpoint. We assume that the input + # embeddings are always present in the checkpoint. + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: @@ -432,6 +425,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + # This block only runs if the preceding for loop doesn't find + # a match for `name` in `stacked_params_mapping`. + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index d9f1795341c..b7b847c9447 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -57,6 +57,7 @@ class ModelCase: ModelCase("openai-community/gpt2"), ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), + ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True) ] TORCH_DTYPES = [torch.float16] @@ -67,7 +68,7 @@ class TestGenerationModels(unittest.TestCase): @classmethod def setUpClass(cls): mp.set_start_method("spawn", force=True) - + def assert_close_logits_and_output_strs( self, prompts: List[str], From 2ce2719da1bbc2ddba08a4f6bfb3bf5696d3c828 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Tue, 10 Dec 2024 17:16:37 -0800 Subject: [PATCH 4/5] Reformat code files --- python/sglang/lang/chat_template.py | 39 ++++++++++---------- python/sglang/srt/layers/logits_processor.py | 5 ++- python/sglang/srt/models/granite.py | 21 ++++++----- test/srt/models/test_generation_models.py | 4 +- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 41b8ec0d820..4a774c4fb6b 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -321,25 +321,25 @@ def get_chat_template_by_model_path(model_path): ) register_chat_template( - ChatTemplate( - name="granite-3-instruct", - default_system_prompt=None, - role_prefix_and_suffix={ - "system": ( - "<|start_of_role|>system<|end_of_role|>", - "<|end_of_text|>", - ), - "user": ( - "<|start_of_role|>user<|end_of_role|>", - "<|end_of_text|>", - ), - "assistant": ( - "<|start_of_role|>assistant<|end_of_role|>", - "<|end_of_text|>", - ), - }, - stop_str=("<|end_of_text|>",), - ) + ChatTemplate( + name="granite-3-instruct", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_of_role|>system<|end_of_role|>", + "<|end_of_text|>", + ), + "user": ( + "<|start_of_role|>user<|end_of_role|>", + "<|end_of_text|>", + ), + "assistant": ( + "<|start_of_role|>assistant<|end_of_role|>", + "<|end_of_text|>", + ), + }, + stop_str=("<|end_of_text|>",), + ) ) @@ -433,6 +433,7 @@ def match_granite_instruct(model_path: str): if "granite" in model_path and "instruct" in model_path: return get_chat_template("granite-3-instruct") + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index c7121094efd..6dba86f96bc 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -89,8 +89,9 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): class LogitsProcessor(nn.Module): - def __init__(self, config, skip_all_gather: bool = False, - logit_scale: Optional[float] = None): + def __init__( + self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None + ): super().__init__() self.config = config self.logit_scale = logit_scale diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index d681fae97f6..d207ff61b26 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -122,7 +122,7 @@ def __init__( ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = config.attention_multiplier + self.scaling = config.attention_multiplier self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings @@ -232,11 +232,14 @@ def forward( hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) * self.residual_multiplier # multiplier for Maximal Update Parameterization + hidden_states = ( + self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + * self.residual_multiplier + ) # multiplier for Maximal Update Parameterization # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) @@ -255,8 +258,7 @@ def __init__( self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size + config.vocab_size, config.hidden_size ) self.layers = nn.ModuleList( [ @@ -304,7 +306,7 @@ def __init__( self.quant_config = quant_config self.model = GraniteModel(config, quant_config=quant_config) # If tie_word_embeddings == True, then input and output embeddings are - # the same tensor. Enforce during object creation so that weights will + # the same tensor. Enforce during object creation so that weights will # load correctly even if the LM head weights don't have a separate entry # in the state dict. self.lm_head = ParallelLMHead( @@ -511,4 +513,5 @@ def get_weights_by_name( ) return None + EntryClass = [GraniteForCausalLM] diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index b7b847c9447..fd27d5c07b6 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -57,7 +57,7 @@ class ModelCase: ModelCase("openai-community/gpt2"), ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), - ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True) + ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), ] TORCH_DTYPES = [torch.float16] @@ -68,7 +68,7 @@ class TestGenerationModels(unittest.TestCase): @classmethod def setUpClass(cls): mp.set_start_method("spawn", force=True) - + def assert_close_logits_and_output_strs( self, prompts: List[str], From 268f6a6f6f06375a22f7f2906f6c077a31b136f3 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Tue, 10 Dec 2024 17:27:41 -0800 Subject: [PATCH 5/5] Add Granite to list of supported models --- docs/references/supported_models.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index bf1044f8498..9dafc3d2a3d 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -29,6 +29,7 @@ - SmolLM - GLM-4 - Phi-3-Small +- IBM Granite 3 ## Embedding Models