Skip to content

Commit

Permalink
support conv1d quant & skip calibrate zero-size tensor (#48912)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Dec 13, 2022
1 parent 5d49e3e commit 5eff6f0
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def __init__(
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}
# If the tensor is zero-size during any calibration step,
# it will be stored in self._zero_size_var_names
self._zero_size_var_names = set()
self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model
self._scale_dict = scale_dict
Expand Down Expand Up @@ -465,9 +468,12 @@ def quantize(self):

if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
if var_name not in self._quantized_var_avg:
continue
self._quantized_threshold[var_name] = np.array(
self._quantized_var_avg[var_name]
).mean()

if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()

Expand Down Expand Up @@ -741,6 +747,9 @@ def _sample_mse(self):
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
Expand Down Expand Up @@ -792,6 +801,9 @@ def _sample_emd(self):
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
Expand Down Expand Up @@ -845,6 +857,9 @@ def _sample_avg(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
if var_name not in self._quantized_var_avg:
self._quantized_var_avg[var_name] = []
Expand All @@ -857,7 +872,6 @@ def _sample_avg(self):
)
)
self._quantized_var_avg[var_name].append(abs_avg_value)
continue

def _sample_abs_max(self):
if self._quantized_threshold == {}:
Expand All @@ -884,6 +898,9 @@ def _sample_abs_max(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_threshold) or (
abs_max_value > self._quantized_threshold[var_name]
Expand Down Expand Up @@ -916,6 +933,9 @@ def _sample_min_max(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or (
Expand All @@ -930,6 +950,11 @@ def _sample_min_max(self):
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if (not var_tensor.any()) or (
var_name not in self._sampling_act_histogram
):
self._zero_size_var_names.add(var_name)
continue
var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins)
Expand Down Expand Up @@ -964,6 +989,9 @@ def _sample_ptf(self):

for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
q_max = 2 ** (self._activation_bits - 1) - 1
scale8 = abs_max_value / q_max
Expand Down Expand Up @@ -1020,6 +1048,9 @@ def _collect_activation_abs_min_max(self):
'''
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = np.abs(var_tensor)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
Expand All @@ -1039,6 +1070,10 @@ def _init_sampling_act_histogram(self):
Based on the min/max value, init the sampling_act_histogram.
'''
for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_abs_min_max
):
continue
if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1]
Expand Down Expand Up @@ -1077,6 +1112,10 @@ def _calculate_kl_hist_threshold(self):
self._quantized_var_threshold[var_name] = weight_threshold

for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_histogram
):
continue
hist, hist_edeges = self._sampling_act_histogram[var_name]
if self._algo == "KL":
bin_width = hist_edeges[1] - hist_edeges[0]
Expand Down Expand Up @@ -1162,7 +1201,6 @@ def _update_program(self):
if self._same_scale_tensor_list is not None:
for tensor_list in self._same_scale_tensor_list:
max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
Expand Down Expand Up @@ -1261,21 +1299,40 @@ def _save_output_threshold(self):
self._calibration_scales = {}

def save_info(
op_node, out_var_name, threshold_map, out_info_name, quantized_type
op_node,
out_var_name,
threshold_map,
out_info_name,
argname_index,
quantized_type,
):
assert (
out_var_name in threshold_map
), "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type
)
if (out_var_name in self._zero_size_var_names) and (
out_var_name not in threshold_map
):
_logger.warning(
"{} is zero-size tensor and unable to calibrate, so skip quant it.".format(
out_var_name
)
)
return
else:
assert (
out_var_name in threshold_map
), "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type
)
if self._onnx_format:
# For easy extension, every var_node set a dict to save parameters of quant.
self._calibration_scales[var_name] = {}
self._calibration_scales[var_name]['scale'] = threshold_map[
var_name
self._calibration_scales[out_var_name] = {}
self._calibration_scales[out_var_name]['scale'] = threshold_map[
out_var_name
]
else:
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr(out_info_name, threshold_map[out_var_name])
op_node._set_attr(
argname_index[0] + str(argname_index[1]) + "_threshold",
threshold_map[out_var_name],
)
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
Expand All @@ -1285,52 +1342,23 @@ def analysis_and_save_info(op_node, out_var_name):
assert argname_index is not None, (
out_var_name + " is not the output of the op"
)
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_kl",
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl",
)
elif self._algo == "hist":
if self._algo in ["KL", "hist"]:
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_hist",
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_hist",
argname_index,
"post_" + str(self._algo).lower(),
)

elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]:
save_info(
op_node,
out_var_name,
self._quantized_threshold,
"out_threshold",
"post_" + str(self._algo),
)
save_info(
op_node,
out_var_name,
self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
argname_index,
"post_" + str(self._algo),
)
elif self._algo == "min_max":
Expand All @@ -1339,13 +1367,15 @@ def analysis_and_save_info(op_node, out_var_name):
out_var_name,
self._quantized_var_min,
"out_min",
argname_index,
"post_min_max",
)
save_info(
op_node,
out_var_name,
self._quantized_var_max,
"out_max",
argname_index,
"post_min_max",
)

Expand Down
Loading

0 comments on commit 5eff6f0

Please sign in to comment.