Skip to content

Commit

Permalink
[v1] fix compilation cache (vllm-project#11598)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: xcnick <[email protected]>
  • Loading branch information
youkaichao authored and xcnick committed Dec 31, 2024
1 parent a54cf6c commit e455973
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 14 deletions.
15 changes: 13 additions & 2 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
initialized randomly with a fixed seed.
"""
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -54,6 +54,16 @@ class LlamaConfig:
tractable_init: bool = False
random_seed: int = 0

def compute_hash(self) -> str:
factors: List[Any] = []
for k, v in self.__dict__.items():
if k == "random_seed":
continue
factors.append((k, v))
factors.sort()
import hashlib
return hashlib.md5(str(factors).encode()).hexdigest()

def __post_init__(self):
assert self.mlp_size >= self.hidden_size

Expand Down Expand Up @@ -263,7 +273,8 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )

vllm_config = VllmConfig(compilation_config=compilation_config)
vllm_config = VllmConfig(compilation_config=compilation_config,
additional_config=llama_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
Expand Down
22 changes: 13 additions & 9 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,21 +619,28 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
self.capture_sizes)

# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.capture_sizes,
)

def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
# no specific sizes to compile
if self.is_last_graph and not self.to_be_compiled_sizes:
end_monitoring_torch_compile(self.vllm_config)
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)

runtime_shape = args[self.sym_shape_indices[0]]
Expand Down Expand Up @@ -662,10 +669,7 @@ def __call__(self, *args) -> Any:

# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:

# save the hash of the inductor graph for the next run
self.compilation_config.inductor_hash_cache.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
self.check_for_ending_compilation()

if not entry.use_cudagraph:
return entry.runnable(*args)
Expand Down
45 changes: 42 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
Union)
Final, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, Type, Union)

import torch
from pydantic import BaseModel, Field, PrivateAttr
Expand Down Expand Up @@ -75,6 +75,12 @@
PretrainedConfig]]


class SupportsHash(Protocol):

def compute_hash(self) -> str:
...


class ModelConfig:
"""Configuration for the model.
Expand Down Expand Up @@ -2969,6 +2975,10 @@ class VllmConfig:
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
instance_id: str = ""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -3000,33 +3010,62 @@ def compute_hash(self) -> str:
vllm_factors.append(__version__)
if self.model_config:
vllm_factors.append(self.model_config.compute_hash())
else:
vllm_factors.append("None")
if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash())
else:
vllm_factors.append("None")
if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash())
else:
vllm_factors.append("None")
if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash())
else:
vllm_factors.append("None")
if self.device_config:
vllm_factors.append(self.device_config.compute_hash())
else:
vllm_factors.append("None")
if self.load_config:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:
vllm_factors.append("None")
if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())

else:
vllm_factors.append("None")
if self.additional_config:
vllm_factors.append(self.additional_config.compute_hash())
else:
vllm_factors.append("None")
factors.append(vllm_factors)

hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config

self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
Expand Down

0 comments on commit e455973

Please sign in to comment.