Skip to content

Commit

Permalink
[platform] support pytorch custom op pluggable (vllm-project#11328)
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
  • Loading branch information
wangxiyuan authored and hmellor committed Jan 12, 2025
1 parent 48b7b71 commit d2cc908
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
7 changes: 7 additions & 0 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def forward_hpu(self, *args, **kwargs):
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
Expand All @@ -81,6 +86,8 @@ def dispatch_forward(self):
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_out_of_tree():
return self.forward_oot
else:
return self.forward_cuda

Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class PlatformEnum(enum.Enum):
CPU = enum.auto()
NEURON = enum.auto()
OPENVINO = enum.auto()
OOT = enum.auto()
UNSPECIFIED = enum.auto()


Expand Down Expand Up @@ -107,6 +108,9 @@ def is_neuron(self) -> bool:
def is_openvino(self) -> bool:
return self._enum == PlatformEnum.OPENVINO

def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT

def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
Expand Down

0 comments on commit d2cc908

Please sign in to comment.