Skip to content

Commit

Permalink
[optimization] remove python function call for custom op (#11750)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 7, 2025
1 parent c0efe92 commit 869579a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 13 deletions.
4 changes: 0 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def register_fake(fn):


# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)

Expand Down
17 changes: 11 additions & 6 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict


Expand Down Expand Up @@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
self.op = ipex.llm.functional.silu_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
self.op(out, x)
return out


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
Expand Down Expand Up @@ -301,7 +300,8 @@ def fused_marlin_moe(
False,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))

intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(intermediate_cache2,
w2,
Expand Down

0 comments on commit 869579a

Please sign in to comment.