Skip to content

Commit

Permalink
[torch.compile] consider relevant code in compilation cache (#11614)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 8, 2025
1 parent cfd3219 commit f121411
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 35 deletions.
70 changes: 62 additions & 8 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule,
example_inputs,
additional_inductor_config,
compilation_config: CompilationConfig,
vllm_backend: "VllmBackend",
graph_index: int = 0,
num_graphs: int = 1,
runtime_shape: Optional[int] = None,
Expand Down Expand Up @@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule,
# see https://github.com/pytorch/pytorch/issues/138980
graph = copy.deepcopy(graph)

cache_data = compilation_config.inductor_hash_cache
cache_data = vllm_backend.inductor_hash_cache
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
Expand All @@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule,
hash_str, example_inputs, True, False)
assert inductor_compiled_graph is not None, (
"Inductor cache lookup failed. Please remove"
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)

# Inductor calling convention (function signature):
Expand Down Expand Up @@ -354,14 +355,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):

def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str], vllm_config: VllmConfig,
graph_pool):
graph_pool, vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend

def run(self, *args):
fake_args = [
Expand Down Expand Up @@ -389,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
self.vllm_backend,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None,
Expand All @@ -397,7 +400,7 @@ def call_module(self, target: torch.fx.node.Target,
self.module.__dict__[target] = PiecewiseBackend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape)
compiled_graph_for_general_shape, self.vllm_backend)

compilation_counter.num_piecewise_capturable_graphs_seen += 1

Expand Down Expand Up @@ -430,6 +433,7 @@ class VllmBackend:
post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
inductor_hash_cache: InductorHashCache

def __init__(
self,
Expand Down Expand Up @@ -472,6 +476,53 @@ def configure_post_pass(self):

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.

# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
vllm_config = self.vllm_config
config_hash = vllm_config.compute_hash()

# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files = list(
sorted(self.compilation_config.traced_files))
self.compilation_config.traced_files.clear()
logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
"\n".join(forward_code_files))
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
with open(filepath) as f:
hash_content.append(f.read())
import hashlib
code_hash = hashlib.md5(
"\n".join(hash_content).encode()).hexdigest()

# combine the two hashes to generate the cache dir
hash_key = hashlib.md5(
f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)

disabled = envs.VLLM_DISABLE_COMPILE_CACHE
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info("Using cache directory: %s for vLLM's torch.compile",
cache_dir)

# when dynamo calls the backend, it means the bytecode
# transform and analysis are done
compilation_counter.num_graphs_seen += 1
Expand Down Expand Up @@ -507,8 +558,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config,
self.graph_pool).run(*example_inputs)
self.vllm_config, self.graph_pool,
self).run(*example_inputs)

self._called = True

Expand Down Expand Up @@ -577,7 +628,8 @@ class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable):
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
Expand All @@ -597,6 +649,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend

self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
Expand Down Expand Up @@ -634,7 +687,7 @@ def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
self.vllm_backend.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

def __call__(self, *args) -> Any:
Expand Down Expand Up @@ -662,6 +715,7 @@ def __call__(self, *args) -> Any:
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
self.vllm_backend,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape,
Expand Down
28 changes: 27 additions & 1 deletion vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import inspect
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
from unittest.mock import patch

import torch
import torch.nn as nn
from torch._dynamo.symbolic_convert import InliningInstructionTranslator

from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
Expand Down Expand Up @@ -196,7 +198,31 @@ def __call__(self, *args, **kwargs):
# we need to control all the compilation of the model.
torch._dynamo.eval_frame.remove_from_cache(
self.original_code_object)
return self.compiled_callable(*args, **kwargs)

# collect all relevant files traced by Dynamo,
# so that the compilation cache can trigger re-compilation
# properly when any of these files change.

# 1. the file containing the top-level forward function
self.vllm_config.compilation_config.traced_files.add(
self.original_code_object.co_filename)

# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call

def patched_inline_call(parent, func, args, kwargs):
code = func.get_code()
self.vllm_config.compilation_config.traced_files.add(
code.co_filename)
return inline_call(parent, func, args, kwargs)

with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call):
output = self.compiled_callable(*args, **kwargs)
return output

# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
Expand Down
29 changes: 3 additions & 26 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import enum
import hashlib
import json
import os
import sys
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -2778,9 +2777,8 @@ def model_post_init(self, __context: Any) -> None:
# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
traced_files: Set[str] = PrivateAttr
compilation_time: float = PrivateAttr
# should be InductorHashCache, but Pydantic does not support it
inductor_hash_cache: Any = PrivateAttr

# Per-model forward context
# Mainly used to store attention cls
Expand Down Expand Up @@ -2818,6 +2816,7 @@ def __repr__(self) -> str:
"compilation_time",
"bs_to_padded_graph_size",
"pass_config",
"traced_files",
}
return self.model_dump_json(exclude=exclude, exclude_unset=True)

Expand Down Expand Up @@ -2877,6 +2876,7 @@ def model_post_init(self, __context: Any) -> None:

self.enabled_custom_ops = Counter()
self.disabled_custom_ops = Counter()
self.traced_files = set()
self.static_forward_context = {}
self.compilation_time = 0.0

Expand All @@ -2899,29 +2899,6 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE

if not self.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.
hash_key = vllm_config.compute_hash()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
f"rank_{vllm_config.parallel_config.rank}")
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir

disabled = envs.VLLM_DISABLE_COMPILE_CACHE
from vllm.compilation.backends import InductorHashCache
self.inductor_hash_cache: InductorHashCache = InductorHashCache(
self.cache_dir, disabled=disabled)
if disabled:
logger.info("vLLM's torch.compile cache is disabled.")
else:
logger.info(
"Using cache directory: %s for vLLM's torch.compile",
self.cache_dir)

from vllm.compilation.backends import VllmBackend
return VllmBackend(vllm_config)

Expand Down
7 changes: 7 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,6 +1108,13 @@ class IntermediateTensors:

tensors: Dict[str, torch.Tensor]

def __init__(self, tensors):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self.tensors = tensors

def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
Expand Down

0 comments on commit f121411

Please sign in to comment.