Skip to content

Commit

Permalink
[Bugfix] Override dunder methods of placeholder modules (#11882)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Jan 9, 2025
1 parent 310aca8 commit 0bd1ff4
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 16 deletions.
47 changes: 44 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
from vllm_test_utils import monitor

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -323,3 +323,44 @@ def measure_current_non_torch():
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)


def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

def build_ctx():
return pytest.raises(ModuleNotFoundError,
match="No module named")

with build_ctx():
int(placeholder)

with build_ctx():
placeholder()

with build_ctx():
_ = placeholder.some_attr

with build_ctx():
# Test conflict with internal __name attribute
_ = placeholder.name

# OK to print the placeholder or use it in a f-string
_ = repr(placeholder)
_ = str(placeholder)

# No error yet; only error when it is used downstream
placeholder_attr = placeholder.placeholder_attr("attr")

with build_ctx():
int(placeholder_attr)

with build_ctx():
placeholder_attr()

with build_ctx():
_ = placeholder_attr.some_attr

with build_ctx():
# Test conflict with internal __module attribute
_ = placeholder_attr.module
189 changes: 176 additions & 13 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import zmq.asyncio
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
from typing_extensions import Never, ParamSpec, TypeIs, assert_never

import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
Expand Down Expand Up @@ -1627,24 +1627,183 @@ def get_vllm_optional_dependencies():
}


@dataclass(frozen=True)
class PlaceholderModule:
class _PlaceholderBase:
"""
Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because
:meth:`__getattr__` is not called when they are accessed.
See also:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
"""

def __getattr__(self, key: str) -> Never:
"""
The main class should implement this to throw an error
for attribute accesses representing downstream usage.
"""
raise NotImplementedError

# [Basic customization]

def __lt__(self, other: object):
return self.__getattr__("__lt__")

def __le__(self, other: object):
return self.__getattr__("__le__")

def __eq__(self, other: object):
return self.__getattr__("__eq__")

def __ne__(self, other: object):
return self.__getattr__("__ne__")

def __gt__(self, other: object):
return self.__getattr__("__gt__")

def __ge__(self, other: object):
return self.__getattr__("__ge__")

def __hash__(self):
return self.__getattr__("__hash__")

def __bool__(self):
return self.__getattr__("__bool__")

# [Callable objects]

def __call__(self, *args: object, **kwargs: object):
return self.__getattr__("__call__")

# [Container types]

def __len__(self):
return self.__getattr__("__len__")

def __getitem__(self, key: object):
return self.__getattr__("__getitem__")

def __setitem__(self, key: object, value: object):
return self.__getattr__("__setitem__")

def __delitem__(self, key: object):
return self.__getattr__("__delitem__")

# __missing__ is optional according to __getitem__ specification,
# so it is skipped

# __iter__ and __reversed__ have a default implementation
# based on __len__ and __getitem__, so they are skipped.

# [Numeric Types]

def __add__(self, other: object):
return self.__getattr__("__add__")

def __sub__(self, other: object):
return self.__getattr__("__sub__")

def __mul__(self, other: object):
return self.__getattr__("__mul__")

def __matmul__(self, other: object):
return self.__getattr__("__matmul__")

def __truediv__(self, other: object):
return self.__getattr__("__truediv__")

def __floordiv__(self, other: object):
return self.__getattr__("__floordiv__")

def __mod__(self, other: object):
return self.__getattr__("__mod__")

def __divmod__(self, other: object):
return self.__getattr__("__divmod__")

def __pow__(self, other: object, modulo: object = ...):
return self.__getattr__("__pow__")

def __lshift__(self, other: object):
return self.__getattr__("__lshift__")

def __rshift__(self, other: object):
return self.__getattr__("__rshift__")

def __and__(self, other: object):
return self.__getattr__("__and__")

def __xor__(self, other: object):
return self.__getattr__("__xor__")

def __or__(self, other: object):
return self.__getattr__("__or__")

# r* and i* methods have lower priority than
# the methods for left operand so they are skipped

def __neg__(self):
return self.__getattr__("__neg__")

def __pos__(self):
return self.__getattr__("__pos__")

def __abs__(self):
return self.__getattr__("__abs__")

def __invert__(self):
return self.__getattr__("__invert__")

# __complex__, __int__ and __float__ have a default implementation
# based on __index__, so they are skipped.

def __index__(self):
return self.__getattr__("__index__")

def __round__(self, ndigits: object = ...):
return self.__getattr__("__round__")

def __trunc__(self):
return self.__getattr__("__trunc__")

def __floor__(self):
return self.__getattr__("__floor__")

def __ceil__(self):
return self.__getattr__("__ceil__")

# [Context managers]

def __enter__(self):
return self.__getattr__("__enter__")

def __exit__(self, *args: object, **kwargs: object):
return self.__getattr__("__exit__")


class PlaceholderModule(_PlaceholderBase):
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str

def __init__(self, name: str) -> None:
super().__init__()

# Apply name mangling to avoid conflicting with module attributes
self.__name = name

def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)

def __getattr__(self, key: str):
name = self.name
name = self.__name

try:
importlib.import_module(self.name)
importlib.import_module(name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
Expand All @@ -1657,17 +1816,21 @@ def __getattr__(self, key: str):
"when the original module can be imported")


@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
class _PlaceholderModuleAttr(_PlaceholderBase):

def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
super().__init__()

# Apply name mangling to avoid conflicting with module attributes
self.__module = module
self.__attr_path = attr_path

def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
return _PlaceholderModuleAttr(self.__module,
f"{self.__attr_path}.{attr_path}")

def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
getattr(self.__module, f"{self.__attr_path}.{key}")

raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
Expand Down

0 comments on commit 0bd1ff4

Please sign in to comment.