Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Move more control of kv cache initialization from model_executor to EngineCore #11960

Merged
merged 22 commits into from
Jan 17, 2025

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Jan 11, 2025

This pr changes the workflow of EngineCore._initialize_kv_caches to enable more flexible control of kv cache format in the future.
It is splitted from #11938 and is a preparation for #11382
Original workflow:

num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks()
self.model_executor.initialize(num_gpu_blocks)

New workflow:

# Get all kv cache tensor needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()

# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.get_available_memory()

# Get the kv cache tensor size
kv_cache_config, num_gpu_blocks = get_kv_cache_config(
    vllm_config, kv_cache_spec, availble_gpu_memory)


# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)

This pr introduces 2 new concepts:

  1. KVCacheSpec, a data structure to represent the kv cache needed by each attention layer, which is constructed by asking the model runner to analyze all Attention modules. Will add more types of Spec in the future, e.g., SlidingWindowSpec, MLASpec
  2. KVCacheConfig, a class to represent how to allocate the kv cache Tensor. It is quite simple now, i.e., tensors with the same size. But it may be extended to the following cases:
    1. tensors with different sizes, to support MLA & spec decode
    2. allocate a global buffer, and make the kv_cache tensors to point to different offsets, to support multiple types of layer sharing the same memory pool.

Signed-off-by: Chen Zhang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@heheda12345 heheda12345 changed the title [V1] Move more control of kv cache initialization from model_executor to to EngineCore [V1] Move more control of kv cache initialization from model_executor to EngineCore Jan 11, 2025
@comaniac comaniac self-assigned this Jan 11, 2025
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Most comments are style changes. One important thing: Please docstring all functions and files in the desired format:

def func_name(arg1, arg2):
    """What does this function do?
    
    Args:
        arg1: ...
        arg2: ...
    
    Returns:
        ... (skip if return None) ...
    """

vllm/v1/engine/core.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
return kv_cache_config, num_gpu_blocks


def is_same_key(kv_cache_spec: KVCacheSpec) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we try to make this name more informative?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KVCacheSpecBase.key -> KVCacheSpecBase.type_id
is_same_key -> is_same_type

vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/utils.py Outdated
Comment on lines 143 to 147
def bind_kv_cache(
ctx: Dict[str, Any],
runner_kv_caches: List[torch.Tensor],
kv_caches: Dict[str, torch.Tensor],
) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in the kv_cache_utils.py as it is kv cache related?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bind_kv_cache is called by GPUModelRunner. I think it is strange to let GPUModelRunner call a function inside core. These two parts should be independent.

vllm/v1/utils.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/utils.py Outdated Show resolved Hide resolved
vllm/v1/kv_cache_interface.py Outdated Show resolved Hide resolved
Comment on lines 61 to 63
# [group_id][layer_name in the group]. One group containing all
# layer_names if the Spec for kv_cache of all layers are the same
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand [group_id][layer_name in the group]. What's group ID? It might be better to just show an example.

# A list of kv-cache groups. Each group includes a set of layers with
# the same kv-cache spec. For example: ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments updated. Do you feel better now?

heheda12345 and others added 9 commits January 13, 2025 23:25
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
@heheda12345
Copy link
Collaborator Author

@comaniac Thank you for the review. I've updated the code based on your suggestions. Can you take another look?

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be the last batch of comments. Approve to unblock the PR first.

vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
vllm/v1/core/kv_cache_utils.py Outdated Show resolved Hide resolved
f"initializing the engine.")


def is_same_type(kv_cache_spec: KVCacheSpec) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this function name is a bit unclear. It actually checks whether the "kv cache specs" of "all" layers are the same, so it should be informative about "spec", "all" and "same". Maybe "is_uniformed_kv_cache_type" or something like that would be better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to is_kv_cache_type_uniform

Comment on lines 405 to 406
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
available_memory: int) -> Tuple[KVCacheConfig, int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this function name (also the callee _get_kv_cache_config_same_type) doesn't illustrate the number of GPU blocks. It's weird to see config, num_blocks = get_kv_cache_config(...).

One possibility:

def get_kv_cache_config_and_available_blocks(...):
    check_enough_kv_cache_memory(...)

   # Later maybe you can introduce a registry when you have more policies.
    if is_uniformed_kv_cache_type(...):
        return _get_kv_cache_config_and_blocks_for_unifiemd_type(...)
    return _get_kv_cache_config_and_blocks_for_xxx(...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed num_blocks to an attribute of KVCacheConfig

vllm/v1/executor/multiproc_executor.py Outdated Show resolved Hide resolved
Comment on lines 118 to 120
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs)
return kv_cache_specs[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would this be extended later if you have different specs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't be extended. kv_cache_spec[i] is for all layers of one GPU. Different TP GPUs always have the same spec though the spec of each layer inside one GPU can be different.
PP executors of different stages can have different specs.

vllm/v1/utils.py Outdated
Comment on lines 152 to 153
Bind kv_caches to the forward context and model_runner's kv_cache.
Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please elaborate more, because the concept of "binding" kv-cache is not common for most people. For example, bind what kv-cache to what, and what's the purpose.

vllm/v1/utils.py Outdated
Comment on lines 168 to 169
assert all(kv_caches[n] is kv_caches[layer_name]
for n in layer_names[1:])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment if you're going to extend this logic for xxx; otherwise it looks a bit weird.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to raise an error if multiple attention layers have the same layer_index.
Layer_name and layer_index are defined by model implementation instead of this pr. There will be no conflict of layer_index in decoder-only models.

vllm/v1/core/kv_cache_utils.py Show resolved Hide resolved
Signed-off-by: Chen Zhang <[email protected]>
@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 16, 2025
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please also sync the main branch. The SPMD PR may have incompatible changes to this one

Args:
kv_cache_spec (KVCacheSpec): The KVCacheSpec of the model

Returns:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent

@heheda12345
Copy link
Collaborator Author

heheda12345 commented Jan 16, 2025

@comaniac Thanks for the review. I've update this PR.

@comaniac comaniac enabled auto-merge (squash) January 16, 2025 17:26
@heheda12345
Copy link
Collaborator Author

Trying to fix the CI failure with #12138

@comaniac comaniac merged commit 69d765f into vllm-project:main Jan 17, 2025
54 checks passed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not really v1-specific, I think we should just make it for v0 and v1 directly. I really hate code duplication, it makes later bugfix and v0-v1 agnostic features difficult to develop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is heavy code duplication.

@@ -134,3 +141,48 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)


def bind_kv_cache(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is heavy code duplication, too.

self._run_workers("compile_or_warm_up_model")

def get_kv_cache_spec(self) -> KVCacheSpec:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this executor is not used. you should add it in vllm/executor/executor_base.py , and do not use _run_workers . use collective_rpc instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants