Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Hybrid Memory Allocator #11382

Open
1 task done
heheda12345 opened this issue Dec 20, 2024 · 0 comments
Open
1 task done

[RFC]: Hybrid Memory Allocator #11382

heheda12345 opened this issue Dec 20, 2024 · 0 comments
Labels

Comments

@heheda12345
Copy link
Collaborator

heheda12345 commented Dec 20, 2024

Motivation.

In addition to standard self-attention only models, we now having more and more hybrid models with more than one type of layer, for example:

  1. Sliding window attention + self attention: Gemma2 Ministral ([Feature]: Alternating local-global attention layers #9464)
  2. Cross attention + self attention: mllama
  3. Mamba layer + self attention: Jamba, Bamba

The KV cache size of different tokens are no longer the same in the above models. However, vLLM can only allocate the same KV cache size for all tokens, as shown in the below figure (mllama with BLOCK_SIZE=1). The memory waste can be 79.6% in mllama, 25% in Gemma-2, and 56.25% in Ministral.
figure1

And for mamba layers, vLLM has a special MambaCacheManager in model_executor/models/mamba_cache.py, which is not compatible with prefix caching.

We want a new memory manager to:

  1. Allocate kv cache for these models with minimum fragmentation
  2. Support prefix caching for these models

We can support them mainly in two milestones:
Milestone 1: per-layer memory allocation, each layer has the same kv cache size per token
In this milestone, we assume that all layers have the same kv_hidden_size but different number of tokens due to different layer type (e.g., encoder, sliding window). Then, each layer can have the same page size, which greatly simplifies the design of memory allocator. This assumption can cover almost all models in current vLLM (except Jamba).

We set page size as [BLOCK_SIZE, kv_hidden_size] (current vLLM is [num_layer, BLOCK_SIZE, kv_hidden_size]). For each request, we call the memory allocator num_layer times to get a block table for each layer. Then, each layer will have a different number of slots based on layer type, and different slot mapping.
figure2

The software architecture will be as following, with the memory manager for each layer be one of [SelfAttentionManager, SlidingWindowManager, MambaManager, CrossAttentionManager …]:
figure3

To make the per-layer memory allocation faster, we can group the layers to satisfy the following properties, so that layers inside each group share the same block table.

  • Layers in each group need to have the same number of tokens
  • All groups have the same number of layers

For example, mllama 11B has 8 cross-attention layer & 32 self-attention layer. Will be grouped to 5 groups: (cross attention layer $\times$ 8) (self-attn $\times$ 8) (self-attn $\times$ 8) (self-attn $\times$ 8) (self-attn $\times$ 8), and each group has one memory manager.

We can still use LRU to manage the cached blocks, by putting free blocks of all layers in the same queue and evicting the LRU block among all layers.

Moreover, we will have the freedom to customize the eviction strategy of different layer types:

  • mamba model only needs to cache the last tokens instead of the full kv cache.
    for request [sys1 sys2 prompt1 prompt2 prompt3], we can evict sys1 prompt1 prompt2; only cache sys2 (the end of system prompt) and prompt3 (the end of request, for multi-turn conversation))
  • Sliding window layers only need to cache the tokens inside sliding window

The customized eviction strategies can be implemented within the memory manager of each type by assigning the LRU time tag carefully (or put into the FreeKVCacheBlockQueue in a careful order)

Milestone 2: allow allocation of pages with different size by LCM pages
In this step, we want to build a more general memory allocator to remove the same kv_hidden_size assumption.

Some use cases are listed here:

  • Jamba
    • The mamba state per layer is 256x the kv_hidden_state/token/layer
    • We can regard mamba state as a different page size with self-attention page size
  • Llava-style multi-modal models: the VisionEncoderCache in v1 to cache the output of vision encoder. The vision embedding size is typically different from self-attention layers.
  • Lower allocation overhead for models in milestone 1: even after grouping, the number of memory managers can be large. e.g., 5 managers in mllama 11B. But if we can allocate pages with different sizes, we only need 2 managers, one for self attention layers, one for cross attention layers.

The above parts will introduce different page sizes. So we need a new allocator for it.
The basic idea is to introduce a two-level page table:

  • Large page allocator: the page size is the lcm of page size of different types of memory
  • Small page allocator: one memory manager for each type of page, partition the “large” pages to allocate “small” pages, the page size is the page size of that type.

For example, mllama with 2 cross attention layers (and KV cache for image tokens) and 3 self attention layers (and KV cache for text tokens), with kv_hidden_size 128, we can have two page sizes (2*128=256, 3*128=384), and the following memory layout:
figure4

Primary result
I’ve implemented a prototype on v0 and got the following results on H100.
The speedup comes from both less fragmentation and better prefix caching.

figure5

Proposed Change.

The allocator will be implemented on v1, with the following steps

  1. The first PR to support sliding window attention & interleaved sliding window / self attention:
    1. KV cache manager
      • One manager class for each type of layer, with customized allocate / free / prefix caching logic. The global allocator calls each allocator to perform allocation in one step.
    2. Runner.initialize_kv_cache
      • Instead of allocating a fixed-size tensor for each layer, we allocate one large tensor and let each layer have a different offset.
    3. AttentionMetadata
      • Build for each layer
      • Change AttentionMetadata to Dict[Attention.layer_name, AttentionMetadata], can be done inside class Attention & class FlashAttentionImpl, do not need model side changes
  2. Add new allocators and attention metadata builder to support:
    1. MLA layer (after [WIP] Deepseek V2 MLA #10927)
    2. mamba layer
    3. KV cache sharing (needed by [New Model]: Support Tencent-Hunyuan-Large #10043 [New Model]: nvidia/Hymba-1.5B-Base #10783)
    4. cross attention layer (if we decide to support encoder-decoder in v1)
    5. embedding models (no LayerManager, so very little overhead)
  3. More features (further discussion needed)
    1. The two KV cache sizes in spec decode
    2. EncoderCacheManager for vision embedding
    3. Sparse KV cache (mentioned in [RFC]: Support KV Cache Compaction #10646)
  4. LCM Allocation

Feedback Period.

one week

CC List.

@comaniac @WoosukKwon @zhuohan123 @simon

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant