Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy checking for vllm/compilation #11496

Merged
merged 3 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ def produce_guards_expression(self, *args, **kwargs):
return ""


def wrap_inductor(graph,
def wrap_inductor(graph: fx.GraphModule,
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):
use_inductor: bool = True) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
global compilation_start_time
Expand Down Expand Up @@ -209,7 +209,7 @@ def wrap_inductor(graph,
returns_tuple = graph_returns_tuple(graph)

# this is the graph we return to Dynamo to run
def compiled_graph(*args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this type is not available in pytorch 2.5 yet.

def compiled_graph(*args) -> Optional[fx.CompiledFxGraph]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucas-tucker this is actually incorrect. the function is named compiled_graph, and the function itself functions as a graph. it does not return a graph. the same applies to copy_and_call below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. I was basing this off of the result of the _lookup_graph call from inductor_compiled_graph. I suppose the returns_tuple variable tells us that the nn is represented as a tuple data structure rather than a proper graph?

# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
Expand Down Expand Up @@ -247,7 +247,7 @@ def _check_can_cache(*args, **kwargs):
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
return

def _get_shape_env():
def _get_shape_env() -> AlwaysHitShapeEnv:
return AlwaysHitShapeEnv()

with patch(# for hijacking the hash of the compiled graph
Expand Down Expand Up @@ -537,7 +537,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
example_inputs[x].clone() for x in self.sym_tensor_indices
]

def copy_and_call(*args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is incorrect. let me fix it.

def copy_and_call(*args) -> fx.GraphModule:
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
Expand Down
3 changes: 2 additions & 1 deletion vllm/compilation/multi_output_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor import pattern_matcher as pm
from torch._ops import OpOverload
from torch.fx import Node

from vllm.compilation.fx_utils import find_auto_fn

Expand Down Expand Up @@ -97,7 +98,7 @@ def insert_getitems(self, tuple_node: fx.Node,
self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices)

def insert_auto_fn(self, op: OpOverload, kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ProExpertProg to check if this is accurate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is right, but for consistency, can we use fx.Node instead of Node, like elsewhere?

def insert_auto_fn(self, op: OpOverload, kwargs) -> Node:
"""
Insert an auto_functionalized node with the given op and kwargs.
"""
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Any, Dict, List

from torch import fx as fx

Expand Down Expand Up @@ -53,7 +53,7 @@ def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)

def __getstate__(self):
def __getstate__(self) -> Dict[str, List[Any]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of
Expand Down
Loading