-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Refactor FP8 kv-cache #4532
Comments
Thanks a lot for putting together this RFC! This sounds like a solid plan to me. Some more detailed comments:
|
I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.
I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use |
But we could have an fp16 model with scales for the kv cache saved in the safetensors files: model.layers.0.attn.k_proj.weight
model.layers.0.attn.k_cache.act_scale So we would therefore not need to support the JSON case |
I'm +1 to supporting activation scales in the FP16 checkpoint and not in JSON. This way less configurations need to be supported and everything is uniform :) |
Sounds good! I get why the data type needs to be passed throught, what I don't really get is why "auto" needs to be handled in C++ -- it seems to me it could be mapped to a concrete type in python and then c++ only needs to handle the cases of concrete types. But I might be wrong, feel free to do what feels most natural while implementing this :) |
Thanks for the valuable feedback!
Then we could have:
|
Ah I see, in that case, |
w.r.t. |
w.r.t. |
Over time, it makes more sense to rule out |
The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
…ct#4893) The 2nd PR for vllm-project#4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
…ct#4893) The 2nd PR for vllm-project#4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
…ct#4893) The 2nd PR for vllm-project#4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
…ct#4893) The 2nd PR for vllm-project#4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
…ct#4893) The 2nd PR for vllm-project#4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
Motivation.
Support float8_e4m3 for NVIDIA GPUs: The current FP8 kv-cache supports e5m2 on NVIDIA GPUs, and e4m3 on AMD GPUs. While e5m2 seems to be an ideal format for kv-cache storage due to better performance (i.e., e5m2 has the same number of exponent bits as fp16 so its scaling overhead could be smaller than e4m3), model checkpoints with FP8 weights mostly adopt e4m3 format. As a result, in order to support FP8 models in vLLM, we need float8_e4m3 kv-cache for both vendors.
Refactor FP8 kv-cache: In the current implementation, we use macros to dispatch FP8 tensor quantization logic in vLLM custom kernels. For example, in the current
cache_kernel.cu
, we use the following code to quantize tensors for CUDA or ROCm:This creates the following issues:
cache_kernel.cu
(so doesattention_kernel.cu
).Proposed Change.
User Interface
fp8
aliases tofloat8_e4m3
for all vendors.fp8_e4m3
andfp8_e5m2
are also available for users. We let vendor backend throw errors if a particular format is not supported in the current vLLM build.Backend Interface
csrc/quantization/fp8/<vendor>/quant_utils.cuh
.cache_kernel.cu
andattention_kernel.cu
is as follows. Note that we currently only have 2 GPU vendors so I suppose this is still acceptable.A wrapper function
convert
that convertsTin
toTout
with a particular FP8 format. For example, when writing values to kv-cache,Tin=uint16_t
,Tout=uint8_t
,kv_dt=kFp8E4M3
:In this way, vendors have the fully control about which format is supported or not.
A wrapper function
scaled_convert
, which is similar toconvert
but accepts a static scaling factor.The macro
DISPATCH_BY_KV_CACHE_DTYPE(FN)
that dispatches the FN macro based on kv-cache data type. For example:With this interface, the supported data types are transparent to each vendor backend. It can throw an error in run time when an unsupported data type is requested. Also, the
FN
would be a macro defined in vendor-independent files, such asCALL_RESHAPE_AND_CACHE
incache_kernel.cu
, andCALL_V2_LAUNCHER_BLOCK_SIZE
inattention_kernel.cu
. The macro should accept 3 arguments in order: input data type (e.g.,uint16_t
), kv-cache data type (e.g.,uint8_t
), and real kv-cache data type (e.g.,kFp8E4M3
).Roadmap
Feedback Period.
Feedback can be provided directly on PR. Based on comments here can update the RFC to elaborate.
CC List.
@robertgshaw2-neuralmagic @tlrmchlsmth @pcmoritz @zhuohan123 @WoosukKwon @HaiShaw @zhaoyang-star
Any Other Things.
No response
The text was updated successfully, but these errors were encountered: