Skip to content

Commit

Permalink
[v1][bugfix] fix cudagraph with inplace buffer assignment (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#11596)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored and BKitor committed Dec 30, 2024
1 parent 52b17d7 commit 540bb78
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
10 changes: 9 additions & 1 deletion vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):

vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
if compiled_callable is None:
# default compilation settings
# compiling the forward method

vllm_config = get_current_vllm_config()
backend = vllm_config.compilation_config.init_backend(vllm_config)

compiled_callable = torch.compile(
Expand Down Expand Up @@ -82,6 +83,13 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):

self.compiled_codes.append(new_code)

if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
import depyf
src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
raise RuntimeError(msg)

@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Expand Down
11 changes: 1 addition & 10 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,19 +541,12 @@ def __init__(
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)

long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)

long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
Expand Down Expand Up @@ -593,8 +586,6 @@ def forward(
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

Expand Down

0 comments on commit 540bb78

Please sign in to comment.