Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Override dunder methods of placeholder modules #11882

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch
from vllm_test_utils import monitor

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -323,3 +323,44 @@ def measure_current_non_torch():
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)


def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")

def build_ctx():
return pytest.raises(ModuleNotFoundError,
match="No module named")

with build_ctx():
int(placeholder)

with build_ctx():
placeholder()

with build_ctx():
_ = placeholder.some_attr

with build_ctx():
# Test conflict with internal __name attribute
_ = placeholder.name

# OK to print the placeholder or use it in a f-string
_ = repr(placeholder)
_ = str(placeholder)

# No error yet; only error when it is used downstream
placeholder_attr = placeholder.placeholder_attr("attr")

with build_ctx():
int(placeholder_attr)

with build_ctx():
placeholder_attr()

with build_ctx():
_ = placeholder_attr.some_attr

with build_ctx():
# Test conflict with internal __module attribute
_ = placeholder_attr.module
189 changes: 176 additions & 13 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import zmq.asyncio
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
from typing_extensions import Never, ParamSpec, TypeIs, assert_never

import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
Expand Down Expand Up @@ -1594,24 +1594,183 @@ def get_vllm_optional_dependencies():
}


@dataclass(frozen=True)
class PlaceholderModule:
class _PlaceholderBase:
"""
Disallows downstream usage of placeholder modules.

We need to explicitly override each dunder method because
:meth:`__getattr__` is not called when they are accessed.

See also:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
"""

def __getattr__(self, key: str) -> Never:
"""
The main class should implement this to throw an error
for attribute accesses representing downstream usage.
"""
raise NotImplementedError

# [Basic customization]

def __lt__(self, other: object):
return self.__getattr__("__lt__")

def __le__(self, other: object):
return self.__getattr__("__le__")

def __eq__(self, other: object):
return self.__getattr__("__eq__")

def __ne__(self, other: object):
return self.__getattr__("__ne__")

def __gt__(self, other: object):
return self.__getattr__("__gt__")

def __ge__(self, other: object):
return self.__getattr__("__ge__")

def __hash__(self):
return self.__getattr__("__hash__")

def __bool__(self):
return self.__getattr__("__bool__")

# [Callable objects]

def __call__(self, *args: object, **kwargs: object):
return self.__getattr__("__call__")

# [Container types]

def __len__(self):
return self.__getattr__("__len__")

def __getitem__(self, key: object):
return self.__getattr__("__getitem__")

def __setitem__(self, key: object, value: object):
return self.__getattr__("__setitem__")

def __delitem__(self, key: object):
return self.__getattr__("__delitem__")

# __missing__ is optional according to __getitem__ specification,
# so it is skipped

# __iter__ and __reversed__ have a default implementation
# based on __len__ and __getitem__, so they are skipped.

# [Numeric Types]

def __add__(self, other: object):
return self.__getattr__("__add__")

def __sub__(self, other: object):
return self.__getattr__("__sub__")

def __mul__(self, other: object):
return self.__getattr__("__mul__")

def __matmul__(self, other: object):
return self.__getattr__("__matmul__")

def __truediv__(self, other: object):
return self.__getattr__("__truediv__")

def __floordiv__(self, other: object):
return self.__getattr__("__floordiv__")

def __mod__(self, other: object):
return self.__getattr__("__mod__")

def __divmod__(self, other: object):
return self.__getattr__("__divmod__")

def __pow__(self, other: object, modulo: object = ...):
return self.__getattr__("__pow__")

def __lshift__(self, other: object):
return self.__getattr__("__lshift__")

def __rshift__(self, other: object):
return self.__getattr__("__rshift__")

def __and__(self, other: object):
return self.__getattr__("__and__")

def __xor__(self, other: object):
return self.__getattr__("__xor__")

def __or__(self, other: object):
return self.__getattr__("__or__")

# r* and i* methods have lower priority than
# the methods for left operand so they are skipped

def __neg__(self):
return self.__getattr__("__neg__")

def __pos__(self):
return self.__getattr__("__pos__")

def __abs__(self):
return self.__getattr__("__abs__")

def __invert__(self):
return self.__getattr__("__invert__")

# __complex__, __int__ and __float__ have a default implementation
# based on __index__, so they are skipped.

def __index__(self):
return self.__getattr__("__index__")

def __round__(self, ndigits: object = ...):
return self.__getattr__("__round__")

def __trunc__(self):
return self.__getattr__("__trunc__")

def __floor__(self):
return self.__getattr__("__floor__")

def __ceil__(self):
return self.__getattr__("__ceil__")

# [Context managers]

def __enter__(self):
return self.__getattr__("__enter__")

def __exit__(self, *args: object, **kwargs: object):
return self.__getattr__("__exit__")


class PlaceholderModule(_PlaceholderBase):
"""
A placeholder object to use when a module does not exist.

This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name: str

def __init__(self, name: str) -> None:
super().__init__()

# Apply name mangling to avoid conflicting with module attributes
self.__name = name

def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)

def __getattr__(self, key: str):
name = self.name
name = self.__name

try:
importlib.import_module(self.name)
importlib.import_module(name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
Expand All @@ -1624,17 +1783,21 @@ def __getattr__(self, key: str):
"when the original module can be imported")


@dataclass(frozen=True)
class _PlaceholderModuleAttr:
module: PlaceholderModule
attr_path: str
class _PlaceholderModuleAttr(_PlaceholderBase):

def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
super().__init__()

# Apply name mangling to avoid conflicting with module attributes
self.__module = module
self.__attr_path = attr_path

def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.module,
f"{self.attr_path}.{attr_path}")
return _PlaceholderModuleAttr(self.__module,
f"{self.__attr_path}.{attr_path}")

def __getattr__(self, key: str):
getattr(self.module, f"{self.attr_path}.{key}")
getattr(self.__module, f"{self.__attr_path}.{key}")

raise AssertionError("PlaceholderModule should not be used "
"when the original module can be imported")
Expand Down
Loading