From 540bb7870b764bdd0a599e68a3d48a5c3b07fd94 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 29 Dec 2024 17:03:49 +0800 Subject: [PATCH] [v1][bugfix] fix cudagraph with inplace buffer assignment (#11596) Signed-off-by: youkaichao --- vllm/compilation/wrapper.py | 10 +++++++++- vllm/model_executor/layers/rotary_embedding.py | 11 +---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index c10241b483169..e3260a10c02ae 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -28,11 +28,12 @@ def __init__(self, compiled_callable: Optional[Callable] = None, compilation_level: int = 0): + vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config if compiled_callable is None: # default compilation settings # compiling the forward method - vllm_config = get_current_vllm_config() backend = vllm_config.compilation_config.init_backend(vllm_config) compiled_callable = torch.compile( @@ -82,6 +83,13 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): self.compiled_codes.append(new_code) + if self.vllm_config.compilation_config.use_cudagraph and \ + "update" in new_code.co_names: + import depyf + src = depyf.decompile(new_code) + msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + raise RuntimeError(msg) + @contextmanager def dispatch_to_code(self, index: int): """Context manager to dispatch to the compiled code. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 117fe086e5e87..6695d44dfa32b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -541,19 +541,12 @@ def __init__( short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(dtype) - self.register_buffer("short_cos_sin_cache", - short_cache, - persistent=False) long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_factor, long_mscale) long_cache = long_cache.to(dtype) - self.register_buffer("long_cos_sin_cache", - long_cache, - persistent=False) - long_short_cache = torch.cat( - [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + long_short_cache = torch.cat([short_cache, long_cache], dim=0) self.register_buffer("long_short_cos_sin_cache", long_short_cache, persistent=False) @@ -593,8 +586,6 @@ def forward( torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) - self.long_short_cos_sin_cache: torch.Tensor = ( - self.long_short_cos_sin_cache.to(idx.device)) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)