Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][Exploratory]: vLLM Neuron Backend with V1 Architecture #11152

Open
2 of 4 tasks
liangfu opened this issue Dec 12, 2024 · 0 comments · May be fixed by #11277
Open
2 of 4 tasks

[RFC][Exploratory]: vLLM Neuron Backend with V1 Architecture #11152

liangfu opened this issue Dec 12, 2024 · 0 comments · May be fixed by #11277
Labels

Comments

@liangfu
Copy link
Contributor

liangfu commented Dec 12, 2024

Motivation.

To leverage vLLM V1 architecture change, we are trying to propose a new integration apporach for the neuron backend that seamlessly integrate with vLLM, while maintaining high-performance and taking prefix-caching as first-class feature.

Background

(ref: #8779)
vLLM is on a path toward 1/ full support for torch.compile, 2/ turn on chunked prefill, prefix caching, speculative decoding by default, 3/ support more than 60 model variants. While, current neuron backend is supported via the transformers-neuronx library, which has limited support to the combination of these feature.

To support a wide range of model variants, vLLM has been maintaining the modular design with vllm.model_executor.layers module. This enables new model developers easily contribute to vLLM to support new model variants. For instance, Mistral team released pixtral-large model weights and brought pixtral-large model support with Pixtral (vllm-project#8377).

Proposed Change.

Embracing torch.compile support

As part of Neuron SDK 2.21 release, we are able to support torch.compile with openxla backend. For instance, we can implement copy_blocks with

class NeuronAttentionBackend:
    @torch.compile(backend="openxla")
    @staticmethod
    def copy_blocks(
        kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
        src_to_dists: Tuple[torch.Tensor, torch.Tensor],
    ) -> None:
        src_indices, dst_indices = src_to_dists
        for k_cache, v_cache in kv_caches:
            torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
            k_cache[:, dst_indices] = k_cache[:, src_indices]
            torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
            v_cache[:, dst_indices] = v_cache[:, src_indices]

if use_torch_compile:
    compiled_code = torch.compile(PallasAttentionBackend.copy_blocks, backend='openxla')
    compiled_code(kv_caches, src_to_dists)
else:
    PallasAttentionBackend.copy_blocks(kv_caches, src_to_dists)

Build neuron attention backend with NKI

We may build NKI-based flash-attention with paged KV cache, as part of vllm.attention.ops module. This is similar to triton-lang based flash-attention in vLLM (ref: triton_flash_attention.py).

Introduce forward_neuron into vllm.model_executor.layers module

Like many other backends in vLLM, the native kernel may not be highly performant on a specific hardware backend. vLLM has been building and maintaining the default behavior with forward_native function call, while hardware-specific optimization can be enabled with forward_xxx function call.

We can reuse the performant components in neuronx_distributed package to further improvement performance.

@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    """
    def forward_neuron(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from neuronx_distributed.ops import NeuronFusedRMSNorm
        if NeuronFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            x = NeuronFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon)
            return x.view(orig_shape), residual

        x = NeuronFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

Development Progress

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@liangfu liangfu added the RFC label Dec 12, 2024
@liangfu liangfu changed the title [RFC]: vLLM Neuron Backend with V1 Architecture [RFC][Exploratory]: vLLM Neuron Backend with V1 Architecture Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant