Skip to content

Commit

Permalink
Adding SGLang FP8 Utils (#2348)
Browse files Browse the repository at this point in the history
  • Loading branch information
HaiShaw authored Dec 4, 2024
1 parent f8b0326 commit b2986d7
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional, Tuple

import torch


def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)

# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale

0 comments on commit b2986d7

Please sign in to comment.