From ad8d48797fc4554ee3afa13d8be5dc5bbf1b45af Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 31 Oct 2024 06:22:40 -0700 Subject: [PATCH 1/8] [Quant tool] Update QDQ Pad and QDQ Slice to quantize output the same as the input. Fix bug when softmax is excluded from QDQ quantization --- .../tools/quantization/base_quantizer.py | 2 + .../tools/quantization/operators/pad.py | 72 ++++++++ .../python/tools/quantization/registry.py | 4 +- .../test/python/quantization/test_op_pad.py | 149 ++++++++++++++++ .../test/python/quantization/test_op_slice.py | 162 ++++++++++++++++++ 5 files changed, 388 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/python/quantization/test_op_slice.py diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index b20af5137d206..97e68733f38eb 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -529,4 +529,6 @@ def adjust_tensor_ranges(self): self.tensors_range[node.input[0]] = td # Adjust Softmax to range from 0.0 to 1.0 elif node.op_type == "Softmax": + if not self.should_quantize_node(node): + continue self.tensors_range[node.output[0]] = TensorData(lowest=np.float32(0.0), highest=np.float32(1.0)) diff --git a/onnxruntime/python/tools/quantization/operators/pad.py b/onnxruntime/python/tools/quantization/operators/pad.py index 5f3c1231e62d6..5a9152d930aae 100644 --- a/onnxruntime/python/tools/quantization/operators/pad.py +++ b/onnxruntime/python/tools/quantization/operators/pad.py @@ -1,3 +1,12 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +from typing import Any + +import numpy as np import onnx from ..quant_utils import ( @@ -8,6 +17,7 @@ quantize_nparray, ) from .base_operator import QuantOperatorBase +from .qdq_base_operator import QDQOperatorBase class QPad(QuantOperatorBase): @@ -98,3 +108,65 @@ def quantize(self): node.input[0] = quantized_input_value.q_name node.output[0] = quantized_output_value.q_name self.quantizer.new_nodes += [node] + + +class QDQPad(QDQOperatorBase): + def __init__(self, onnx_quantizer, onnx_node): + super().__init__(onnx_quantizer, onnx_node) + + def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None: + """ + Returns the Pad's constant padding value. Returns `None` if the padding value is + not constant (i.e., comes from a dynamic input). + """ + const_val = None + onnx_tensor_type = self.quantizer.model.get_tensor_type(self.node.input[0]) + if onnx_tensor_type is None: + return None + + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type) + if self.quantizer.opset_version <= 2: + const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype) + elif len(self.node.input) >= 3 and self.node.input[2]: + const_val = self.quantizer.model.get_constant_value(self.node.input[2]) + else: + const_val = np.array(0, dtype=np_dtype) + + return const_val + + def _should_quantize_output_same_as_input(self) -> bool: + """ + Returns true if Pad's output should use the same quantization parameters as input[0] + """ + attrs_dict = {} + for attribute in self.node.attribute: + kv = attribute_to_kwarg(attribute) + attrs_dict.update(kv) + + pad_mode = attrs_dict.get("mode", b"constant") + if pad_mode in (b"reflect", b"edge", b"wrap"): + # These modes pad the output with a value that already exists in the input. + # So, we can quantize the output the same as the input. + return True + + # For 'constant' mode, if padding with 0, we can also quantize the output the same as the input + # because our quantization floating-point range always includes 0. + if pad_mode == b"constant": + pad_val = self._get_pad_const_val(attrs_dict) + if pad_val is not None and pad_val.dtype in (np.float32, np.float16): + return float(pad_val.item()) == 0 + + return False + + def quantize(self): + assert self.node.op_type == "Pad" + + for input_name in self.node.input: + if input_name: + self.quantizer.quantize_activation_tensor(input_name) + + if not self.disable_qdq_for_node_output: + if self._should_quantize_output_same_as_input(): + self.quantizer.quantize_output_same_as_input(self.node.output[0], self.node.input[0], self.node.name) + else: + self.quantizer.quantize_activation_tensor(self.node.output[0]) diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index 160b056e1de17..fbeae39c39d21 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -14,7 +14,7 @@ from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul from .operators.maxpool import QDQMaxPool, QMaxPool from .operators.norm import QDQNormalization -from .operators.pad import QPad +from .operators.pad import QDQPad, QPad from .operators.pooling import QLinearPool from .operators.qdq_base_operator import QDQOperatorBase from .operators.resize import QDQResize, QResize @@ -76,6 +76,8 @@ "Resize": QDQResize, "MaxPool": QDQMaxPool, "AveragePool": QDQDirect8BitOp, + "Slice": QDQDirect8BitOp, + "Pad": QDQPad, "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 291bf42405d58..72d895dae8435 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -4,8 +4,11 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations import itertools +import os +import tempfile import unittest import numpy as np @@ -519,5 +522,151 @@ def test_pad_with_empty_string_input_name(self): self.assertNotEqual(name, "_quantized") +class TestQDQPad(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.pad_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_pad_model( + self, + mode: str, + constant_value: float | None = None, + ) -> onnx.ModelProto: + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (3, 2)) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (3, 4)) + + initializers = [] + pad_input_names = ["input_0"] + + pads_data = np.array([0, 2, 0, 0], dtype=np.int64) # Pad two vals at beginning of axis 1. + initializers.append(onnx.numpy_helper.from_array(pads_data, "pads")) + pad_input_names.append("pads") + + if mode == "constant" and constant_value is not None: + initializers.append(onnx.helper.make_tensor("constant_value", onnx.TensorProto.FLOAT, [], [constant_value])) + pad_input_names.append("constant_value") + + pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", mode=mode) + + graph = onnx.helper.make_graph( + [pad_node], + "PadFloat", + [input_0], + [output_0], + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_qdq_pad_qparams(self): + """ + Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations. + """ + test_configs = [ + ("constant", None), + ("constant", 0), + ("constant", 10.0), + ("reflect", None), + ("edge", None), + ("wrap", None), + ] + + for pad_mode, constant_value in test_configs: + with self.subTest(pad_mode=pad_mode, constant_value=constant_value): + label = f"_{pad_mode}_{constant_value}" + float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx") + + float_model = self.build_pad_model(pad_mode, constant_value) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + input_data_list = [ + {"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np.float32)}, + {"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np.float32)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + extra_options={"ForceQuantizeNoInputCheck": True}, + ) + + expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1} + if constant_value is not None: + expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized. + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + if pad_mode != "reflect": + # Do not check model correctness for 'reflect' mode because ONNX Runtime implementation does + # not match the ONNX reference implementation. See the following issue: + # https://github.com/microsoft/onnxruntime/issues/20801 + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + qdq_model = onnx.load_model(qdq_model_path) + quant_output_same_as_input = False + + if pad_mode in ("reflect", "edge", "wrap"): + quant_output_same_as_input = True + + if pad_mode == "constant" and constant_value in (None, 0): + quant_output_same_as_input = True + + consumers = {} + producers = {} + pad_node = None + + # Build dictionaries that map a tensor name to the consumer or producer nodes. + for node in qdq_model.graph.node: + if node.op_type == "Pad": + pad_node = node + + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + for output_name in node.output: + producers[output_name] = node + + self.assertEqual(pad_node.op_type, "Pad") + + input_dq_node = producers.get(pad_node.input[0], None) + self.assertNotEqual(input_dq_node, None) + self.assertEqual(input_dq_node.op_type, "DequantizeLinear") + + output_q_node = consumers.get(pad_node.output[0], [None])[0] + self.assertNotEqual(output_q_node, None) + self.assertEqual(output_q_node.op_type, "QuantizeLinear") + + # Check that the Pad's input DQ uses the same scale/zp as the Pad's output Q. + if quant_output_same_as_input: + self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) # Same scale + self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) # Same zero-point + else: + self.assertNotEqual(input_dq_node.input[1], output_q_node.input[1]) + self.assertNotEqual(input_dq_node.input[2], output_q_node.input[2]) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_slice.py b/onnxruntime/test/python/quantization/test_op_slice.py new file mode 100644 index 0000000000000..5769be27a9514 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_slice.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + + +class TestQDQSlice(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.slice_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_slice_model( + self, + input_shape: list[int], + input_tensor_type: onnx.TensorProto.DataType, + starts: list[int], + ends: list[int], + axes: list[int] | None = None, + steps: list[int] | None = None, + ) -> onnx.ModelProto: + """ + Returns an onnx.ModelProto with a single Slice operator. + """ + input_0 = onnx.helper.make_tensor_value_info("input_0", input_tensor_type, input_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", input_tensor_type, None) + + initializers = [ + onnx.numpy_helper.from_array(np.array(starts, dtype=np.int64), "starts"), + onnx.numpy_helper.from_array(np.array(ends, dtype=np.int64), "ends"), + ] + slice_input_names = ["input_0", "starts", "ends"] + + if axes: + initializers.append(onnx.numpy_helper.from_array(np.array(axes, dtype=np.int64), "axes")) + slice_input_names.append("axes") + + if steps: + if not axes: + slice_input_names.append("") # Empty axes input. + initializers.append(onnx.numpy_helper.from_array(np.array(steps, dtype=np.int64), "steps")) + slice_input_names.append("steps") + + slice_node = onnx.helper.make_node("Slice", slice_input_names, ["output_0"], name="Slice0") + + graph = onnx.helper.make_graph( + [slice_node], + "SliceGraph", + [input_0], + [output_0], + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_qdq_slice_qparams(self): + """ + Test that QDQ Slice has equal scale/zero-point for its input and output. + """ + test_configs = [onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16] + + for onnx_tensor_type in test_configs: + with self.subTest(onnx_tensor_type=onnx_tensor_type): + label = f"{onnx.TensorProto.DataType.Name(onnx_tensor_type)}" + float_model_path = os.path.join(self._tmp_dir_path, f"slice.{label}.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"slice.{label}.qdq.onnx") + + input_shape = [2, 4] + float_model = self.build_slice_model( + input_shape=input_shape, + input_tensor_type=onnx_tensor_type, + starts=[1, 0], + ends=[2, 3], + axes=None, + steps=[1, 2], + ) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type) + input_data_list = [ + {"input_0": np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=np_dtype)}, + {"input_0": np.array([[-1.0, -2.0, -3.0, -4.0], [-5.0, -6.0, -7.0, -8.0]], dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + extra_options={"ForceQuantizeNoInputCheck": True}, + ) + expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Slice": 1} + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + qdq_model = onnx.load_model(qdq_model_path) + consumers = {} + producers = {} + slice_node = None + + # Build dictionaries that map a tensor name to the consumer or producer nodes. + for node in qdq_model.graph.node: + if node.op_type == "Slice": + slice_node = node + + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + for output_name in node.output: + producers[output_name] = node + + self.assertEqual(slice_node.op_type, "Slice") + + input_dq_node = producers.get(slice_node.input[0], None) + self.assertNotEqual(input_dq_node, None) + self.assertEqual(input_dq_node.op_type, "DequantizeLinear") + + output_q_node = consumers.get(slice_node.output[0], [None])[0] + self.assertNotEqual(output_q_node, None) + self.assertEqual(output_q_node.op_type, "QuantizeLinear") + + # Check that the Slice's input DQ uses the same scale/zp as the Slice's output Q. + self.assertEqual(input_dq_node.input[1], output_q_node.input[1]) + self.assertEqual(input_dq_node.input[2], output_q_node.input[2]) + + +if __name__ == "__main__": + unittest.main() From 4e24125f83682dccc137b406315d90e55fddeb0e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 4 Nov 2024 16:33:35 -0800 Subject: [PATCH 2/8] Properly test and handle Pad opset 2 --- .../tools/quantization/operators/pad.py | 2 +- .../test/python/quantization/test_op_pad.py | 49 +++++++++++++------ 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/onnxruntime/python/tools/quantization/operators/pad.py b/onnxruntime/python/tools/quantization/operators/pad.py index 5a9152d930aae..b3e9ddb5e6278 100644 --- a/onnxruntime/python/tools/quantization/operators/pad.py +++ b/onnxruntime/python/tools/quantization/operators/pad.py @@ -125,7 +125,7 @@ def _get_pad_const_val(self, attrs_dict: dict[str, Any]) -> np.ndarray | None: return None np_dtype = onnx.helper.tensor_dtype_to_np_dtype(onnx_tensor_type.elem_type) - if self.quantizer.opset_version <= 2: + if self.quantizer.opset_version < 11: const_val = np.array(attrs_dict.get("value", 0), dtype=np_dtype) elif len(self.node.input) >= 3 and self.node.input[2]: const_val = self.quantizer.model.get_constant_value(self.node.input[2]) diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 72d895dae8435..5bc9885e39bbb 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -539,22 +539,32 @@ def build_pad_model( self, mode: str, constant_value: float | None = None, + opset: int = 21, ) -> onnx.ModelProto: input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (3, 2)) output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (3, 4)) initializers = [] pad_input_names = ["input_0"] + attrs = {"mode": mode} pads_data = np.array([0, 2, 0, 0], dtype=np.int64) # Pad two vals at beginning of axis 1. - initializers.append(onnx.numpy_helper.from_array(pads_data, "pads")) - pad_input_names.append("pads") + if opset >= 11: + initializers.append(onnx.numpy_helper.from_array(pads_data, "pads")) + pad_input_names.append("pads") + else: + attrs["pads"] = pads_data.tolist() if mode == "constant" and constant_value is not None: - initializers.append(onnx.helper.make_tensor("constant_value", onnx.TensorProto.FLOAT, [], [constant_value])) - pad_input_names.append("constant_value") + if opset >= 11: + initializers.append( + onnx.helper.make_tensor("constant_value", onnx.TensorProto.FLOAT, [], [constant_value]) + ) + pad_input_names.append("constant_value") + else: + attrs["value"] = float(constant_value) - pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", mode=mode) + pad_node = onnx.helper.make_node("Pad", pad_input_names, ["output_0"], name="Pad0", **attrs) graph = onnx.helper.make_graph( [pad_node], @@ -563,7 +573,7 @@ def build_pad_model( [output_0], initializer=initializers, ) - opset_imports = [onnx.helper.make_opsetid("", 21)] + opset_imports = [onnx.helper.make_opsetid("", opset)] model = onnx.helper.make_model(graph, opset_imports=opset_imports) model = onnx.shape_inference.infer_shapes(model) onnx.checker.check_model(model, True) @@ -574,21 +584,28 @@ def test_qdq_pad_qparams(self): Test that QDQ Pad has equal scale/zero-point for its input and output for certain configurations. """ test_configs = [ - ("constant", None), - ("constant", 0), - ("constant", 10.0), - ("reflect", None), - ("edge", None), - ("wrap", None), + # Opset 21 + ("constant", None, 21), + ("constant", 0, 21), + ("constant", 10.0, 21), + ("reflect", None, 21), + ("edge", None, 21), + ("wrap", None, 21), + # Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs. + ("constant", None, 10), + ("constant", 0, 10), + ("constant", 10.0, 10), + ("reflect", None, 10), + ("edge", None, 10), ] - for pad_mode, constant_value in test_configs: - with self.subTest(pad_mode=pad_mode, constant_value=constant_value): + for pad_mode, constant_value, opset in test_configs: + with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset): label = f"_{pad_mode}_{constant_value}" float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx") qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx") - float_model = self.build_pad_model(pad_mode, constant_value) + float_model = self.build_pad_model(pad_mode, constant_value, opset=opset) onnx.save_model(float_model, float_model_path) # Create a data reader @@ -610,7 +627,7 @@ def test_qdq_pad_qparams(self): ) expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1} - if constant_value is not None: + if constant_value is not None and opset >= 11: expected_op_counts["DequantizeLinear"] += 1 # The constant padding value is quantized. check_op_type_count(self, qdq_model_path, **expected_op_counts) From 72766cda84374e4f99d7918ee51766185e9375d2 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 4 Nov 2024 16:55:11 -0800 Subject: [PATCH 3/8] Add unittest for softmax bug fix (when softmax is excluded) --- .../python/quantization/test_op_softmax.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/onnxruntime/test/python/quantization/test_op_softmax.py b/onnxruntime/test/python/quantization/test_op_softmax.py index 3416198450137..e5bc6288c91e2 100644 --- a/onnxruntime/test/python/quantization/test_op_softmax.py +++ b/onnxruntime/test/python/quantization/test_op_softmax.py @@ -213,6 +213,40 @@ def test_quantize_softmax(self): self.quantize_softmax_test_qop(QuantType.QUInt8, QuantType.QUInt8) self.quantize_softmax_test_qdq(QuantType.QUInt8, QuantType.QUInt8) + def test_bug_fix_exclude_softmax(self): + """ + Test fix to bug that happens when softmax is excluded from quantization, but + the quantization tool still tries to assign it a tensor range of [0.0, 1.0]. + """ + np.random.seed(1) + model_fp32_path = "softmax_fp32.onnx" + model_qdq_path = "softmax_bug_exclude_softmax.qdq.onnx" + self.construct_model_conv_softmax( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"axis": -2}, + [1, 3, 24, 40], + add_ms_domain_opset=False, + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + data_reader.rewind() + + # Bug would cause an exception during quantization. + quantize_static( + model_fp32_path, + model_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + nodes_to_exclude=["Softmax"], + ) + + qdq_model = onnx.load(Path(model_qdq_path)) + self.assertIn("Softmax", {node.op_type for node in qdq_model.graph.node}) + def test_quantize_softmax_s8s8(self): self.quantize_softmax_test_qop( QuantType.QInt8, From c63df678afbc2de3959a488a8b45eb7d74351473 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 5 Nov 2024 13:42:07 -0800 Subject: [PATCH 4/8] Add float16 testing for pad unit test --- .../test/python/quantization/test_op_pad.py | 49 +++++++++++-------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 5bc9885e39bbb..2935b40477d83 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -540,9 +540,10 @@ def build_pad_model( mode: str, constant_value: float | None = None, opset: int = 21, + float_type: onnx.TensorProto.DataType = onnx.TensorProto.FLOAT, ) -> onnx.ModelProto: - input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, (3, 2)) - output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, (3, 4)) + input_0 = onnx.helper.make_tensor_value_info("input_0", float_type, (3, 2)) + output_0 = onnx.helper.make_tensor_value_info("output_0", float_type, (3, 4)) initializers = [] pad_input_names = ["input_0"] @@ -557,9 +558,7 @@ def build_pad_model( if mode == "constant" and constant_value is not None: if opset >= 11: - initializers.append( - onnx.helper.make_tensor("constant_value", onnx.TensorProto.FLOAT, [], [constant_value]) - ) + initializers.append(onnx.helper.make_tensor("constant_value", float_type, [], [constant_value])) pad_input_names.append("constant_value") else: attrs["value"] = float(constant_value) @@ -585,33 +584,41 @@ def test_qdq_pad_qparams(self): """ test_configs = [ # Opset 21 - ("constant", None, 21), - ("constant", 0, 21), - ("constant", 10.0, 21), - ("reflect", None, 21), - ("edge", None, 21), - ("wrap", None, 21), + ("constant", None, 21, onnx.TensorProto.FLOAT), + ("constant", None, 21, onnx.TensorProto.FLOAT16), + ("constant", 0, 21, onnx.TensorProto.FLOAT), + ("constant", 0, 21, onnx.TensorProto.FLOAT16), + ("constant", 10.0, 21, onnx.TensorProto.FLOAT), + ("constant", 10.0, 21, onnx.TensorProto.FLOAT16), + ("reflect", None, 21, onnx.TensorProto.FLOAT), + ("reflect", None, 21, onnx.TensorProto.FLOAT16), + ("edge", None, 21, onnx.TensorProto.FLOAT), + ("edge", None, 21, onnx.TensorProto.FLOAT16), + ("wrap", None, 21, onnx.TensorProto.FLOAT), + ("wrap", None, 21, onnx.TensorProto.FLOAT16), # Model with opset 10 will use pad of opset 2, which uses attributes instead of inputs. - ("constant", None, 10), - ("constant", 0, 10), - ("constant", 10.0, 10), - ("reflect", None, 10), - ("edge", None, 10), + # Opset 10 Q/DQ ops don't support float16. + ("constant", None, 10, onnx.TensorProto.FLOAT), + ("constant", 0, 10, onnx.TensorProto.FLOAT), + ("constant", 10.0, 10, onnx.TensorProto.FLOAT), + ("reflect", None, 10, onnx.TensorProto.FLOAT), + ("edge", None, 10, onnx.TensorProto.FLOAT), ] - for pad_mode, constant_value, opset in test_configs: - with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset): + for pad_mode, constant_value, opset, float_type in test_configs: + with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type): label = f"_{pad_mode}_{constant_value}" float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx") qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx") - float_model = self.build_pad_model(pad_mode, constant_value, opset=opset) + float_model = self.build_pad_model(pad_mode, constant_value, opset=opset, float_type=float_type) onnx.save_model(float_model, float_model_path) # Create a data reader + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(float_type) input_data_list = [ - {"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np.float32)}, - {"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np.float32)}, + {"input_0": np.array([[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], dtype=np_dtype)}, + {"input_0": np.array([[2.3, 3.4], [4.5, 5.7], [1.0, 1.2]], dtype=np_dtype)}, ] data_reader = TestDataFeeds(input_data_list) From 343b0da598b5f9297892b400cb0e4970cbe1c01b Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 5 Nov 2024 13:45:05 -0800 Subject: [PATCH 5/8] Remove unnecessary extra option --- onnxruntime/test/python/quantization/test_op_pad.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 2935b40477d83..f0c8d41ef0d56 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -630,7 +630,6 @@ def test_qdq_pad_qparams(self): quant_format=QuantFormat.QDQ, activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, - extra_options={"ForceQuantizeNoInputCheck": True}, ) expected_op_counts = {"DequantizeLinear": 2, "QuantizeLinear": 2, "Pad": 1} From 6016c95b6d9938ce9ea5b342ea20c850f82722b9 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 5 Nov 2024 15:40:53 -0800 Subject: [PATCH 6/8] Refactor common unittest utility function --- .../test/python/quantization/op_test_utils.py | 33 +++++++++++++++++++ .../test/python/quantization/test_op_pad.py | 31 +++++++---------- .../test/python/quantization/test_op_slice.py | 29 ++++++---------- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index cf7fc292ea86b..0106f2e056afe 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -1,3 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + import uuid from pathlib import Path @@ -661,3 +668,29 @@ def generate_random_initializer(initializer_name, tensor_shape, tensor_dtype, me tensor = np.random.normal(mean, dev, tensor_shape).astype(tensor_dtype) init = onnx.numpy_helper.from_array(tensor, initializer_name) return init + + +def get_tensor_consumers_and_producers( + model: onnx.ModelProto, +) -> tuple[dict[str, list[onnx.NodeProto]], dict[str, onnx.NodeProto]]: + """ + Returns a tuple containing the following python dictionaries: + - consumers: maps a tensor name to the list of nodes that have that tensor as an input. + - producers: maps a tensor name to the node that generates this tensor as an output. + """ + consumers: dict[str, list[onnx.NodeProto]] = {} + producers: dict[str, onnx.NodeProto] = {} + for node in model.graph.node: + # Iterate throught node's inputs to build the consumers dictionary. + for input_name in node.input: + if input_name: + if input_name not in consumers: + consumers[input_name] = [] + + consumers[input_name].append(node) + + # Iterate through node's outputs to build the producers dictionary. + for output_name in node.output: + producers[output_name] = node + + return (consumers, producers) diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index f0c8d41ef0d56..5267d94c619c3 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -14,7 +14,13 @@ import numpy as np import onnx from onnx import TensorProto, helper -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + check_qtype_by_node_type, + get_tensor_consumers_and_producers, +) from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static @@ -653,27 +659,12 @@ def test_qdq_pad_qparams(self): if pad_mode == "constant" and constant_value in (None, 0): quant_output_same_as_input = True - consumers = {} - producers = {} - pad_node = None - - # Build dictionaries that map a tensor name to the consumer or producer nodes. - for node in qdq_model.graph.node: - if node.op_type == "Pad": - pad_node = node - - for input_name in node.input: - if input_name: - if input_name not in consumers: - consumers[input_name] = [] - - consumers[input_name].append(node) - - for output_name in node.output: - producers[output_name] = node - + pad_node = next((node for node in qdq_model.graph.node if node.op_type == "Pad"), None) + self.assertNotEqual(pad_node, None) self.assertEqual(pad_node.op_type, "Pad") + # Get the parent and child nodes of the Pad and check that they are DQ/Q. + consumers, producers = get_tensor_consumers_and_producers(qdq_model) input_dq_node = producers.get(pad_node.input[0], None) self.assertNotEqual(input_dq_node, None) self.assertEqual(input_dq_node.op_type, "DequantizeLinear") diff --git a/onnxruntime/test/python/quantization/test_op_slice.py b/onnxruntime/test/python/quantization/test_op_slice.py index 5769be27a9514..bfb9fc6b46bbd 100644 --- a/onnxruntime/test/python/quantization/test_op_slice.py +++ b/onnxruntime/test/python/quantization/test_op_slice.py @@ -12,7 +12,12 @@ import numpy as np import onnx -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + get_tensor_consumers_and_producers, +) from onnxruntime.quantization import QuantFormat, QuantType, quantize_static @@ -124,27 +129,13 @@ def test_qdq_slice_qparams(self): check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) qdq_model = onnx.load_model(qdq_model_path) - consumers = {} - producers = {} - slice_node = None - - # Build dictionaries that map a tensor name to the consumer or producer nodes. - for node in qdq_model.graph.node: - if node.op_type == "Slice": - slice_node = node - - for input_name in node.input: - if input_name: - if input_name not in consumers: - consumers[input_name] = [] - - consumers[input_name].append(node) - - for output_name in node.output: - producers[output_name] = node + slice_node = next((node for node in qdq_model.graph.node if node.op_type == "Slice"), None) + self.assertNotEqual(slice_node, None) self.assertEqual(slice_node.op_type, "Slice") + # Get the parent and child nodes of the Slice and check that they are DQ/Q. + consumers, producers = get_tensor_consumers_and_producers(qdq_model) input_dq_node = producers.get(slice_node.input[0], None) self.assertNotEqual(input_dq_node, None) self.assertEqual(input_dq_node.op_type, "DequantizeLinear") From 8d1562ab684fca69f4d0549fab2d777bb6535047 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 5 Nov 2024 15:43:11 -0800 Subject: [PATCH 7/8] Update onnxruntime/test/python/quantization/op_test_utils.py --- onnxruntime/test/python/quantization/op_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index 0106f2e056afe..82193d08684c6 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -681,7 +681,7 @@ def get_tensor_consumers_and_producers( consumers: dict[str, list[onnx.NodeProto]] = {} producers: dict[str, onnx.NodeProto] = {} for node in model.graph.node: - # Iterate throught node's inputs to build the consumers dictionary. + # Iterate through node's inputs to build the consumers dictionary. for input_name in node.input: if input_name: if input_name not in consumers: From 9b26a0703b1d73ad3594a93f0389579a4c7b5fa0 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 6 Nov 2024 10:24:00 -0800 Subject: [PATCH 8/8] Address comment for unittest --- onnxruntime/test/python/quantization/test_op_pad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 5267d94c619c3..05736019cd7c8 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -613,7 +613,7 @@ def test_qdq_pad_qparams(self): for pad_mode, constant_value, opset, float_type in test_configs: with self.subTest(pad_mode=pad_mode, constant_value=constant_value, opset=opset, float_type=float_type): - label = f"_{pad_mode}_{constant_value}" + label = f"_{pad_mode}_{constant_value}_opset{opset}_{onnx.TensorProto.DataType.Name(float_type)}" float_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.float.onnx") qdq_model_path = os.path.join(self._tmp_dir_path, f"pad{label}.qdq.onnx")