Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Refactor FP8 kv-cache #4532

Closed
1 of 3 tasks
comaniac opened this issue May 1, 2024 · 11 comments
Closed
1 of 3 tasks

[RFC]: Refactor FP8 kv-cache #4532

comaniac opened this issue May 1, 2024 · 11 comments
Labels

Comments

@comaniac
Copy link
Collaborator

comaniac commented May 1, 2024

Motivation.

Support float8_e4m3 for NVIDIA GPUs: The current FP8 kv-cache supports e5m2 on NVIDIA GPUs, and e4m3 on AMD GPUs. While e5m2 seems to be an ideal format for kv-cache storage due to better performance (i.e., e5m2 has the same number of exponent bits as fp16 so its scaling overhead could be smaller than e4m3), model checkpoints with FP8 weights mostly adopt e4m3 format. As a result, in order to support FP8 models in vLLM, we need float8_e4m3 kv-cache for both vendors.

Refactor FP8 kv-cache: In the current implementation, we use macros to dispatch FP8 tensor quantization logic in vLLM custom kernels. For example, in the current cache_kernel.cu, we use the following code to quantize tensors for CUDA or ROCm:

#if defined(ENABLE_FP8_E5M2) // Only for CUDA
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3) // Only for ROCm
      key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
      value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
      assert(false);
#endif

This creates the following issues:

  1. We cannot enable e4m3 and e5m2 in the same build.
  2. The function arguments are fixed.
  3. Any vendor specific changes need to touch cache_kernel.cu (so does attention_kernel.cu).
  4. It is hard to be extended in the future to support more data types (e.g., INT8) for kv-cache.

Proposed Change.

User Interface

  1. The default data type fp8 aliases to float8_e4m3 for all vendors.
  2. fp8_e4m3 and fp8_e5m2 are also available for users. We let vendor backend throw errors if a particular format is not supported in the current vLLM build.
  3. When running FP16 model with FP8 kv-cache, we optionally accept kv-cache per-tensor scaling factors in a JSON format. If not provided, we always use scaling factor 1. This is compatible with the current e5m2 strategy.
  4. When running FP8 model, we load kv-cache scaling factor from the model checkpoint.

Backend Interface

  1. All vendor specific FP8 KV-cache related kernels are in csrc/quantization/fp8/<vendor>/quant_utils.cuh.
  2. The only vendor specific logic in cache_kernel.cu and attention_kernel.cu is as follows. Note that we currently only have 2 GPU vendors so I suppose this is still acceptable.
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif
  1. We define the following interfaces for each vendor to implement.
  • A wrapper function convert that converts Tin to Tout with a particular FP8 format. For example, when writing values to kv-cache, Tin=uint16_t, Tout=uint8_t, kv_dt=kFp8E4M3:

    template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
    __inline__ __device__ Tout convert(const Tin &x) {
        if (kv_cache_dtype == "fp8_e4m3") { ... }
        else if (kv_cache_dtype == "fp8_e5m2") { ... }
        else {...}
    }
    

    In this way, vendors have the fully control about which format is supported or not.

  • A wrapper function scaled_convert, which is similar to convert but accepts a static scaling factor.

  • The macro DISPATCH_BY_KV_CACHE_DTYPE(FN) that dispatches the FN macro based on kv-cache data type. For example:

    #define DISPATCH_BY_KV_CACHE_DTYPE(FN)                        \
      if (kv_cache_dtype == "auto") {                             \
        if (key.dtype() == at::ScalarType::Float) {               \
          FN(float, float, vllm::Fp8KVCacheDataType::kAuto);      \
        } else {                                                  \
         ...                                                      \
        }                                                         \
      } else {                                                    \
        if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {\
          if (key.dtype() == at::ScalarType::Float) {                 \
            FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3);   \
          } else {                                                    \
            ...                                                       \
          }                                                           \
        } else if (kv_cache_dtype == "fp8_e5m2") {                    \
          if (key.dtype() == at::ScalarType::Float) {                 \
            FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3);   \
          } else {                                                    \
            ...                                                       \
          }                                                           \
        }                                                             \
      }
    

    With this interface, the supported data types are transparent to each vendor backend. It can throw an error in run time when an unsupported data type is requested. Also, the FN would be a macro defined in vendor-independent files, such as CALL_RESHAPE_AND_CACHE in cache_kernel.cu, and CALL_V2_LAUNCHER_BLOCK_SIZE in attention_kernel.cu. The macro should accept 3 arguments in order: input data type (e.g., uint16_t), kv-cache data type (e.g., uint8_t), and real kv-cache data type (e.g., kFp8E4M3).

Roadmap

  • Refactor backend interface.
  • Refactor user interface and support FP8 checkpoint with kv-cache scaling factors.
  • Performance improvements.

Feedback Period.

Feedback can be provided directly on PR. Based on comments here can update the RFC to elaborate.

CC List.

@robertgshaw2-neuralmagic @tlrmchlsmth @pcmoritz @zhuohan123 @WoosukKwon @HaiShaw @zhaoyang-star

Any Other Things.

No response

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 1, 2024

Thanks a lot for putting together this RFC! This sounds like a solid plan to me.

Some more detailed comments:

  • Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.
  • Have you considered to not have a vllm::Fp8KVCacheDataType::kAuto datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)

@comaniac
Copy link
Collaborator Author

comaniac commented May 1, 2024

Why not store the KV scaling factors in the safetensors / model checkpoints instead of the JSON format? The scaling factors for the KV store are very similar to the scaling factors for activations, so it is very natural to handle them the same and the safetensors / model checkpoints are the most natural places to store them.

I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.

Have you considered to not have a vllm::Fp8KVCacheDataType::kAuto datatype and instead resolve the the kv cache dtype to the right dtype on the python layer? We should have all the needed information there and the C++ kernels shouldn't need to deal with this complexity and just be able to use a concrete type, right? (unless I'm missing something)

I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... } is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a) and c10::Float8_e5m2 convert(scalar_t& a) and let C++ find the right one.

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented May 2, 2024

I plan to support both cases. The JSON format is only for the case that people don't want FP8 model but FP8 kv-cache only (e.g., on pre-Hopper GPUs). In this case the model checkpoint is still in FP16, so we need a way to pass scaling factors anyways. For FP8 models we definitely load the scaling factors from checkpoints directly.

But we could have an fp16 model with scales for the kv cache saved in the safetensors files:

model.layers.0.attn.k_proj.weight
model.layers.0.attn.k_cache.act_scale

So we would therefore not need to support the JSON case

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 2, 2024

I'm +1 to supporting activation scales in the FP16 checkpoint and not in JSON. This way less configurations need to be supported and everything is uniform :)

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 2, 2024

I think I've tried several ways but didn't luck at this moment (maybe I also miss something). I suppose the main reason is kv quantization happens during kv-cache write (cache_kernel) and read (attention_kernel), so the quantization kernel has to be dispatched and invoked inside the kernel. That's why we have to pass the data type all the way to these kernels. Additionally, since we now use uint8_t to store fp8 values, we cannot differentiate whether a function like uint8_t convert(uint16_t& a) { ... } is for E4M3 or E5M2. This also makes function overloading problematic. If we could directly use c10::Float8_e4m3fn as the kv-cache data type directly, we could just implement c10::Float8_e4m3fn convert(scalar_t& a) and c10::Float8_e5m2 convert(scalar_t& a) and let C++ find the right one.

Sounds good! I get why the data type needs to be passed throught, what I don't really get is why "auto" needs to be handled in C++ -- it seems to me it could be mapped to a concrete type in python and then c++ only needs to handle the cases of concrete types. But I might be wrong, feel free to do what feels most natural while implementing this :)

@comaniac
Copy link
Collaborator Author

comaniac commented May 2, 2024

Thanks for the valuable feedback!

  • [Interface] I am actually supportive with kv-cache scaling factor embedded in model checkpoint (and 1.0 will be used if we cannot find it in the checkpoint). After all the kv-cache scaling factors are always tightly coupled with the model itself. Let's try to align with other folks on this interface.
  • [kAuto] I see your point. I agree that the term kAuto is misleading. In fact, it should be something like kNoFp8. The reason to keep this flag is mainly for the case when FP8 is not enabled (e.g., CPU or CUDA<8.0). In this case we completely disable FP8 related kernels, so we need a compile-time flag to make a special path. For example:
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
  // Make sure we don't call " fp8::scaled_convert" when FP8 is disabled.
  key_cache[tgt_key_idx] = tgt_key;
  value_cache[tgt_value_idx] = tgt_value;
} else {
  key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
  value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}

Then we could have:

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
  switch (kv_dt) {
#ifdef ENABLE_FP8
  case Fp8KVCacheDataType::kFp8E4m3:
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
  case Fp8KVCacheDataType::kFp8E5m2:
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
#endif
  default: // Only this branch left when FP8 is disabled, so always throw an error.
    assert(false);
  }
}

@pcmoritz
Copy link
Collaborator

pcmoritz commented May 2, 2024

Ah I see, in that case, kAuto is a good name since it is the same as "auto" in python. I didn't realize it required a special code path :)

@HaiShaw
Copy link
Contributor

HaiShaw commented May 8, 2024

w.r.t. We cannot enable e4m3 and e5m2 in the same build.
If we look to have a build with both supported on a same newer hardware, most likely we won't need both formats function simultaneously, as that increases the complexity with no use case, considering only e4m3 would be used in forward and inferencing computations. On contrary, e5m2 would only be practically feasible on older hardware to be a storage type with acceptable performance from mantissa rounding, have e4m3 enabled on older hardware isn't beneficial, neither to cost of cast, nor to computation (no hardware support). Finally, have a build being generic across generations of GPUs seems to be unnecessary.

@HaiShaw
Copy link
Contributor

HaiShaw commented May 8, 2024

w.r.t. When running FP8 model, we load kv-cache scaling factor from the model checkpoint.
We shall have serialized checkpoint with various scaling factors defined, to both the stationary scaling factors (for weights, at whatever granularity), and updatable scaling factors (activations, KV caches), and to the later we need to define the update process with quantizer flow included.

@HaiShaw
Copy link
Contributor

HaiShaw commented May 8, 2024

A wrapper function convert that converts Tin to Tout with a particular FP8 format. For example, when writing values to kv-cache, Tin=uint16_t, Tout=uint8_t, kv_dt=kFp8E4M3

Over time, it makes more sense to rule out uint8_t and move to use torch fp8 types, then kv_dt would be unnecessary.

pcmoritz pushed a commit that referenced this issue May 22, 2024
The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this issue May 31, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this issue Jun 8, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
joerunde pushed a commit to joerunde/vllm that referenced this issue Jun 17, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this issue Jul 14, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
@jon-chuang
Copy link
Contributor

@comaniac can this be closed? as of: #4893

@comaniac comaniac closed this as completed Aug 6, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this issue Sep 6, 2024
…ct#4893)

The 2nd PR for vllm-project#4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants