Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[optimization] remove python function call for custom activation op #11885

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,6 @@ def register_fake(fn):
from torch.library import impl_abstract as register_fake


# activation ops
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul(out, x)


def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x)


def fatrelu_and_mul(out: torch.Tensor,
x: torch.Tensor,
threshold: float = 0.0) -> None:
torch.ops._C.fatrelu_and_mul(out, x, threshold)


def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_fast(out, x)


def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_new(out, x)


def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_quick(out, x)


# page attention ops
def paged_attention_v1(
out: torch.Tensor,
Expand Down
74 changes: 45 additions & 29 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
if current_platform.is_cuda_alike():
self.op = torch.ops._C.fatrelu_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
Expand All @@ -39,12 +41,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x1 * x2

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.fatrelu_and_mul(out, x, self.threshold)
self.op(out, x, self.threshold)
return out


Expand Down Expand Up @@ -103,34 +103,37 @@ def __init__(self, approximate: str = "none"):
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike():
self.gelu_and_mul = torch.ops._C.gelu_and_mul
self.gelu_tanh_and_mul = torch.ops._C.gelu_tanh_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.gelu_and_mul = ipex_ops.gelu_and_mul
self.gelu_tanh_and_mul = ipex_ops.gelu_tanh_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
self.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
self.gelu_tanh_and_mul(out, x)
cennn marked this conversation as resolved.
Show resolved Hide resolved
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
self.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
self.gelu_tanh_and_mul(out, x)
return out

def extra_repr(self) -> str:
Expand All @@ -140,65 +143,78 @@ def extra_repr(self) -> str:
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
self.op = torch.ops._C.gelu_new
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_new

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

return ops.gelu_new(x)
return self.op(x)


@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
cennn marked this conversation as resolved.
Show resolved Hide resolved
self.op = torch.ops._C.gelu_fast
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_fast

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

return ops.gelu_fast(x)
return self.op(x)


@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
self.op = ops.gelu_quick
cennn marked this conversation as resolved.
Show resolved Hide resolved
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.gelu_quick

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
self.op(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
self.op(out, x)
return out

# TODO implement forward_xpu for QuickGELU
Expand Down
Loading