Skip to content

Commit

Permalink
[Bugfix] Fix for ROCM compressed tensor support (vllm-project#11561)
Browse files Browse the repository at this point in the history
  • Loading branch information
selalipop authored and BKitor committed Dec 30, 2024
1 parent b648857 commit eabdf4f
Showing 1 changed file with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def process_weights_after_loading(self, layer) -> None:
)

if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)

weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
Expand All @@ -57,11 +59,13 @@ def process_weights_after_loading(self, layer) -> None:
weight = layer.weight

if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)

weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
Expand All @@ -76,7 +80,7 @@ def process_weights_after_loading(self, layer) -> None:
raise ValueError(f"Unknown quantization strategy {self.strategy}")

# INPUT SCALE
if self.is_static_input_scheme:
if self.is_static_input_scheme and hasattr(layer, 'input_scale'):
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
Expand Down

0 comments on commit eabdf4f

Please sign in to comment.