Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[optimization] remove python function call for custom op #11750

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def register_fake(fn):


# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x)


def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)

Expand Down
17 changes: 11 additions & 6 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict


Expand Down Expand Up @@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
import intel_extension_for_pytorch as ipex
self.op = ipex.llm.functional.silu_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
self.op(out, x)
return out


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.scalar_type import scalar_types
Expand Down Expand Up @@ -301,7 +300,8 @@ def fused_marlin_moe(
False,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N))

intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a16=use_int8_w8a16,
block_shape=block_shape)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))

invoke_fused_moe_kernel(intermediate_cache2,
w2,
Expand Down
Loading