Skip to content

Commit

Permalink
[platform] support custom torch.compile backend key (vllm-project#11318)
Browse files Browse the repository at this point in the history
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: Fred Reiss <[email protected]>
  • Loading branch information
2 people authored and frreiss committed Jan 10, 2025
1 parent 066031b commit 205611d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 5 deletions.
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeStochasticBaseSampler)
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -368,7 +369,7 @@ def _smallest_positive_value(self) -> float:
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def _multinomial(
probs: torch.Tensor,
num_samples: int,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __post_init__(self):
assert self.num_added_elements <= self.num_added_elements_padded


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand All @@ -53,7 +54,7 @@
maybe_prefix)


@torch.compile
@torch.compile(backend=current_platform.simple_compile_backend)
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
Expand Down Expand Up @@ -54,12 +55,12 @@ def weight_loader(self, param: torch.nn.Parameter,
return load_column_parallel_weight(param, loaded_weight)


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class Platform:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
# The torch.compile backend for compiling simple and
# standalone functions. The default value is "inductor" to keep
# the same behavior as PyTorch.
# NOTE: for the forward part of the model, vLLM has another separate
# compilation strategy.
simple_compile_backend: str = "inductor"
supported_quantization: list[str] = []

def is_cuda(self) -> bool:
Expand Down

0 comments on commit 205611d

Please sign in to comment.