Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quant Tool] Update QDQ Pad, Slice, Softmax #22676

Merged
merged 10 commits into from
Nov 6, 2024
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,4 +529,6 @@ def adjust_tensor_ranges(self):
self.tensors_range[node.input[0]] = td
# Adjust Softmax to range from 0.0 to 1.0
elif node.op_type == "Softmax":
if not self.should_quantize_node(node):
continue
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0))
72 changes: 72 additions & 0 deletions onnxruntime/python/tools/quantization/operators/pad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

from typing import Any

import numpy as np
import onnx

from ..quant_utils import (
Expand All @@ -8,6 +17,7 @@
quantize_nparray,
)
from .base_operator import QuantOperatorBase
from .qdq_base_operator import QDQOperatorBase


class QPad(QuantOperatorBase):
Expand Down Expand Up @@ -98,3 +108,65 @@ def quantize(self):
node.input[0] = quantized_input_value.q_name
node.output[0] = quantized_output_value.q_name
self.quantizer.new_nodes += [node]


class QDQPad(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None:
"""
Returns the Pad's constant padding value. Returns `None` if the padding value is
not constant (i.e., comes from a dynamic input).
"""
const_val = None
onnx_tensor_type = self.quantizer.model.get_tensor_type(self.node.input[0])
if onnx_tensor_type is None:
return None

np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type)
if self.quantizer.opset_version < 11:
const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype)
elif len(self.node.input) >= 3 and self.node.input[2]:
const_val = self.quantizer.model.get_constant_value(self.node.input[2])
else:
const_val = np.array(0, dtype=np_dtype)

return const_val

def _should_quantize_output_same_as_input(self) -> bool:
"""
Returns true if Pad's output should use the same quantization parameters as input[0]
"""
attrs_dict = {}
for attribute in self.node.attribute:
kv = attribute_to_kwarg(attribute)
attrs_dict.update(kv)

pad_mode = attrs_dict.get("mode", b"constant")
if pad_mode in (b"reflect", b"edge", b"wrap"):
# These modes pad the output with a value that already exists in the input.
# So, we can quantize the output the same as the input.
return True

# For 'constant' mode, if padding with 0, we can also quantize the output the same as the input
# because our quantization floating-point range always includes 0.
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
if pad_mode == b"constant":
pad_val = self._get_pad_const_val(attrs_dict)
if pad_val is not None and pad_val.dtype in (np.float32, np.float16):
return float(pad_val.item()) == 0

return False

def quantize(self):
assert self.node.op_type == "Pad"

for input_name in self.node.input:
if input_name:
self.quantizer.quantize_activation_tensor(input_name)

if not self.disable_qdq_for_node_output:
if self._should_quantize_output_same_as_input():
self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name)
else:
self.quantizer.quantize_activation_tensor(self.node.output[0])
4 changes: 3 additions & 1 deletion onnxruntime/python/tools/quantization/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
from .operators.maxpool import QDQMaxPool, QMaxPool
from .operators.norm import QDQNormalization
from .operators.pad import QPad
from .operators.pad import QDQPad, QPad
from .operators.pooling import QLinearPool
from .operators.qdq_base_operator import QDQOperatorBase
from .operators.resize import QDQResize, QResize
Expand Down Expand Up @@ -76,6 +76,8 @@
"Resize": QDQResize,
"MaxPool": QDQMaxPool,
"AveragePool": QDQDirect8BitOp,
"Slice": QDQDirect8BitOp,
"Pad": QDQPad,
"MatMul": QDQMatMul,
"Split": QDQSplit,
"Gather": QDQGather,
Expand Down
172 changes: 172 additions & 0 deletions onnxruntime/test/python/quantization/test_op_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import itertools
import os
import tempfile
import unittest

import numpy as np
Expand Down Expand Up @@ -519,5 +522,174 @@ def test_pad_with_empty_string_input_name(self):
self.assertNotEqual(name, "_quantized")


class TestQDQPad(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.pad_")

# Note: swap with the commented line if you want to see the models in local test dir.
cls._tmp_dir_path = cls._tmp_model_dir.name
# cls._tmp_dir_path = "."

@classmethod
def tearDownClass(cls):
cls._tmp_model_dir.cleanup()

def build_pad_model(
self,
mode: str,
constant_value: float | None = None,
opset: int = 21,
float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT,
) -> onnx.ModelProto:
input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, (3, 2))
output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, (3, 4))

initializers = []
pad_input_names = ["input_0"]
attrs = {"mode": mode}

pads_data = np.array([0, 2, 0, 0], dtype=np.int64) # Pad two vals at beginning of axis 1.
if opset >= 11:
initializers.append(onnx.numpy_helper.from_array(pads_data, "pads"))
pad_input_names.append("pads")
else:
attrs["pads"] = pads_data.tolist()

if mode == "constant" and constant_value is not None:
if opset >= 11:
initializers.append(onnx.helper.make_tensor("constant_value", float_type, [], [constant_value]))
pad_input_names.append("constant_value")
else:
attrs["value"] = float(constant_value)

pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", **attrs)

graph = onnx.helper.make_graph(
[pad_node],
"PadFloat",
[input_0],
[output_0],
initializer=initializers,
)
opset_imports = [onnx.helper.make_opsetid("", opset)]
model = onnx.helper.make_model(graph, opset_imports=opset_imports)
model = onnx.shape_inference.infer_shapes(model)
onnx.checker.check_model(model, True)
return model

def test_qdq_pad_qparams(self):
"""
Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations.
"""
test_configs = [
# Opset 21
("constant", None, 21, onnx.TensorProto.FLOAT),
("constant", None, 21, onnx.TensorProto.FLOAT16),
("constant", 0, 21, onnx.TensorProto.FLOAT),
("constant", 0, 21, onnx.TensorProto.FLOAT16),
("constant", 10.0, 21, onnx.TensorProto.FLOAT),
("constant", 10.0, 21, onnx.TensorProto.FLOAT16),
("reflect", None, 21, onnx.TensorProto.FLOAT),
("reflect", None, 21, onnx.TensorProto.FLOAT16),
("edge", None, 21, onnx.TensorProto.FLOAT),
("edge", None, 21, onnx.TensorProto.FLOAT16),
("wrap", None, 21, onnx.TensorProto.FLOAT),
("wrap", None, 21, onnx.TensorProto.FLOAT16),
# Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs.
# Opset 10 Q/DQ ops don't support float16.
("constant", None, 10, onnx.TensorProto.FLOAT),
("constant", 0, 10, onnx.TensorProto.FLOAT),
("constant", 10.0, 10, onnx.TensorProto.FLOAT),
("reflect", None, 10, onnx.TensorProto.FLOAT),
("edge", None, 10, onnx.TensorProto.FLOAT),
]

for pad_mode, constant_value, opset, float_type in test_configs:
with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type):
label = f"_{pad_mode}_{constant_value}"
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx")
qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx")

float_model = self.build_pad_model(pad_mode, constant_value, opset=opset, float_type=float_type)
onnx.save_model(float_model, float_model_path)

# Create a data reader
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type)
input_data_list = [
{"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np_dtype)},
{"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np_dtype)},
]
data_reader = TestDataFeeds(input_data_list)

# quantize model to QDQ
quantize_static(
float_model_path,
qdq_model_path,
data_reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
)

expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1}
if constant_value is not None and opset >= 11:
expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized.
check_op_type_count(self, qdq_model_path, **expected_op_counts)

if pad_mode != "reflect":
# Do not check model correctness for 'reflect' mode because ONNX Runtime implementation does
# not match the ONNX reference implementation. See the following issue:
# https://github.com/microsoft/onnxruntime/issues/20801
data_reader.rewind()
check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next())

qdq_model = onnx.load_model(qdq_model_path)
quant_output_same_as_input = False

if pad_mode in ("reflect", "edge", "wrap"):
quant_output_same_as_input = True

if pad_mode == "constant" and constant_value in (None, 0):
quant_output_same_as_input = True

consumers = {}
producers = {}
pad_node = None

# Build dictionaries that map a tensor name to the consumer or producer nodes.
adrianlizarraga marked this conversation as resolved.
Show resolved Hide resolved
for node in qdq_model.graph.node:
if node.op_type == "Pad":
pad_node = node

for input_name in node.input:
if input_name:
if input_name not in consumers:
consumers[input_name] = []

consumers[input_name].append(node)

for output_name in node.output:
producers[output_name] = node

self.assertEqual(pad_node.op_type, "Pad")

input_dq_node = producers.get(pad_node.input[0], None)
self.assertNotEqual(input_dq_node, None)
self.assertEqual(input_dq_node.op_type, "DequantizeLinear")

output_q_node = consumers.get(pad_node.output[0], [None])[0]
self.assertNotEqual(output_q_node, None)
self.assertEqual(output_q_node.op_type, "QuantizeLinear")

# Check that the Pad's input DQ uses the same scale/zp as the Pad's output Q.
if quant_output_same_as_input:
self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) # Same scale
self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) # Same zero-point
else:
self.assertNotEqual(input_dq_node.input[1], output_q_node.input[1])
self.assertNotEqual(input_dq_node.input[2], output_q_node.input[2])


if __name__ == "__main__":
unittest.main()
Loading
Loading