Skip to content

Commit

Permalink
Merge branch 'patch_2' into apply_plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxiyuan committed Jan 2, 2025
2 parents 3ab6610 + 0f3e0eb commit bdc342f
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()

@classmethod
def set_foward_method(cls, method):
"""Provide a way to register a custom forward method for a specific
backend."""
if getattr(cls, f"forward_{current_platform.device_name}", None):
raise ValueError(
f"Custom op {cls.__class__.__name__} already has a "
f"forward_{current_platform.device_name} method")
setattr(cls, f"forward_{current_platform.device_name}", method)

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

Expand Down Expand Up @@ -72,18 +82,15 @@ def dispatch_forward(self):
if not enabled:
return self.forward_native

if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_hpu():
return self.forward_hpu
elif current_platform.is_tpu():
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
else:
return self.forward_cuda
custom_forward_func = \
getattr(self, f"forward_{current_platform.device_name}", None)
if not custom_forward_func:
logger.warning(
"Custom op %s is not supported on %s, falling back "
"to native.", self.__class__.__name__,
current_platform.device_name)
return self.forward_native
return custom_forward_func

@classmethod
def enabled(cls) -> bool:
Expand Down

0 comments on commit bdc342f

Please sign in to comment.