-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support #4535
[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support #4535
Conversation
dcf4178
to
79a4c19
Compare
Per offline discussion, this PR only includes backend refactoring for FP8 kv-cache related kernels and utilities. A follow-up PR will then cover the scaling factor loading. Thus, this PR is ready for review. cc @pcmoritz @robertgshaw2-neuralmagic @HaiShaw @WoosukKwon |
It is a bummer that github doesn't render the diff between the old and new nvidia quant_utils.cuh -- for ease of reviewing, here is the diff: (base) pcmoritz@pcmoritz-DQ44HV60WX /tmp % diff quant_utils_old.cuh quant_utils_new.cuh
2a3,6
> #include "../../../attention/attention_dtypes.h"
> #include "../../../attention/dtype_bfloat16.cuh"
> #include "../../../attention/dtype_float16.cuh"
> #include "../../../attention/dtype_float32.cuh"
4d7
< #include <stdint.h>
5a9
> #include <stdint.h>
7,10d10
< #include "../../attention/attention_dtypes.h"
< #include "../../attention/dtype_float32.cuh"
< #include "../../attention/dtype_float16.cuh"
< #include "../../attention/dtype_bfloat16.cuh"
12d11
<
14,15d12
< #ifdef ENABLE_FP8_E5M2
< namespace fp8_e5m2_unscaled {
17,20c14,20
< template<typename Tout, typename Tin>
< __inline__ __device__ Tout vec_conversion(const Tin& x)
< {
< return x;
---
> namespace fp8 {
> #ifdef ENABLE_FP8
>
> template <typename Tout, typename Tin>
> __inline__ __device__ Tout
> vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
> return x;
24,28c24,28
< template<>
< __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
< {
< __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
< return res.x;
---
> template <>
> __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
> const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
> __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> return res.x;
32,42c32,42
< template<>
< __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
< {
< union {
< uint16_t u16[2];
< uint32_t u32;
< } tmp;
< __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
< tmp.u16[0] = res.x;
< tmp.u16[1] = res.y;
< return tmp.u32;
---
> template <>
> __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint16_t u16[2];
> uint32_t u32;
> } tmp;
> __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
> tmp.u16[0] = res.x;
> tmp.u16[1] = res.y;
> return tmp.u32;
46,55c46,56
< template<>
< __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
< {
< union {
< uint2 u32x2;
< uint32_t u32[2];
< } tmp;
< tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
< tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
< return tmp.u32x2;
---
> template <>
> __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint2 u32x2;
> uint32_t u32[2];
> } tmp;
> tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
> tmp.u32[1] =
> vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
> return tmp.u32x2;
59,68c60,69
< template<>
< __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
< {
< union {
< uint4 u64x2;
< uint2 u64[2];
< } tmp;
< tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
< tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
< return tmp.u64x2;
---
> template <>
> __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
> const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint4 u64x2;
> uint2 u64[2];
> } tmp;
> tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
> tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
> return tmp.u64x2;
72,80c73,81
< template<>
< __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
< {
< // Note there is no direct convert function from fp8 to bf16.
< // fp8 -> half
< __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
< // half -> float -> bf16
< float tmp = half_to_float(res.x);
< return __float2bfloat16(tmp);
---
> template <>
> __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
> const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
> // Note there is no direct convert function from fp8 to bf16.
> // fp8 -> half
> __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> // half -> float -> bf16
> float tmp = half_to_float(res.x);
> return __float2bfloat16(tmp);
84,90c85,91
< template<>
< __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
< {
< __nv_bfloat162 res;
< res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
< res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
< return res;
---
> template <>
> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> __nv_bfloat162 res;
> res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
> res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
> return res;
94,100c95,102
< template<>
< __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
< {
< bf16_4_t res;
< res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
< res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
< return res;
---
> template <>
> __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t res;
> res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
> res.y =
> vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
> return res;
104,115c106,117
< template<>
< __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
< {
< bf16_4_t tmp1, tmp2;
< tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
< tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
< bf16_8_t res;
< res.x = tmp1.x;
< res.y = tmp1.y;
< res.z = tmp2.x;
< res.w = tmp2.y;
< return res;
---
> template <>
> __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
> const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t tmp1, tmp2;
> tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
> tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
> bf16_8_t res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
119,125c121,128
< template<>
< __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
< {
< // fp8 -> half
< uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
< // half -> float
< return half_to_float(tmp);
---
> template <>
> __inline__ __device__ float
> vec_conversion<float, uint8_t>(const uint8_t &a,
> const __nv_fp8_interpretation_t fp8_type) {
> // fp8 -> half
> uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
> // half -> float
> return half_to_float(tmp);
129,135c132,138
< template<>
< __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
< {
< // fp8x2 -> half2
< uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
< // half2 -> float2
< return half2_to_float2(tmp);
---
> template <>
> __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> // fp8x2 -> half2
> uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
> // half2 -> float2
> return half2_to_float2(tmp);
139,145c142,148
< template<>
< __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
< {
< Float4_ res;
< res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
< res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
< return res;
---
> template <>
> __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> Float4_ res;
> res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
> res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
> return res;
149,160c152,163
< template<>
< __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
< {
< Float4_ tmp1, tmp2;
< tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
< tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
< Float8_ res;
< res.x = tmp1.x;
< res.y = tmp1.y;
< res.z = tmp2.x;
< res.w = tmp2.y;
< return res;
---
> template <>
> __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
> const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp1, tmp2;
> tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
> tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
> Float8_ res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
163d165
<
165,171c167,174
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
< {
< __half_raw tmp;
< tmp.x = a;
< __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
< return (uint8_t)res;
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> __half_raw tmp;
> tmp.x = a;
> __nv_fp8_storage_t res =
> __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
175,177c178,180
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
< {
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
> const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
179c182
< assert(false);
---
> assert(false);
181,182c184,186
< __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
< return (uint8_t)res;
---
> __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
> __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
187,191c191,195
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
< {
< __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
< return (uint8_t)res;
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
> const float &a, const __nv_fp8_interpretation_t fp8_type) {
> __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
195,200c199,204
< template<>
< __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
< {
< Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
< float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
< return res;
---
> template <>
> __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
> float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
> return res;
202a207,213
> template <>
> __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
> const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> half2 float16;
> uint32_t uint32;
> };
204,210c215,217
< template<>
< __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
< {
< union {
< half2 float16;
< uint32_t uint32;
< };
---
> float16 = __float22half2_rn(a);
> return uint32;
> }
212,213c219,232
< float16 = __float22half2_rn(a);
< return uint32;
---
> template <>
> __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
> const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
> uint2 b;
> float2 val;
> val.x = a.x.x;
> val.y = a.x.y;
> b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
>
> val.x = a.y.x;
> val.y = a.y.y;
> b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
>
> return b;
216,223c235,244
< template<>
< __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
< {
< uint2 b;
< float2 val;
< val.x = a.x.x;
< val.y = a.x.y;
< b.x = vec_conversion<uint32_t, float2>(val);
---
> template <>
> __inline__ __device__ float4 vec_conversion<float4, Float4_>(
> const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
> float4 b;
> b.x = a.x.x;
> b.y = a.x.y;
> b.z = a.y.x;
> b.w = a.y.y;
> return b;
> }
225,227c246,255
< val.x = a.y.x;
< val.y = a.y.y;
< b.y = vec_conversion<uint32_t, float2>(val);
---
> template <>
> __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
> const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
> uint4 b;
> b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
> b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
> b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
> b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
> return b;
> }
229c257,262
< return b;
---
> template <>
> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
> const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
> __nv_bfloat162 b;
> from_float(b, a);
> return b;
232,240c265,270
< template<>
< __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
< {
< float4 b;
< b.x = a.x.x;
< b.y = a.x.y;
< b.z = a.y.x;
< b.w = a.y.y;
< return b;
---
> template <>
> __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
> const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t b;
> from_float(b, a);
> return b;
243,251c273,278
< template<>
< __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
< {
< uint4 b;
< b.x = vec_conversion<uint32_t, float2>(a.x);
< b.y = vec_conversion<uint32_t, float2>(a.y);
< b.z = vec_conversion<uint32_t, float2>(a.z);
< b.w = vec_conversion<uint32_t, float2>(a.w);
< return b;
---
> template <>
> __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
> const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_8_t b;
> from_float(b, a);
> return b;
254,258c281,290
< template<>
< __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
< __nv_bfloat162 b;
< from_float(b, a);
< return b;
---
> /* Scaled and vectorized conversions, for data exchange between high and low
> precision domains Convention of the scale in API, e.g: FP8_data =
> Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
> Dequant(FP8) * scale => HP
> */
>
> template <typename Tout, typename Tin>
> __inline__ __device__ Tout scaled_vec_conversion(
> const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
> return x;
261,265c293,299
< template<>
< __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
< bf16_4_t b;
< from_float(b, a);
< return b;
---
> // fp8 -> half
> template <>
> __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
> const uint8_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> return float_to_half(half_to_float(tmp.x) * scale);
268,272c302,314
< template<>
< __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
< bf16_8_t b;
< from_float(b, a);
< return b;
---
> // fp8x2 -> half2
> template <>
> __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint16_t u16[2];
> uint32_t u32;
> } tmp;
> __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
> tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
> tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
> return tmp.u32;
275,276c317,576
< } // namespace fp8_e5m2_unscaled
< #endif // ENABLE_FP8_E5M2
---
> // fp8x4 -> half2x2
> template <>
> __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint2 u32x2;
> uint32_t u32[2];
> } tmp;
> tmp.u32[0] =
> scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
> tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
> scale, fp8_type);
> return tmp.u32x2;
> }
>
> // fp8x8 -> half2x4
> template <>
> __inline__ __device__ uint4
> scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint4 u64x2;
> uint2 u64[2];
> } tmp;
> tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
> tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
> return tmp.u64x2;
> }
>
> // fp8 -> __nv_bfloat16
> template <>
> __inline__ __device__ __nv_bfloat16
> scaled_vec_conversion<__nv_bfloat16, uint8_t>(
> const uint8_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> // Note there is no direct convert function from fp8 to bf16.
> // fp8 -> half
> __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> // half -> float -> bf16
> float tmp = half_to_float(res.x);
> return __float2bfloat16(tmp * scale);
> }
>
> // fp8x2 -> __nv_bfloat162
> template <>
> __inline__ __device__ __nv_bfloat162
> scaled_vec_conversion<__nv_bfloat162, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __nv_bfloat162 res;
> res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
> fp8_type);
> res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
> scale, fp8_type);
> return res;
> }
>
> // fp8x4 -> bf16_4_t
> template <>
> __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t res;
> res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
> fp8_type);
> res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
> scale, fp8_type);
> return res;
> }
>
> // fp8x8 -> bf16_8_t
> template <>
> __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
> const uint2 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t tmp1, tmp2;
> tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
> tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
> bf16_8_t res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
> }
>
> // fp8 -> float
> template <>
> __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
> const uint8_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
>
> // fp8 -> half
> uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
> // half -> float
> return half_to_float(tmp) * scale;
> }
>
> // fp8x2 -> float2
> template <>
> __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> // fp8x2 -> half2
> uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
> // half2 -> float2
> return half2_to_float2(tmp);
> }
>
> // fp8x4 -> float4
> template <>
> __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> Float4_ res;
> res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
> res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
> fp8_type);
> return res;
> }
>
> // fp8x8 -> float8
> template <>
> __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
> const uint2 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp1, tmp2;
> tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
> tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
> Float8_ res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
> }
>
> // half -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __nv_fp8_storage_t res =
> __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
> }
>
> // bf16 -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
> const __nv_bfloat16 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
> assert(false);
> #else
> __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
> __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
> #endif
> }
>
> // float -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
> const float &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __nv_fp8_storage_t res =
> __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
> }
>
> // fp8x4 -> float4
> template <>
> __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
> float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
> return res;
> }
> #endif // ENABLE_FP8
>
> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
> __inline__ __device__ Tout convert(const Tin &x) {
> switch (kv_dt) {
> #ifdef ENABLE_FP8
> case Fp8KVCacheDataType::kAuto:
> // When the type is auto, Tin should be able to be converted to
> // Tout directly. Thus, the corresponding vec_conversion function
> // should ignore the last argument (e.g. __NV_E4M3).
> case Fp8KVCacheDataType::kFp8E4m3:
> return vec_conversion<Tout, Tin>(x, __NV_E4M3);
> case Fp8KVCacheDataType::kFp8E5m2:
> return vec_conversion<Tout, Tin>(x, __NV_E5M2);
> #endif
> default:
> assert(false);
> }
> }
>
> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
> __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
> switch (kv_dt) {
> #ifdef ENABLE_FP8
> case Fp8KVCacheDataType::kAuto:
> // When the type is auto, Tin should be able to be converted to
> // Tout directly. Thus, the corresponding vec_conversion function
> // should ignore the last argument (e.g. __NV_E4M3).
> case Fp8KVCacheDataType::kFp8E4m3:
> return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
> case Fp8KVCacheDataType::kFp8E5m2:
> return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
> #endif
> default:
> assert(false);
> }
> }
>
> // The following macro is used to dispatch the conversion function based on the
> // data type of the key and value cache. The FN is a macro that calls a function
> // with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
> #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
> if (KV_DTYPE == "auto") { \
> if (SRC_DTYPE == at::ScalarType::Float) { \
> FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
> } else if (SRC_DTYPE == at::ScalarType::Half) { \
> FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
> } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
> FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
> } else { \
> TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
> } \
> } else { \
> if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
> if (SRC_DTYPE == at::ScalarType::Float) { \
> FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3); \
> } else if (SRC_DTYPE == at::ScalarType::Half) { \
> FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3); \
> } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
> FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3); \
> } else { \
> TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
> } \
> } else if (KV_DTYPE == "fp8_e5m2") { \
> if (SRC_DTYPE == at::ScalarType::Float) { \
> FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2); \
> } else if (SRC_DTYPE == at::ScalarType::Half) { \
> FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2); \
> } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
> FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2); \
> } else { \
> TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
> } \
> } else { \
> TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
> } \
> }
>
> } // namespace fp8 |
#ifdef ENABLE_FP8 | ||
case Fp8KVCacheDataType::kAuto: | ||
// When the type is auto, Tin should be able to be converted to | ||
// Tout directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment that we are falling throught to the next statement here (same below)
Did you investigate the performance impact of passing |
Good question. It would be tedious to put this type to template, because we have roughly 30 overloaded functions. Since C++ doesn't allow partial specialized template, we have to manually duplicate them to 60 functions to cover both formats... |
Why don't we test if there is a performance overhead (probably the compiler is already smart enough to optimize that -- it should be since the argument is constant in https://github.com/vllm-project/vllm/pull/4535/files#diff-97c4751eafe4ec7333fe2f140e29c84ea054f43d17d4286cc8c4e69a095d09aaR502 and similar for scaled_convert. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the performance is ok, we can go forward with this.
Thanks a lot for cleaning this up @comaniac ❤️
This code was not pretty and now it is much nicer!
The only thing I'm not a fan of is {nvidia, amd}/quant_utils.cuh
. If anybody has ideas how to do that better, that would be very much appreciated!
I'll verify the performance. For naming, another way I could think of is |
It is not about naming, more about having all these special cases and little conversion utilities :) |
I benchmarked on L4 GPU and the latency difference is within 1-2% which should be acceptable. |
@HaiShaw @AdrianAbeyta I keep seeing the following error when building this PR with ROCm. It seems like the same
|
54d0ee2
to
b61b3b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think better not to mix quant_utils
with reference to Fp8KVCacheDataType kv_dt
in one file, as quant_utils
could be used for all other things as well - e.g. activations quantization, we could maintains its autonomy in that sense.
Can you |
I cannot do that since it's on the CI instead of my local workspace. The issue remains even after I removed the dtype_fp8.cuh header from quant_utils.cuh... |
CI passed so we should be good to go. For the comment about not mixing |
The first PR for #4532.
Task list:
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
format.sh
to format your code.docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with
rfc-required
and might not go through the PR.What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!