Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[platform] support custom torch.compile backend key #11318

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeStochasticBaseSampler)
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -368,7 +369,7 @@ def _smallest_positive_value(self) -> float:
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def _multinomial(
probs: torch.Tensor,
num_samples: int,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __post_init__(self):
assert self.num_added_elements <= self.num_added_elements_padded


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
Expand All @@ -53,7 +54,7 @@
maybe_prefix)


@torch.compile
@torch.compile(backend=current_platform.simple_compile_backend)
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
Expand Down Expand Up @@ -54,12 +55,12 @@ def weight_loader(self, param: torch.nn.Parameter,
return load_column_parallel_weight(param, loaded_weight)


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)


@torch.compile(dynamic=True)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class Platform:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key: str = "CPU"
# The torch.compile backend for compiling simple and
# standalone functions. The default value is "inductor" to keep
# the same behavior as PyTorch.
# NOTE: for the forward part of the model, vLLM has another separate
# compilation strategy.
simple_compile_backend: str = "inductor"
supported_quantization: list[str] = []

def is_cuda(self) -> bool:
Expand Down
Loading