Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] consider relevant code in compilation cache #11614

Merged
merged 11 commits into from
Jan 8, 2025
70 changes: 62 additions & 8 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
vllm_backend: "VllmBackend",
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
Expand Down Expand Up @@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)

cache_data = compilation_config.inductor_hash_cache
cache_data = vllm_backend.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
Expand All @@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)

# Inductor calling convention (function signature):
Expand Down Expand Up @@ -354,14 +355,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):

def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str], vllm_config: VllmConfig,
graph_pool):
graph_pool, vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend

def run(self, *args):
fake_args = [
Expand Down Expand Up @@ -389,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
self.vllm_backend,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
Expand All @@ -397,7 +400,7 @@ def call_module(self, target: torch.fx.node.Target,
self.module.__dict__[target] = PiecewiseBackend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape)
compiled_graph_for_general_shape, self.vllm_backend)

compilation_counter.num_piecewise_capturable_graphs_seen += 1

Expand Down Expand Up @@ -430,6 +433,7 @@ class VllmBackend:
post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
inductor_hash_cache: InductorHashCache

def __init__(
self,
Expand Down Expand Up @@ -472,6 +476,53 @@ def configure_post_pass(self):

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.

# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config = self.vllm_config
config_hash = vllm_config.compute_hash()

# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files = list(
sorted(self.compilation_config.traced_files))
self.compilation_config.traced_files.clear()
logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
"\n".join(forward_code_files))
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
with open(filepath) as f:
hash_content.append(f.read())
import hashlib
code_hash = hashlib.md5(
"\n".join(hash_content).encode()).hexdigest()

# combine the two hashes to generate the cache dir
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
Comment on lines +508 to +509
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not take the whole hash?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this keeps consistent with previous behavior that VllmConfig.compute_hash uses 10 chars for the hash key.

cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)

disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
cache_dir)

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
Expand Down Expand Up @@ -507,8 +558,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config,
self.graph_pool).run(*example_inputs)
self.vllm_config, self.graph_pool,
self).run(*example_inputs)

self._called = True

Expand Down Expand Up @@ -577,7 +628,8 @@ class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable):
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
Expand All @@ -597,6 +649,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend

self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
Expand Down Expand Up @@ -634,7 +687,7 @@ def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
self.vllm_backend.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

def __call__(self, *args) -> Any:
Expand Down Expand Up @@ -662,6 +715,7 @@ def __call__(self, *args) -> Any:
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
self.vllm_backend,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
Expand Down
28 changes: 27 additions & 1 deletion vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import inspect
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
from unittest.mock import patch

import torch
import torch.nn as nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator

from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
Expand Down Expand Up @@ -196,7 +198,31 @@ def __call__(self, *args, **kwargs):
# we need to control all the compilation of the model.
torch._dynamo.eval_frame.remove_from_cache(
self.original_code_object)
return self.compiled_callable(*args, **kwargs)

# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.

# 1. the file containing the top-level forward function
self.vllm_config.compilation_config.traced_files.add(
self.original_code_object.co_filename)

# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call

def patched_inline_call(parent, func, args, kwargs):
code = func.get_code()
self.vllm_config.compilation_config.traced_files.add(
code.co_filename)
return inline_call(parent, func, args, kwargs)

with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call):
output = self.compiled_callable(*args, **kwargs)
return output

# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
Expand Down
29 changes: 3 additions & 26 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import enum
import hashlib
import json
import os
import sys
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -2778,9 +2777,8 @@ def model_post_init(self, __context: Any) -> None:
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
traced_files: Set[str] = PrivateAttr
compilation_time: float = PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache: Any = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
Expand Down Expand Up @@ -2818,6 +2816,7 @@ def __repr__(self) -> str:
"compilation_time",
"bs_to_padded_graph_size",
"pass_config",
"traced_files",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)

Expand Down Expand Up @@ -2877,6 +2876,7 @@ def model_post_init(self, __context: Any) -> None:

self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.traced_files = set()
self.static_forward_context = {}
self.compilation_time = 0.0

Expand All @@ -2899,29 +2899,6 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE

if not self.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key = vllm_config.compute_hash()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir

disabled = envs.VLLM_DISABLE_COMPILE_CACHE
from vllm.compilation.backends import InductorHashCache
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
self.cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info(
"Using cache directory: %s for vLLM's torch.compile",
self.cache_dir)

from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)

Expand Down
7 changes: 7 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,13 @@ class IntermediateTensors:

tensors: Dict[str, torch.Tensor]

def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors
Comment on lines +1111 to +1116
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that we'll need to do for every dataclass that's used during model execution?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we need to do this for every dataclass that is created during model execution. Right now I only find IntermediateTensors, do you know anything else?


def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
Expand Down
Loading