Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move implicit input manager to Context #226

Merged
merged 3 commits into from
Aug 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnx_chainer/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, model):
self.name_list = dict()
self.parameters = []
self.constants = []
self.implicit_inputs = dict() # inputs which not connect to output
namedlink = {n: l for n, l in model.namedlinks()}
self.param_to_link = {}
for name, param in model.namedparams():
Expand Down
4 changes: 2 additions & 2 deletions onnx_chainer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,10 @@ def _export(model, args, filename, export_params, graph_name, save_text,
o = Graph(context, converters, opset_version, network_outputs)
o.to_onnx_graph()

implicit_input_names = set(o.inputs.keys()) - param_names -\
implicit_input_names = set(context.implicit_inputs.keys()) - param_names -\
set(network_inputs.keys())
for name in implicit_input_names:
tensor = convert_parameter(o.inputs[name], context)
tensor = convert_parameter(context.implicit_inputs[name], context)
initializers.append(tensor)
input_tensors.append(helper.make_tensor_value_info(
name, tensor.data_type, tensor.dims))
Expand Down
13 changes: 4 additions & 9 deletions onnx_chainer/functions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def convert_BatchNormalization(
func, opset_version, input_names, output_names, context):
is_fixed_bn = len(func.inputs) > 3

# NOTE(disktnk):
# if `use_beta=False`, beta_param is None, `use_gamma=False` is same.
# NOTE: even if `use_beta=False` or `use_gamma=False`, beta or gamma
# are set in inputs by RetainHook,
beta_param = func.inputs[2].get_variable_or_none()
gamma_param = func.inputs[1].get_variable_or_none()
namedlink = context.get_link(beta_param) or context.get_link(gamma_param)
Expand Down Expand Up @@ -49,18 +49,13 @@ def add_param(v, suffix):
maen_name = add_param(mean, 'avg_mean')
var_name = add_param(var, 'avg_var')
if is_fixed_bn:
context.implicit_inputs.pop(input_names[3], None)
context.implicit_inputs.pop(input_names[4], None)
input_names[3] = maen_name
input_names[4] = var_name
else:
input_names.extend([maen_name, var_name])

if beta_param is None:
beta_name = add_param(np.zeros_like(mean, dtype=mean.dtype), 'beta')
input_names[2] = beta_name
if gamma_param is None:
gamma_name = add_param(np.ones_like(mean, dtype=mean.dtype), 'gamma')
input_names[1] = gamma_name

momentum = getattr(func, 'decay', 0.)

# TODO(disktnk): On definition of ONNX's BatchNormalization operator,
Expand Down
3 changes: 1 addition & 2 deletions onnx_chainer/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(self, context, converters, opset_version, network_outputs):

self.graph = []
self.func_name_counts = collections.defaultdict(int)
self.inputs = {} # Input `Variable` objects keyed by string IDs
self.outputs = set() # Output variable names
self.specified_opset_version = opset_version
self.network_outputs = network_outputs
Expand Down Expand Up @@ -90,7 +89,7 @@ def convert_to_onnx_node(self, function):
input_name = self.context.get_name(var)
if input_name not in self.outputs:
# register input variables to check implicit inputs
self.inputs[input_name] = var
self.context.implicit_inputs[input_name] = var
input_names.append(input_name)

# This is to get corresponding VariableNode id from the output
Expand Down
24 changes: 15 additions & 9 deletions tests/functions_tests/test_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from onnx_chainer.testing import input_generator
from tests.helper import get_initializer_names
from tests.helper import ONNXModelTest


Expand Down Expand Up @@ -84,9 +85,10 @@ def test_output(self):
name += '_' + self.condition

def test_input_names(onnx_model):
input_names = set(v.name for v in onnx_model.graph.input)
assert 'param_bn_avg_mean' in input_names
assert 'param_bn_avg_var' in input_names
initializer_names = get_initializer_names(onnx_model)
assert len(initializer_names) == 4
assert 'param_bn_avg_mean' in initializer_names
assert 'param_bn_avg_var' in initializer_names

self.expect(
self.model, self.x, name=name, train=train,
Expand Down Expand Up @@ -130,9 +132,10 @@ def __call__(self, x):
def test_output(self):

def test_input_names(onnx_model):
input_names = set(v.name for v in onnx_model.graph.input)
assert 'BatchNormalization_0_param_avg_mean' in input_names
assert 'BatchNormalization_0_param_avg_var' in input_names
initializer_names = get_initializer_names(onnx_model)
assert len(initializer_names) == 4
assert 'BatchNormalization_0_param_avg_mean' in initializer_names
assert 'BatchNormalization_0_param_avg_var' in initializer_names

self.expect(
self.model, self.x, custom_model_test_func=test_input_names)
Expand All @@ -157,9 +160,12 @@ def __call__(self, x):
def test_output(self):

def test_input_names(onnx_model):
input_names = set(v.name for v in onnx_model.graph.input)
assert 'FixedBatchNormalization_0_param_avg_mean' in input_names
assert 'FixedBatchNormalization_0_param_avg_var' in input_names
initializer_names = get_initializer_names(onnx_model)
assert len(initializer_names) == 4
assert 'FixedBatchNormalization_0_param_avg_mean' in\
initializer_names
assert 'FixedBatchNormalization_0_param_avg_var' in\
initializer_names

self.expect(
self.model, self.x, custom_model_test_func=test_input_names)
6 changes: 3 additions & 3 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def to_gpu(self, model, x):


def check_all_connected_from_inputs(onnx_model):
edge_names = _get_initializer_names(onnx_model) |\
edge_names = get_initializer_names(onnx_model) |\
_get_input_names(onnx_model)
# Nodes which are not connected from the network inputs.
orphan_nodes = []
Expand All @@ -159,7 +159,7 @@ def check_all_connected_from_inputs(onnx_model):
assert not(orphan_nodes), '{}'.format(orphan_nodes)


def _get_initializer_names(onnx_model):
def get_initializer_names(onnx_model):
return {i.name for i in onnx_model.graph.initializer}


Expand All @@ -169,4 +169,4 @@ def _get_input_names(onnx_model):

def _get_graph_input_names(onnx_model):
return list(
_get_input_names(onnx_model) - _get_initializer_names(onnx_model))
_get_input_names(onnx_model) - get_initializer_names(onnx_model))