From fe55798dd855a8da2d572b635da658b91adc50af Mon Sep 17 00:00:00 2001 From: disktnk Date: Sun, 18 Aug 2019 01:05:36 +0900 Subject: [PATCH 1/3] fix to move implicit input manager to context --- onnx_chainer/context.py | 1 + onnx_chainer/export.py | 4 ++-- onnx_chainer/functions/normalization.py | 4 ++++ onnx_chainer/graph.py | 3 +-- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnx_chainer/context.py b/onnx_chainer/context.py index 8132a8e..be0f1b0 100644 --- a/onnx_chainer/context.py +++ b/onnx_chainer/context.py @@ -36,6 +36,7 @@ def __init__(self, model): self.name_list = dict() self.parameters = [] self.constants = [] + self.implicit_inputs = dict() # inputs which not connect to output namedlink = {n: l for n, l in model.namedlinks()} self.param_to_link = {} for name, param in model.namedparams(): diff --git a/onnx_chainer/export.py b/onnx_chainer/export.py index dc410d1..8d3ca6d 100644 --- a/onnx_chainer/export.py +++ b/onnx_chainer/export.py @@ -407,10 +407,10 @@ def _export(model, args, filename, export_params, graph_name, save_text, o = Graph(context, converters, opset_version, network_outputs) o.to_onnx_graph() - implicit_input_names = set(o.inputs.keys()) - param_names -\ + implicit_input_names = set(context.implicit_inputs.keys()) - param_names -\ set(network_inputs.keys()) for name in implicit_input_names: - tensor = convert_parameter(o.inputs[name], context) + tensor = convert_parameter(context.implicit_inputs[name], context) initializers.append(tensor) input_tensors.append(helper.make_tensor_value_info( name, tensor.data_type, tensor.dims)) diff --git a/onnx_chainer/functions/normalization.py b/onnx_chainer/functions/normalization.py index 0998b61..caa3a1b 100644 --- a/onnx_chainer/functions/normalization.py +++ b/onnx_chainer/functions/normalization.py @@ -49,15 +49,19 @@ def add_param(v, suffix): maen_name = add_param(mean, 'avg_mean') var_name = add_param(var, 'avg_var') if is_fixed_bn: + context.implicit_inputs.pop(input_names[3], None) + context.implicit_inputs.pop(input_names[4], None) input_names[3] = maen_name input_names[4] = var_name else: input_names.extend([maen_name, var_name]) if beta_param is None: + context.implicit_inputs.pop(input_names[2], None) beta_name = add_param(np.zeros_like(mean, dtype=mean.dtype), 'beta') input_names[2] = beta_name if gamma_param is None: + context.implicit_inputs.pop(input_names[1], None) gamma_name = add_param(np.ones_like(mean, dtype=mean.dtype), 'gamma') input_names[1] = gamma_name diff --git a/onnx_chainer/graph.py b/onnx_chainer/graph.py index 7c27cf4..c29959b 100644 --- a/onnx_chainer/graph.py +++ b/onnx_chainer/graph.py @@ -16,7 +16,6 @@ def __init__(self, context, converters, opset_version, network_outputs): self.graph = [] self.func_name_counts = collections.defaultdict(int) - self.inputs = {} # Input `Variable` objects keyed by string IDs self.outputs = set() # Output variable names self.specified_opset_version = opset_version self.network_outputs = network_outputs @@ -90,7 +89,7 @@ def convert_to_onnx_node(self, function): input_name = self.context.get_name(var) if input_name not in self.outputs: # register input variables to check implicit inputs - self.inputs[input_name] = var + self.context.implicit_inputs[input_name] = var input_names.append(input_name) # This is to get corresponding VariableNode id from the output From 67aaf50d05ad3900da1a3190d7e611ffba8dbabb Mon Sep 17 00:00:00 2001 From: disktnk Date: Sun, 18 Aug 2019 01:24:06 +0900 Subject: [PATCH 2/3] remove unnecessary variable set --- onnx_chainer/functions/normalization.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/onnx_chainer/functions/normalization.py b/onnx_chainer/functions/normalization.py index caa3a1b..c9a41bf 100644 --- a/onnx_chainer/functions/normalization.py +++ b/onnx_chainer/functions/normalization.py @@ -13,8 +13,8 @@ def convert_BatchNormalization( func, opset_version, input_names, output_names, context): is_fixed_bn = len(func.inputs) > 3 - # NOTE(disktnk): - # if `use_beta=False`, beta_param is None, `use_gamma=False` is same. + # NOTE: even if `use_beta=False` or `use_gamma=False`, beta or gamma + # are set in inputs by RetainHook, beta_param = func.inputs[2].get_variable_or_none() gamma_param = func.inputs[1].get_variable_or_none() namedlink = context.get_link(beta_param) or context.get_link(gamma_param) @@ -56,15 +56,6 @@ def add_param(v, suffix): else: input_names.extend([maen_name, var_name]) - if beta_param is None: - context.implicit_inputs.pop(input_names[2], None) - beta_name = add_param(np.zeros_like(mean, dtype=mean.dtype), 'beta') - input_names[2] = beta_name - if gamma_param is None: - context.implicit_inputs.pop(input_names[1], None) - gamma_name = add_param(np.ones_like(mean, dtype=mean.dtype), 'gamma') - input_names[1] = gamma_name - momentum = getattr(func, 'decay', 0.) # TODO(disktnk): On definition of ONNX's BatchNormalization operator, From 1d3feb91a30112bc0c3dde57af2304e20a0d62c9 Mon Sep 17 00:00:00 2001 From: disktnk Date: Sun, 18 Aug 2019 17:59:10 +0900 Subject: [PATCH 3/3] fix to add param count check --- tests/functions_tests/test_normalizations.py | 24 ++++++++++++-------- tests/helper.py | 6 ++--- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/functions_tests/test_normalizations.py b/tests/functions_tests/test_normalizations.py index 1612b59..85dfa60 100644 --- a/tests/functions_tests/test_normalizations.py +++ b/tests/functions_tests/test_normalizations.py @@ -5,6 +5,7 @@ import numpy as np from onnx_chainer.testing import input_generator +from tests.helper import get_initializer_names from tests.helper import ONNXModelTest @@ -84,9 +85,10 @@ def test_output(self): name += '_' + self.condition def test_input_names(onnx_model): - input_names = set(v.name for v in onnx_model.graph.input) - assert 'param_bn_avg_mean' in input_names - assert 'param_bn_avg_var' in input_names + initializer_names = get_initializer_names(onnx_model) + assert len(initializer_names) == 4 + assert 'param_bn_avg_mean' in initializer_names + assert 'param_bn_avg_var' in initializer_names self.expect( self.model, self.x, name=name, train=train, @@ -130,9 +132,10 @@ def __call__(self, x): def test_output(self): def test_input_names(onnx_model): - input_names = set(v.name for v in onnx_model.graph.input) - assert 'BatchNormalization_0_param_avg_mean' in input_names - assert 'BatchNormalization_0_param_avg_var' in input_names + initializer_names = get_initializer_names(onnx_model) + assert len(initializer_names) == 4 + assert 'BatchNormalization_0_param_avg_mean' in initializer_names + assert 'BatchNormalization_0_param_avg_var' in initializer_names self.expect( self.model, self.x, custom_model_test_func=test_input_names) @@ -157,9 +160,12 @@ def __call__(self, x): def test_output(self): def test_input_names(onnx_model): - input_names = set(v.name for v in onnx_model.graph.input) - assert 'FixedBatchNormalization_0_param_avg_mean' in input_names - assert 'FixedBatchNormalization_0_param_avg_var' in input_names + initializer_names = get_initializer_names(onnx_model) + assert len(initializer_names) == 4 + assert 'FixedBatchNormalization_0_param_avg_mean' in\ + initializer_names + assert 'FixedBatchNormalization_0_param_avg_var' in\ + initializer_names self.expect( self.model, self.x, custom_model_test_func=test_input_names) diff --git a/tests/helper.py b/tests/helper.py index b2e3c88..b3ad9b9 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -143,7 +143,7 @@ def to_gpu(self, model, x): def check_all_connected_from_inputs(onnx_model): - edge_names = _get_initializer_names(onnx_model) |\ + edge_names = get_initializer_names(onnx_model) |\ _get_input_names(onnx_model) # Nodes which are not connected from the network inputs. orphan_nodes = [] @@ -159,7 +159,7 @@ def check_all_connected_from_inputs(onnx_model): assert not(orphan_nodes), '{}'.format(orphan_nodes) -def _get_initializer_names(onnx_model): +def get_initializer_names(onnx_model): return {i.name for i in onnx_model.graph.initializer} @@ -169,4 +169,4 @@ def _get_input_names(onnx_model): def _get_graph_input_names(onnx_model): return list( - _get_input_names(onnx_model) - _get_initializer_names(onnx_model)) + _get_input_names(onnx_model) - get_initializer_names(onnx_model))