Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Add prefix-caching support for phi-3-small-8k/128k model triton kernel #8345

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

congcongchen123
Copy link

@congcongchen123 congcongchen123 commented Sep 10, 2024

This PR is a feature enrichment of #4799 which introduces blocksparse flash attention, Microsoft Phi-3-Small-8K and Phi-3-Small-128K models.

This PR modifies the block-sparse attention prefill Triton kernel to add prefix-caching support.

  • Tested utilizing tool: benchmarks/benchmark_prefix_caching.py --model=microsoft/Phi-3-small-8k-instruct --enable-prefix-caching with different prompts, and verified that the output is correct.
  • Added unit test.
  • Benchmark phi3-small 8k/128k model with prefix-caching enabled, and verified that it has achieved significant TTFT latency gain with different prefix-cache hit rates.

Thanks for offline discussion @linxihui, @wschin !

FIX #xxxx (link existing issues this PR will resolve)

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@congcongchen123 congcongchen123 marked this pull request as ready for review September 10, 2024 21:46
@wschin
Copy link
Contributor

wschin commented Sep 12, 2024

@comaniac, would you mind find someone to review this PR? This is an important model. If reviewer feels reviewing kernel takes too long, maybe requesting more numerical tests and profiling result? Many thanks.

@comaniac
Copy link
Collaborator

@comaniac, would you mind find someone to review this PR? This is an important model. If reviewer feels reviewing kernel takes too long, maybe requesting more numerical tests and profiling result? Many thanks.

I'll review it ASAP. @mgoin @tlrmchlsmth @pcmoritz could any of you also help review the kernel implementation?
Meanwhile, numerical tests are definitely required, and do you need to tune the kernel configs offline like we did for the MoE kernels?

@@ -3,6 +3,407 @@
import triton.language as tl


@triton.jit
def _context_fwd_kernel_inner(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to write done the math formulation implemented by this kernel. The same for other kernel functions.



@torch.inference_mode()
def context_blocksparse_flash_attn_varlen_fwd(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If math equation is added, people can improve and replace it easier in the future.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a pass through the PR. I left a couple of superficial comments and had a question on CUDA graphs.

Also congcongchen123 looks like you need to run bash format.sh

@@ -49,6 +51,10 @@ def __init__(
self.use_spda = use_spda
self.dtype = dtype
self.device = device
# block size used for blocksparse attention, used in `local_blocks`,
# `vert_stride`.
# It is different from kv_cache block size which is he size of of a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "he size of of" -> "the size of"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can we avoid use block_size to reduce confusion? And how this size usually be set? Is it set via LLM engine config or a predefined constant?

Comment on lines +334 to +335
query_start_block_id = (context_lens_tensor // sparse_block_size).cpu()
query_end_block_id = ((seq_lens_tensor - 1) // sparse_block_size).cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does moving these to the CPU break CUDA graphs?

k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could you take a look at the comments here? It looks like the formatter made this a little hard to read

context_len,
offs_n)
# flash-attn 2
# m_i += tl.math.log2(l_i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this debug cruft?

Comment on lines +225 to +227
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < context_len,
other=0.0) # [D,N]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on a lack of masking along N here, looks like we require that the sparse block size divides the hidden dim? I think we should assert that somewhere if we don't already. I acknowledge this is paranoia.

@@ -49,6 +51,10 @@ def __init__(
self.use_spda = use_spda
self.dtype = dtype
self.device = device
# block size used for blocksparse attention, used in `local_blocks`,
# `vert_stride`.
# It is different from kv_cache block size which is he size of of a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually can we avoid use block_size to reduce confusion? And how this size usually be set? Is it set via LLM engine config or a predefined constant?

context_lens_tensor: torch.Tensor,
sm_scale: float,
):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add descriptions about what this function does?

Comment on lines +261 to +281
query, key, value: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
key_cache: shape = (num_blocks, num_kv_heads,
head_size // x, kv_cache_block_size, x),
where x is defined in paged_attn.py.
value_cache: shape = (num_blocks, num_kv_heads, head_size,
kv_cache_block_size).
block_tables: shape = (batch_size, num_blocks).
query_start_loc: shape = (batch_size + 1,).
The cumulative subquery lengths of the sequences in
the batch, used to index into subquery. E.g., if
the subquery length is [4, 6], it is [0, 4, 10].
seq_lens_tensor: shape=(batch_size + 1,).
The sequence length per sequence.
context_lens_tensor: shape=(batch_size,).
The context length per sequence (tokens that are
computed so far).
sm_scale: softmax scale, default to 1/sqrt(head_size).

return: tensor of shape as q.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query, key, value: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
key_cache: shape = (num_blocks, num_kv_heads,
head_size // x, kv_cache_block_size, x),
where x is defined in paged_attn.py.
value_cache: shape = (num_blocks, num_kv_heads, head_size,
kv_cache_block_size).
block_tables: shape = (batch_size, num_blocks).
query_start_loc: shape = (batch_size + 1,).
The cumulative subquery lengths of the sequences in
the batch, used to index into subquery. E.g., if
the subquery length is [4, 6], it is [0, 4, 10].
seq_lens_tensor: shape=(batch_size + 1,).
The sequence length per sequence.
context_lens_tensor: shape=(batch_size,).
The context length per sequence (tokens that are
computed so far).
sm_scale: softmax scale, default to 1/sqrt(head_size).
return: tensor of shape as q.
Args:
query, key, value: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
key_cache: shape = (num_blocks, num_kv_heads,
head_size // x, kv_cache_block_size, x),
where x is defined in paged_attn.py.
value_cache: shape = (num_blocks, num_kv_heads, head_size,
kv_cache_block_size).
block_tables: shape = (batch_size, num_blocks).
query_start_loc: shape = (batch_size + 1,).
The cumulative subquery lengths of the sequences in
the batch, used to index into subquery. E.g., if
the subquery length is [4, 6], it is [0, 4, 10].
seq_lens_tensor: shape=(batch_size + 1,).
The sequence length per sequence.
context_lens_tensor: shape=(batch_size,).
The context length per sequence (tokens that are
computed so far).
sm_scale: softmax scale, default to 1/sqrt(head_size).
Return: tensor of shape as q.


return: tensor of shape as q.
"""
assert (not self.use_spda), "forward_prefix does not support spda"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert (not self.use_spda), "forward_prefix does not support spda"
if self.use_spda:
raise ValueError("Prefix caching with block sparse attention does not support SPDA")

Comment on lines +290 to +292
assert (
self.block_size == self.q_block_size
or self.q_block_size is None), "Different block size not supported"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert (
self.block_size == self.q_block_size
or self.q_block_size is None), "Different block size not supported"
if self.block_size != self.q_block_size and self.q_block_size is not None:
raise NotImplementedError("Different block sizes are not supported")

Comment on lines +475 to +483
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device)

# Need this, otherwise when we capture the graph the process
# for GPU 1 would run on both GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a better for to deal with this in tests and CI?
cc @youkaichao

@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES)
@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS)
@torch.inference_mode()
def test_contexted_kv_attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you measure how long does this test take roughly?

@simon-mo simon-mo requested a review from WoosukKwon as a code owner November 26, 2024 05:49
Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @congcongchen123.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants