From b2986d7aa5a40740b71c0d2f59a9277cfa10c67f Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 4 Dec 2024 03:01:33 -0800 Subject: [PATCH] Adding SGLang FP8 Utils (#2348) --- .../srt/layers/quantization/fp8_utils.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/fp8_utils.py diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py new file mode 100644 index 00000000000..3ba381a373f --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -0,0 +1,27 @@ +from typing import Optional, Tuple + +import torch + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale