Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] Add LoRA kernel micro benchmarks #11579

Merged
merged 17 commits into from
Jan 16, 2025

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Dec 28, 2024

Add LoRA kernel micro benchmarks for tuning/optimizing LoRA kernels

  • The benchmarking script creates a pool of tensors for each kernel argument and uses the tensors in-order for benchmarking. Having a bigger argument pool helps in mitigating the caching effects during benchmarking.
  • The benchmarking script has the ability to run the kernels inside a cuda graph. This is particularly useful for benchmarking triton kernels due to their launch overhead.
  • The benchmarking script also benchmarks torch.mm as a baseline.

Added a utils.py in benchmarks/kernels/ that implements a Bench class. This Bench class is abstract enough to use in other future benchmark implementations.

The benchmarking script, can run in one of 3 modes,

  1. range_bench
    Example : python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8

Use this to benchmark a range of hidden dimension sizes and lora-ranks

  1. list_bench
    Example : python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 2049 4096 8192 --lora-ranks 2 8 16 20 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 ---cuda-graph-nops 32

When range benchmarking is too restrictive, use this version to simply list the hidden-dimension sizes and lora-rank values.

  1. model_bench
    Example : python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand sgmv_expand_slice bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32

Specify a model to use the weight shapes in the model to understand the model execution performance.

Some benchmarks run on main, using

NUM_LORAS=(4)
BATCH_SIZES=(16 128 256 512 1024 2048 8192)
HIDDEN_SIZES=(1024 2048 4096 8192 16384)
RANKS=(16)

echo "Benchmarking bgmv punica kernels ..."
python3 benchmarks/kernels/benchmark_lora.py list_bench --dtype torch.float16 --arg-pool-size 32 --with-cuda-graph --num-loras ${NUM_LORAS[@]} --op-types bgmv_shrink bgmv_expand --seq-lengths 1 --hidden-sizes ${HIDDEN_SIZES[@]} --batch-sizes ${BATCH_SIZES[@]} --sort-by-lora-id 1

echo "Benchmarking sgmv punica kernels ..."
python3 benchmarks/kernels/benchmark_lora.py list_bench --dtype torch.float16 --arg-pool-size 32 --with-cuda-graph --num-loras ${NUM_LORAS[@]} --op-types sgmv_shrink sgmv_expand --seq-lengths 8 --hidden-sizes ${HIDDEN_SIZES[@]} --batch-sizes ${BATCH_SIZES[@]} --sort-by-lora-id 1

and later collated can be found here https://docs.google.com/spreadsheets/d/16iA8nZyuhfOctNg6KSJ1Y0Ve5udZKDOMsiDYDORNyks/edit?usp=sharing

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@varun-sundar-rabindranath
Copy link
Contributor Author

@jeejeelee This PR adds some tooling for benchmarking LoRA kernels. Should be useful for further optimizing LoRA kernels and for #11234 . Note that this PR emulates the *_expand_slice operations by calling the kernels back-to-back like in the tests. However, the change should be simple enough to support #11234. PTAL.

@mgoin fyi

args.with_cuda_graph))
seq_len_timers.append(
bench_optype(_ctx, args.arg_pool_size, bench_op,
args.with_cuda_graph))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we need to ensure the compute results are aligned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For expand related operations, with add_inputs=True, testing for correctness on the benchmarking results is hard as the function is run an indeterminate number of times.

I have added a test_correctness method to BenchmarkTensor class that can be invoked with a CLI argument --test-correctness. Note that this tests for correctness before the benchmarking is run. This should give us enough confidence about the validity of the results.

What do you think ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm particularly surprised by the table execution time, especially the result shown in A164. SGMV shouldn't be this slow. So I think we should first verify that the calculation results are correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jeejeelee

The numbers in the table are timings for 32 consecutive invocations of the benchmarking function, run from inside a cuda graph. The timings / 32 should yield time per single invocation. Sorry about the confusion, I should have mentioned this earlier. I have added comments and print statements in the code to make this clear.

When run in cuda graph mode, the graph is captured with N invocations of the benchmarking function.

with torch.cuda.stream(stream):

The reported time is the time taken for a single graph replay.
return TBenchmark.Timer(

I ran the benchmarks for SGMV expand again - Please look at rows 51 to 91 for the normalized timings.
https://docs.google.com/spreadsheets/d/1gSUNdZ08H-057SUnxeWhPKBWrg5Hc3QkRJS6YnAq6_E/edit?usp=sharing

  • In the table, for smaller problem shapes, you can see how the normalized cuda graph timings don't have the triton kernel launch overheads
  • In the table, you can also see how having a bigger pool of arguments helps in mitigating the caching effects during benchmarking.

About testing : I have added the functionality to test the outputs after the benchmarking run anyways 👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me use {"bs": 8192, "sl": 8, "m": 65536, "k": 16, "n": 16384, "num_loras": 4, "sort_by_lora": true, "num_slices": 1} as an example. If I understand correctly, it would require 8192 execute to torch.mm versus 1 execute to the triton kernel, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand. but for

{"bs": 8192, "sl": 8, "m": 65536, "k": 16, "n": 16384, "num_loras": 4, "sort_by_lora": true, "num_slices": 1}

torch.mm's M K N are "m": 65536, "k": 16, "n": 16384 - A (65536 x 16) x B (16 x 16384) = C (65536 x 16384) - It is a single torch.mm call to compute all of C .

For LoRA, A matrix is the same size, but we have 4 B matrices (LoRA weights) and C is output based on LoRA ID mapping. Again it is a single triton call to compute all of C .

num_ops_in_cuda_graph=arg_pool_size) if with_cuda_graph else None
with Bench(cuda_graph_params, ctx.bench_label(),
ctx.bench_sublabel(op_type), description, torch.mm,
**mm_kwargs) as bench:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ:Does torch.mm support group gemm? If not, as baseline, how does it compute multi-lora gemm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, it does not. I meant for the torch.mm (just a matmul) benchmark to serve as a roofline. sorry about the confusion, I have renamed the functions and added a comment.

'max_seq_length': max_seq_len,
'token_nums': num_tokens,
'add_inputs': True,
}
Copy link
Collaborator

@jeejeelee jeejeelee Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If add_inputs is True, the expand-related kernel performs group-gemm + outputs, rather than just group-gemm alone

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was intentional, so we benchmark the most used and most expensive version. But, I see the value in passing this via the CLI. Added --expand-fn-add-inputs argument to the CLI.

"case. It is provided as a roofline for comparing our LoRA Kernel "
"implementations. It is expected that the LoRA kernels will be "
"slower than torch.mm in cases where num_loras is big. But for "
"small num_loras the goal should be to match the torch.mm numbers.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeejeelee I have added this note on how to interpret the torch.mm numbers. The console output looks like

== All Results ====
[---------------------------------------------------------------------------------------- lora-torch.float16 | cugraph 32 ops ----------------------------------------------------------------------------------------]
                                                                                                                 |  single-lora roofline using torch.mm (f16xf16=>f16)  |  SGMV_EXPAND(add_inputs=False) (f32xf16=>f16)
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      {"bs": 16, "sl": 16, "m": 256, "k": 16, "n": 2048, "num_loras": 4, "sort_by_lora": true, "num_slices": 1}  |                        132.6                         |                     174.5                    

Times are in microseconds (us).

Note : The timings reported above is for 32 consecutive invocations of the benchmarking functions. Please divide by 32 for single invocation timings 
Note on Comparison with torch.mm : The torch.mm numbers are benchmark numbers of a simple matmul emulating the single lora case. It is provided as a roofline for comparing our LoRA Kernel implementations. It is expected that the LoRA kernels will be slower than torch.mm in cases where num_loras is big. But for small num_loras the goal should be to match the torch.mm numbers.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we output this information just once?

assert all([
bt.test_correctness(op_type, expand_fn_add_inputs)
for bt in bench_tensors
])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeejeelee I have removed the correctness testing on benchmarking results. Instead, we now test the benchmarking function before running the benchmarks.

Matching the outputs of the benchmarking runs are very flaky and intractable. The root-cause of the issue is the updates to the output matrices and the fact that the benchmarking script can run the benchmarking function multiple times.
for expand related functions, when add_inputs = True, the output matrix is updated arbitrary number of times making correct testing intractable.
for shrink functions, depending on if SPLIT_K is used in the kernels, the results are either added to the output or stored directly. When results are added to the output, correctness testing becomes intractable.

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution and patience. Overall LGTM after completing the modifications below.

from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.utils import FlexibleArgumentParser
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We recently merged ##11100, these imports need to be reimplemented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up @jeejeelee . I have fixed it 🙌

Varun Sundar Rabindranath added 12 commits January 16, 2025 10:54
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
fix
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Varun Sundar Rabindranath added 5 commits January 16, 2025 10:54
test only benchmark tensors that participated

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@mgoin mgoin enabled auto-merge (squash) January 16, 2025 15:10
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 16, 2025
@mgoin mgoin merged commit 5fd24ec into vllm-project:main Jan 16, 2025
48 checks passed
ice-tong pushed a commit to ice-tong/vllm that referenced this pull request Jan 18, 2025
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
joennlae pushed a commit to 44ai-labs/vllm that referenced this pull request Jan 19, 2025
abmfy pushed a commit to abmfy/vllm-flashinfer that referenced this pull request Jan 24, 2025
abmfy pushed a commit to abmfy/vllm-flashinfer that referenced this pull request Jan 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants