From 1aa8cedced113806e3b35f653bcd03f2cf364d6b Mon Sep 17 00:00:00 2001 From: Mark Ibrahim Date: Thu, 23 Apr 2020 10:07:43 -0700 Subject: [PATCH] Refactor onnx conversion (#100) Summary: Pull Request resolved: https://github.com/facebookresearch/CrypTen/pull/100 This diff refactors CrypTen's conversion of onnx models to allow for future alterations of the imported onnx graph. Changes include * moving logic of hairy onnx conversions into a separate module (`onnx_converter.py`) with its own tests * adding several tests for helper functions to map onnx graphs to crypten modules * refactoring more complex functions such as `from_pytorch` and `from_onnx` using helpers for easier testings / debugging * using a context manager for all `io.BytesIO()` streams to ensure resource is released Reviewed By: knottb Differential Revision: D21072663 fbshipit-source-id: f0e7814195c7fd2e0e926a5656dec59dba9d6a43 --- crypten/nn/__init__.py | 274 ++---------------------- crypten/nn/onnx_converter.py | 384 +++++++++++++++++++++++++++++++++ crypten/nn/onnx_helper.py | 151 ------------- test/test_nn.py | 237 --------------------- test/test_onnx_converter.py | 396 +++++++++++++++++++++++++++++++++++ 5 files changed, 792 insertions(+), 650 deletions(-) create mode 100644 crypten/nn/onnx_converter.py delete mode 100644 crypten/nn/onnx_helper.py create mode 100644 test/test_onnx_converter.py diff --git a/crypten/nn/__init__.py b/crypten/nn/__init__.py index f147ff91..05b92119 100644 --- a/crypten/nn/__init__.py +++ b/crypten/nn/__init__.py @@ -5,13 +5,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import io -from collections import OrderedDict - -import onnx -import torch -import torch.onnx.utils -from onnx import numpy_helper from .loss import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, L1Loss, MSELoss from .module import ( @@ -52,39 +45,19 @@ Sum, Transpose, Unsqueeze, - _BatchNorm, - _ConstantPad, - _Pool2d, -) -from .onnx_helper import ( - _sync_parameters, - _update_onnx_symbolic_registry, - get_attribute_value, - get_parameter_name, ) +from .onnx_converter import TF_AND_TF2ONNX, from_pytorch, from_tensorflow -try: - import tensorflow as tf # noqa - import tf2onnx - - TF_AND_TF2ONNX = True -except ImportError: - TF_AND_TF2ONNX = False - - -# expose contents of package: +# expose contents of package __all__ = [ - "MSELoss", - "L1Loss", - "BCELoss", - "BCEWithLogitsLoss", "Add", "AvgPool2d", - "_BatchNorm", "BatchNorm1d", "BatchNorm2d", "BatchNorm3d", + "BCELoss", + "BCEWithLogitsLoss", "Concat", "Constant", "ConstantPad1d", @@ -99,251 +72,28 @@ "DropoutNd", "Exp", "Flatten", + "from_pytorch", + "from_tensorflow", "Gather", "GlobalAveragePool", "Graph", + "L1Loss", "Linear", "LogSoftmax", + "MatMul", "MaxPool2d", + "Mean", "Module", - "_Pool2d", + "MSELoss", "ReLU", - "Mean", - "Sum", "Reshape", "Sequential", "Shape", "Softmax", "Squeeze", "Sub", + "Sum", + "TF_AND_TF2ONNX", "Transpose", "Unsqueeze", ] - -# mapping from ONNX to crypten.nn: -ONNX_TO_CRYPTEN = { - "Add": Add, - "AveragePool": AvgPool2d, - "BatchNormalization": _BatchNorm, - "Concat": Concat, - "Constant": Constant, - "Dropout": Dropout, - "Dropout2d": Dropout2d, - "Dropout3d": Dropout3d, - "DropoutNd": DropoutNd, - "Exp": Exp, - "Flatten": Flatten, - "Gather": Gather, - "Gemm": Linear, - "GlobalAveragePool": GlobalAveragePool, - "LogSoftmax": LogSoftmax, - "MatMul": MatMul, - "MaxPool": MaxPool2d, - "Pad": _ConstantPad, - "Relu": ReLU, - "ReduceMean": Mean, - "ReduceSum": Sum, - "Reshape": Reshape, - "Shape": Shape, - "Softmax": Softmax, - "Squeeze": Squeeze, - "Sub": Sub, - "Transpose": Transpose, - "Unsqueeze": Unsqueeze, -} - - -def from_pytorch(pytorch_model, dummy_input): - """ - Static function that converts a PyTorch model into a CrypTen model. - """ - # Exporting model to ONNX graph: - # TODO: Currently export twice because the torch-to-ONNX symbolic registry - # only gets created on the first call. - - # export first time so symbolic registry is created - f = io.BytesIO() - try: - # current version of PyTorch requires us to use `enable_onnx_checker` - torch.onnx.export( - pytorch_model, - dummy_input, - f, - do_constant_folding=False, - export_params=True, - enable_onnx_checker=False, - input_names=["input"], - output_names=["output"], - ) - except TypeError: - # older versions of PyTorch require us to NOT use `enable_onnx_checker` - torch.onnx.export( - pytorch_model, - dummy_input, - f, - do_constant_folding=False, - export_params=True, - input_names=["input"], - output_names=["output"], - ) - - # update ONNX symbolic registry with CrypTen-specific functions - _update_onnx_symbolic_registry() - - # export again so the graph is created with CrypTen-specific registry - f = io.BytesIO() - try: - torch.onnx.export( - pytorch_model, - dummy_input, - f, - do_constant_folding=False, - export_params=True, - enable_onnx_checker=False, - input_names=["input"], - output_names=["output"], - ) - except TypeError: - torch.onnx.export( - pytorch_model, - dummy_input, - f, - do_constant_folding=False, - export_params=True, - input_names=["input"], - output_names=["output"], - ) - f.seek(0) - - # construct CrypTen model: - crypten_model = from_onnx(f) - - # make sure training / eval setting is copied: - crypten_model.train(mode=pytorch_model.training) - return crypten_model - - -def from_tensorflow(tensorflow_graph_def, inputs, outputs): - """ - Static function that converts Tensorflow model into CrypTen model based on - https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py - The model is returned in evaluation mode. - Args: - `tensorflow_graph_def`: Input Tensorflow GraphDef to be converted - `inputs`: input nodes - `outputs`: output nodes - """ - # Exporting model to ONNX graph - if not TF_AND_TF2ONNX: - raise ImportError("Please install both tensorflow and tf2onnx packages") - - with tf.Graph().as_default() as tf_graph: - tf.import_graph_def(tensorflow_graph_def, name="") - with tf2onnx.tf_loader.tf_session(graph=tf_graph): - g = tf2onnx.tfonnx.process_tf_graph( - tf_graph, - opset=10, - continue_on_error=False, - input_names=inputs, - output_names=outputs, - ) - onnx_graph = tf2onnx.optimizer.optimize_graph(g) - model_proto = onnx_graph.make_model( - "converted from {}".format(tensorflow_graph_def) - ) - f = io.BytesIO() - f.write(model_proto.SerializeToString()) - - # construct CrypTen model - # Note: We don't convert crypten model to training mode, as Tensorflow - # models are used for both training and evaluation without the specific - # conversion of one mode to another - f.seek(0) - crypten_model = from_onnx(f) - return crypten_model - - -def from_onnx(onnx_string_or_file): - """ - Constructs a CrypTen model or module from an ONNX Protobuf string or file. - """ - - # if input is file, read string: - if hasattr(onnx_string_or_file, "seek"): # input is file-like - onnx_string_or_file.seek(0) - onnx_model = onnx.load(onnx_string_or_file) - else: - onnx_model = onnx.load_model_from_string(onnx_string_or_file) - - # create dict of all parameters, inputs, and outputs: - all_parameters = { - t.name: torch.from_numpy(numpy_helper.to_array(t)) - for t in onnx_model.graph.initializer - } - input_names = [input.name for input in onnx_model.graph.input] - output_names = [output.name for output in onnx_model.graph.output] - input_names = [ - name for name in input_names if name not in all_parameters.keys() - ] # parameters are not inputs - assert len(input_names) == 1, "number of inputs should be 1" - assert len(output_names) == 1, "number of outputs should be 1" - - # create graph by looping over nodes: - crypten_model = Graph(input_names[0], output_names[0]) - for node in onnx_model.graph.node: - # retrieve inputs, outputs, attributes, and parameters for this node: - node_output_name = list(node.output)[0] - node_input_names = list(node.input) # includes parameters - - # Create parameters: OrderedDict is required to figure out mapping - # between complex names and ONNX arguments - parameters = OrderedDict() - orig_parameter_names = [] - # add in all the parameters for the current module - for i, name in enumerate(node_input_names): - if name in all_parameters and name not in input_names: - key = get_parameter_name(name) - # the following is necessary because tf2onnx names multiple parameters - # identically if they have the same value - if TF_AND_TF2ONNX: - # only modify if we already have the key in parameters - if key in parameters: - key = key + "_" + str(i) - parameters[key] = all_parameters[name] - orig_parameter_names.append(get_parameter_name(name)) - node_input_names = [ - name - for name in node_input_names - if get_parameter_name(name) not in orig_parameter_names - ] - attributes = {attr.name: get_attribute_value(attr) for attr in node.attribute} - - # get operator type: - if node.op_type == "Conv": - dims = len(attributes["kernel_shape"]) - if dims == 1: - cls = Conv1d - elif dims == 2: - cls = Conv2d - else: - raise ValueError("CrypTen does not support op Conv%dd." % dims) - else: - if node.op_type not in ONNX_TO_CRYPTEN: - raise ValueError("CrypTen does not support op %s." % node.op_type) - cls = ONNX_TO_CRYPTEN[node.op_type] - - if TF_AND_TF2ONNX: - # sync parameter names so that they become what CrypTen expects - parameters = _sync_parameters(parameters, node.op_type) - - # add CrypTen module to graph: - crypten_module = cls.from_onnx(parameters=parameters, attributes=attributes) - crypten_model.add_module(node_output_name, crypten_module, node_input_names) - - # return model (or module when there is only one module): - num_modules = len(list(crypten_model.modules())) - if num_modules == 1: - for crypten_module in crypten_model.modules(): - return crypten_module - else: - return crypten_model diff --git a/crypten/nn/onnx_converter.py b/crypten/nn/onnx_converter.py new file mode 100644 index 00000000..276abaf2 --- /dev/null +++ b/crypten/nn/onnx_converter.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import io +from collections import OrderedDict + +import onnx +import torch +import torch.onnx.symbolic_helper as sym_help +import torch.onnx.symbolic_registry as sym_registry +import torch.onnx.utils +from onnx import numpy_helper + +from . import module + + +try: + import tensorflow as tf # noqa + import tf2onnx + + TF_AND_TF2ONNX = True +except ImportError: + TF_AND_TF2ONNX = False + + +# mapping from ONNX to crypten.nn for modules with different names: +ONNX_TO_CRYPTEN = { + "AveragePool": module.AvgPool2d, + "BatchNormalization": module._BatchNorm, + "Gemm": module.Linear, + "MaxPool": module.MaxPool2d, + "Pad": module._ConstantPad, + "Relu": module.ReLU, + "ReduceMean": module.Mean, + "ReduceSum": module.Sum, +} + + +def from_pytorch(pytorch_model, dummy_input): + """ + Static function that converts a PyTorch model into a CrypTen model. + """ + # construct CrypTen model: + f = _from_pytorch_to_bytes(pytorch_model, dummy_input) + crypten_model = from_onnx(f) + f.close() + + # make sure training / eval setting is copied: + crypten_model.train(mode=pytorch_model.training) + return crypten_model + + +def _from_pytorch_to_bytes(pytorch_model, dummy_input): + """Returns I/O stream containing onnx graph with crypten specific ops""" + # TODO: Currently export twice because the torch-to-ONNX symbolic registry + # only gets created on the first call. + with io.BytesIO() as f: + _export_pytorch_model(f, pytorch_model, dummy_input) + + # update ONNX symbolic registry with CrypTen-specific functions + _update_onnx_symbolic_registry() + + # export again so the graph is created with CrypTen-specific registry + f = io.BytesIO() + f = _export_pytorch_model(f, pytorch_model, dummy_input) + f.seek(0) + return f + + +def _export_pytorch_model(f, pytorch_model, dummy_input): + """Returns a Binary I/O stream containing exported model""" + kwargs = { + "do_constant_folding": False, + "export_params": True, + "enable_onnx_checker": False, + "input_names": ["input"], + "output_names": ["output"], + } + try: + # current version of PyTorch requires us to use `enable_onnx_checker` + torch.onnx.export(pytorch_model, dummy_input, f, **kwargs) + except TypeError: + # older versions of PyTorch require us to NOT use `enable_onnx_checker` + kwargs.pop("enable_onnx_checker") + torch.onnx.export(pytorch_model, dummy_input, f, **kwargs) + return f + + +def from_tensorflow(tensorflow_graph_def, inputs, outputs): + """ + Static function that converts Tensorflow model into CrypTen model based on + https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py + The model is returned in evaluation mode. + Args: + `tensorflow_graph_def`: Input Tensorflow GraphDef to be converted + `inputs`: input nodes + `outputs`: output nodes + """ + # Exporting model to ONNX graph + if not TF_AND_TF2ONNX: + raise ImportError("Please install both tensorflow and tf2onnx packages") + + with tf.Graph().as_default() as tf_graph: + tf.import_graph_def(tensorflow_graph_def, name="") + with tf2onnx.tf_loader.tf_session(graph=tf_graph): + g = tf2onnx.tfonnx.process_tf_graph( + tf_graph, + opset=10, + continue_on_error=False, + input_names=inputs, + output_names=outputs, + ) + onnx_graph = tf2onnx.optimizer.optimize_graph(g) + model_proto = onnx_graph.make_model( + "converted from {}".format(tensorflow_graph_def) + ) + f = io.BytesIO() + f.write(model_proto.SerializeToString()) + + # construct CrypTen model + # Note: We don't convert crypten model to training mode, as Tensorflow + # models are used for both training and evaluation without the specific + # conversion of one mode to another + f.seek(0) + crypten_model = from_onnx(f) + return crypten_model + + +def from_onnx(onnx_string_or_file): + """ + Constructs a CrypTen model or module from an ONNX Protobuf string or file. + """ + onnx_model = _load_onnx_model(onnx_string_or_file) + + # create dict of all parameters, inputs, and outputs: + all_parameters = { + t.name: torch.from_numpy(numpy_helper.to_array(t)) + for t in onnx_model.graph.initializer + } + input_names, output_names = _get_input_output_names(onnx_model, all_parameters) + + # create graph by looping over nodes: + crypten_model = module.Graph(input_names[0], output_names[0]) + for node in onnx_model.graph.node: + # retrieve inputs, outputs, attributes, and parameters for this node: + node_output_name = list(node.output)[0] + node_input_names = list(node.input) # includes parameters + + # Create parameters: OrderedDict is required to figure out mapping + # between complex names and ONNX arguments + parameters = OrderedDict() + orig_parameter_names = [] + # add in all the parameters for the current module + for i, name in enumerate(node_input_names): + if name in all_parameters and name not in input_names: + key = _get_parameter_name(name) + # the following is necessary because tf2onnx names multiple parameters + # identically if they have the same value + # only modify if we already have the key in parameters + if TF_AND_TF2ONNX and key in parameters: + key = key + "_" + str(i) + parameters[key] = all_parameters[name] + orig_parameter_names.append(_get_parameter_name(name)) + node_input_names = [ + name + for name in node_input_names + if _get_parameter_name(name) not in orig_parameter_names + ] + attributes = {attr.name: _get_attribute_value(attr) for attr in node.attribute} + + cls = _get_operator_class(node, attributes) + + if TF_AND_TF2ONNX: + # sync parameter names so that they become what CrypTen expects + parameters = _sync_tensorflow_parameters(parameters, node.op_type) + + # add CrypTen module to graph: + crypten_module = cls.from_onnx(parameters=parameters, attributes=attributes) + crypten_model.add_module(node_output_name, crypten_module, node_input_names) + + crypten_model = _get_model_or_module(crypten_model) + return crypten_model + + +def _load_onnx_model(onnx_string_or_file): + """Loads onnx model from file or string""" + # if input is file, read string + if hasattr(onnx_string_or_file, "seek"): + onnx_string_or_file.seek(0) + return onnx.load(onnx_string_or_file) + return onnx.load_model_from_string(onnx_string_or_file) + + +def _get_input_output_names(onnx_model, all_parameters): + """Return input and output names""" + input_names = [] + for input in onnx_model.graph.input: + # parameters are not inputs + if input.name not in all_parameters: + input_names.append(input.name) + + output_names = [output.name for output in onnx_model.graph.output] + + assert len(input_names) == 1, "number of inputs should be 1" + assert len(output_names) == 1, "number of outputs should be 1" + + return input_names, output_names + + +def _get_operator_class(node, attributes): + """Returns CrypTen class of operator""" + # get operator type: + if node.op_type == "Conv": + dims = len(attributes["kernel_shape"]) + if dims == 1: + cls = module.Conv1d + elif dims == 2: + cls = module.Conv2d + else: + raise ValueError("CrypTen does not support op Conv%dd." % dims) + else: + crypten_module = getattr( + module, node.op_type, ONNX_TO_CRYPTEN.get(node.op_type, None) + ) + + if crypten_module is None: + raise ValueError("CrypTen does not support op %s." % node.op_type) + cls = crypten_module + return cls + + +def _get_model_or_module(crypten_model): + """ + Returns module if model contains only one module. Otherwise returns model. + """ + num_modules = len(list(crypten_model.modules())) + if num_modules == 1: + for crypten_module in crypten_model.modules(): + return crypten_module + return crypten_model + + +def _get_parameter_name(name): + """ + Gets parameter name from parameter key. + """ + return name[name.rfind(".") + 1 :] + + +def _get_attribute_value(attr): + """ + Retrieves value from attribute in ONNX graph. + """ + if attr.HasField("f"): # floating-point attribute + return attr.f + elif attr.HasField("i"): # integer attribute + return attr.i + elif attr.HasField("s"): # string attribute + return attr.s # TODO: Sanitize string. + elif attr.HasField("t"): # tensor attribute + return torch.from_numpy(numpy_helper.to_array(attr.t)) + elif len(attr.ints) > 0: + return list(attr.ints) + elif len(attr.floats) > 0: + return list(attr.floats) + raise ValueError("Unknown attribute type for attribute %s." % attr.name) + + +def _sync_tensorflow_parameters(parameter_map, module_name): + """ + Syncs parameters from parameter map to be consistent + with expected PyTorch parameter map + """ + + def _map_module_parameters(parameter_map, module_param_names): + for i, key in enumerate(parameter_map.keys()): + value = parameter_map[key] + new_parameter_map[module_param_names[i]] = value + + new_parameter_map = {} + if module_name == "Conv": + module_param_names = ["weight", "bias"] + _map_module_parameters(parameter_map, module_param_names) + elif module_name == "BatchNormalization": + module_param_names = [ + "weight", + "bias", + "running_mean", + "running_var", + "training_mode", + ] + _map_module_parameters(parameter_map, module_param_names) + else: + new_parameter_map = parameter_map + return new_parameter_map + + +def _update_onnx_symbolic_registry(): + """ + Updates the ONNX symbolic registry for operators that need a CrypTen-specific + implementation and custom operators. + """ + for version_key, version_val in sym_registry._registry.items(): + for function_key in version_val.keys(): + if function_key == "softmax": + sym_registry._registry[version_key][ + function_key + ] = _onnx_crypten_softmax + if function_key == "log_softmax": + sym_registry._registry[version_key][ + function_key + ] = _onnx_crypten_logsoftmax + if function_key == "dropout": + sym_registry._registry[version_key][ + function_key + ] = _onnx_crypten_dropout + if function_key == "feature_dropout": + sym_registry._registry[version_key][ + function_key + ] = _onnx_crypten_feature_dropout + + +@sym_help.parse_args("v", "i", "none") +def _onnx_crypten_softmax(g, input, dim, dtype=None): + """ + This function converts PyTorch's Softmax module to a Softmax module in + the ONNX model. It overrides PyTorch's default conversion of Softmax module + to a sequence of Exp, ReduceSum and Div modules, since this default + conversion can cause numerical overflow when applied to CrypTensors. + """ + result = g.op("Softmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = sym_help._get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) + return result + + +@sym_help.parse_args("v", "i", "none") +def _onnx_crypten_logsoftmax(g, input, dim, dtype=None): + """ + This function converts PyTorch's LogSoftmax module to a LogSoftmax module in + the ONNX model. It overrides PyTorch's default conversion of LogSoftmax module + to avoid potentially creating Transpose operators. + """ + result = g.op("LogSoftmax", input, axis_i=dim) + if dtype and dtype.node().kind() != "prim::Constant": + parsed_dtype = sym_help._get_const(dtype, "i", "dtype") + result = g.op("Cast", result, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) + return result + + +@sym_help.parse_args("v", "f", "i") +def _onnx_crypten_dropout(g, input, p, train): + """ + This function converts PyTorch's Dropout module to a Dropout module in the ONNX + model. It overrides PyTorch's default implementation to ignore the Dropout module + during the conversion. PyTorch assumes that ONNX models are only used for + inference and therefore Dropout modules are not required in the ONNX model. + However, CrypTen needs to convert ONNX models to trainable + CrypTen models, and so the Dropout module needs to be included in the + CrypTen-specific conversion. + """ + r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) + return r + + +@sym_help.parse_args("v", "f", "i") +def _onnx_crypten_feature_dropout(g, input, p, train): + """ + This function converts PyTorch's DropoutNd module to a DropoutNd module in the ONNX + model. It overrides PyTorch's default implementation to ignore the DropoutNd module + during the conversion. PyTorch assumes that ONNX models are only used for + inference and therefore DropoutNd modules are not required in the ONNX model. + However, CrypTen needs to convert ONNX models to trainable + CrypTen models, and so the DropoutNd module needs to be included in the + CrypTen-specific conversion. + """ + r, _ = g.op("DropoutNd", input, ratio_f=p, outputs=2) + return r diff --git a/crypten/nn/onnx_helper.py b/crypten/nn/onnx_helper.py deleted file mode 100644 index ba3e9756..00000000 --- a/crypten/nn/onnx_helper.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.onnx.symbolic_helper as sym_help -import torch.onnx.symbolic_registry as sym_registry -from onnx import numpy_helper - - -def get_parameter_name(name): - """ - Gets parameter name from parameter key. - """ - return name[name.rfind(".") + 1 :] - - -def get_attribute_value(attr): - """ - Retrieves value from attribute in ONNX graph. - """ - if attr.HasField("f"): # floating-point attribute - return attr.f - elif attr.HasField("i"): # integer attribute - return attr.i - elif attr.HasField("s"): # string attribute - return attr.s # TODO: Sanitize string. - elif attr.HasField("t"): # tensor attribute - return torch.from_numpy(numpy_helper.to_array(attr.t)) - elif len(attr.ints) > 0: - return list(attr.ints) - elif len(attr.floats) > 0: - return list(attr.floats) - else: - raise ValueError("Unknown attribute type for attribute %s." % attr.name) - - -def _sync_parameters(parameter_map, module_name): - """ - Syncs parameters from parameter map to be consistent - with expected PyTorch parameter map - """ - - def _map_module_parameters(parameter_map, module_param_names): - for i, key in enumerate(parameter_map.keys()): - value = parameter_map[key] - new_parameter_map[module_param_names[i]] = value - - new_parameter_map = {} - if module_name == "Conv": - module_param_names = ["weight", "bias"] - _map_module_parameters(parameter_map, module_param_names) - elif module_name == "BatchNormalization": - module_param_names = [ - "weight", - "bias", - "running_mean", - "running_var", - "training_mode", - ] - _map_module_parameters(parameter_map, module_param_names) - else: - new_parameter_map = parameter_map - return new_parameter_map - - -def _update_onnx_symbolic_registry(): - """ - Updates the ONNX symbolic registry for operators that need a CrypTen-specific - implementation and custom operators. - """ - for version_key, version_val in sym_registry._registry.items(): - for function_key in version_val.keys(): - if function_key == "softmax": - sym_registry._registry[version_key][ - function_key - ] = _onnx_crypten_softmax - if function_key == "log_softmax": - sym_registry._registry[version_key][ - function_key - ] = _onnx_crypten_logsoftmax - if function_key == "dropout": - sym_registry._registry[version_key][ - function_key - ] = _onnx_crypten_dropout - if function_key == "feature_dropout": - sym_registry._registry[version_key][ - function_key - ] = _onnx_crypten_feature_dropout - - -@sym_help.parse_args("v", "i", "none") -def _onnx_crypten_softmax(g, input, dim, dtype=None): - """ - This function converts PyTorch's Softmax module to a Softmax module in - the ONNX model. It overrides PyTorch's default conversion of Softmax module - to a sequence of Exp, ReduceSum and Div modules, since this default - conversion can cause numerical overflow when applied to CrypTensors. - """ - result = g.op("Softmax", input, axis_i=dim) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = sym_help._get_const(dtype, "i", "dtype") - result = g.op("Cast", result, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) - return result - - -@sym_help.parse_args("v", "i", "none") -def _onnx_crypten_logsoftmax(g, input, dim, dtype=None): - """ - This function converts PyTorch's LogSoftmax module to a LogSoftmax module in - the ONNX model. It overrides PyTorch's default conversion of LogSoftmax module - to avoid potentially creating Transpose operators. - """ - result = g.op("LogSoftmax", input, axis_i=dim) - if dtype and dtype.node().kind() != "prim::Constant": - parsed_dtype = sym_help._get_const(dtype, "i", "dtype") - result = g.op("Cast", result, to_i=sym_help.scalar_type_to_onnx[parsed_dtype]) - return result - - -@sym_help.parse_args("v", "f", "i") -def _onnx_crypten_dropout(g, input, p, train): - """ - This function converts PyTorch's Dropout module to a Dropout module in the ONNX - model. It overrides PyTorch's default implementation to ignore the Dropout module - during the conversion. PyTorch assumes that ONNX models are only used for - inference and therefore Dropout modules are not required in the ONNX model. - However, CrypTen needs to convert ONNX models to trainable - CrypTen models, and so the Dropout module needs to be included in the - CrypTen-specific conversion. - """ - r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) - return r - - -@sym_help.parse_args("v", "f", "i") -def _onnx_crypten_feature_dropout(g, input, p, train): - """ - This function converts PyTorch's DropoutNd module to a DropoutNd module in the ONNX - model. It overrides PyTorch's default implementation to ignore the DropoutNd module - during the conversion. PyTorch assumes that ONNX models are only used for - inference and therefore DropoutNd modules are not required in the ONNX model. - However, CrypTen needs to convert ONNX models to trainable - CrypTen models, and so the DropoutNd module needs to be included in the - CrypTen-specific conversion. - """ - r, _ = g.op("DropoutNd", input, ratio_f=p, outputs=2) - return r diff --git a/test/test_nn.py b/test/test_nn.py index d267a663..f1b786fc 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -812,141 +812,6 @@ def forward(self, x): "loss has not decreased after training", ) - def test_from_pytorch_training_classification(self): - """Tests from_pytorch CrypTen training for classification models""" - import torch.nn as nn - import torch.nn.functional as F - - class CNN(nn.Module): - def __init__(self): - super(CNN, self).__init__() - self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=1) - self.fc1 = nn.Linear(16 * 13 * 13, 100) - self.fc2 = nn.Linear(100, 2) - - def forward(self, x): - out = self.conv1(x) - out = F.relu(out) - out = F.max_pool2d(out, 2) - out = out.view(out.size(0), -1) - out = self.fc1(out) - out = F.relu(out) - out = self.fc2(out) - out = F.softmax(out, dim=1) - return out - - model_plaintext = CNN() - batch_size = 5 - x_orig = get_random_test_tensor(size=(batch_size, 1, 28, 28), is_float=True) - y_orig = ( - get_random_test_tensor(size=(batch_size, 1), is_float=True).gt(0).long() - ) - y_one_hot = onehot(y_orig, num_targets=2) - - # encrypt training sample: - x_train = crypten.cryptensor(x_orig, requires_grad=True) - y_train = crypten.cryptensor(y_one_hot) - dummy_input = torch.empty((1, 1, 28, 28)) - - for loss_name in ["BCELoss", "CrossEntropyLoss"]: - # create encrypted model - model = crypten.nn.from_pytorch(model_plaintext, dummy_input) - model.train() - model.encrypt() - - self._check_training(model, x_train, y_train, loss_name) - - def test_from_pytorch_training_regression(self): - """Tests from_pytorch CrypTen training for regression models""" - import torch.nn as nn - import torch.nn.functional as F - - class FeedForward(nn.Module): - def __init__(self): - super(FeedForward, self).__init__() - self.fc1 = nn.Linear(3, 10) - self.fc2 = nn.Linear(10, 1) - - def forward(self, x): - out = self.fc1(x) - out = F.relu(out) - out = self.fc2(out) - return out - - model_plaintext = FeedForward() - batch_size = 5 - - x_orig = get_random_test_tensor(size=(batch_size, 3), is_float=True) - dummy_input = torch.empty((1, 3)) - # y is a linear combo of features 1 and 3 - y_orig = 2 * x_orig[:, 0] + 3 * x_orig[:, 2] - - x_train = crypten.cryptensor(x_orig, requires_grad=True) - y_train = crypten.cryptensor(y_orig.unsqueeze(-1)) - - # create encrypted model - model = crypten.nn.from_pytorch(model_plaintext, dummy_input) - model.train() - model.encrypt() - - self._check_training(model, x_train, y_train, "MSELoss") - - def _check_training( - self, model, x_train, y_train, loss_name, num_epochs=3, learning_rate=0.001 - ): - """Verifies gradient updates and loss decreases during training""" - # create loss function - loss = getattr(crypten.nn, loss_name)() - - for i in range(num_epochs): - output = model(x_train) - loss_value = loss(output, y_train) - - # set gradients to "zero" - model.zero_grad() - for param in model.parameters(): - self.assertIsNone(param.grad, "zero_grad did not reset gradients") - - # perform backward pass - loss_value.backward() - for param in model.parameters(): - if param.requires_grad: - self.assertIsNotNone( - param.grad, "required parameter gradient not created" - ) - - # update parameters - orig_parameters, upd_parameters = {}, {} - orig_parameters = self._compute_reference_parameters( - "", orig_parameters, model, 0 - ) - model.update_parameters(learning_rate) - upd_parameters = self._compute_reference_parameters( - "", upd_parameters, model, learning_rate - ) - - # check parameter update - parameter_changed = False - for name, value in orig_parameters.items(): - if param.requires_grad and param.grad is not None: - unchanged = torch.allclose(upd_parameters[name], value) - if unchanged is False: - parameter_changed = True - self.assertTrue( - parameter_changed, "no parameter changed in training step" - ) - - # record initial and current loss - if i == 0: - orig_loss = loss_value.get_plain_text() - curr_loss = loss_value.get_plain_text() - - # check that the loss has decreased after training - self.assertTrue( - curr_loss.item() < orig_loss.item(), - f"{loss_name} has not decreased after training", - ) - def test_batchnorm_module(self): """Test module correctly sets and updates running stats""" batchnorm_fn_and_size = ( @@ -1044,108 +909,6 @@ def _run_test(_sample, _target): _run_test(sample, target) _run_test(crypten.cryptensor(sample), crypten.cryptensor(target)) - @unittest.skipIf( - not crypten.nn.TF_AND_TF2ONNX, "Tensorflow and tf2onnx not installed" - ) - def test_tensorflow_model_conversion(self): - import tensorflow as tf - import tf2onnx - - # create simple model - model_tf1 = tf.keras.Sequential( - [ - tf.keras.layers.Dense( - 10, - activation=tf.nn.relu, - kernel_initializer="ones", - bias_initializer="ones", - input_shape=(4,), - ), - tf.keras.layers.Dense( - 10, - activation=tf.nn.relu, - kernel_initializer="ones", - bias_initializer="ones", - ), - tf.keras.layers.Dense(3, kernel_initializer="ones"), - ] - ) - - model_tf2 = tf.keras.Sequential( - [ - tf.keras.layers.Conv2D( - 32, - 3, - activation="relu", - strides=1, - kernel_initializer="ones", - bias_initializer="ones", - input_shape=(32, 32, 3), - ), - tf.keras.layers.MaxPooling2D(3), - tf.keras.layers.GlobalAveragePooling2D(), - tf.keras.layers.Dropout(0.5), - ] - ) - - model_tf3 = tf.keras.Sequential( - [ - tf.keras.layers.Conv1D( - 32, - 1, - activation="relu", - strides=1, - kernel_initializer="ones", - bias_initializer="ones", - input_shape=(6, 128), - ), - tf.keras.layers.AvgPool1D(1), - ] - ) - - feature_sizes = [(1, 4), (1, 32, 32, 3), (1, 6, 128)] - label_sizes = [(1, 3), (1, 32), (1, 6, 32)] - - for i, curr_model_tf in enumerate([model_tf1, model_tf2, model_tf3]): - # create a random feature vector - features = get_random_test_tensor( - size=feature_sizes[i], is_float=True, min_value=1, max_value=3 - ) - labels = get_random_test_tensor( - size=label_sizes[i], is_float=True, min_value=1 - ) - - # convert to a TF tensor via numpy - features_tf = tf.convert_to_tensor(features.numpy()) - labels_tf = tf.convert_to_tensor(labels.numpy()) - # compute the tensorflow predictions - curr_model_tf.compile("sgd", loss=tf.keras.losses.MeanSquaredError()) - curr_model_tf.fit(features_tf, labels_tf) - result_tf = curr_model_tf(features_tf, training=False) - - # convert TF model to CrypTen model - # write as a SavedModel, then load GraphDef from it - import tempfile - - saved_model_dir = tempfile.NamedTemporaryFile(delete=True).name - os.makedirs(saved_model_dir, exist_ok=True) - curr_model_tf.save(saved_model_dir) - graph_def, inputs, outputs = tf2onnx.tf_loader.from_saved_model( - saved_model_dir, None, None - ) - model_enc = crypten.nn.from_tensorflow( - graph_def, list(inputs.keys()), list(outputs.keys()) - ) - - # encrypt model and run it - model_enc.encrypt() - features_enc = crypten.cryptensor(features) - result_enc = model_enc(features_enc) - - # compare the results - result = torch.tensor(result_tf.numpy()) - self._check(result_enc, result, "nn.from_tensorflow failed") - def test_state_dict(self): """ Tests dumping and loading of state dicts. diff --git a/test/test_onnx_converter.py b/test/test_onnx_converter.py new file mode 100644 index 00000000..db06e2d6 --- /dev/null +++ b/test/test_onnx_converter.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import io +import logging +import os +import unittest +from test.multiprocess_test_case import ( + MultiProcessTestCase, + get_random_test_tensor, + onehot, +) + +import crypten +import torch +from crypten.common.tensor_types import is_float_tensor +from crypten.nn import onnx_converter + + +class TestOnnxConverter(object): + """Tests PyTorch and Tensorflow model imports""" + + def _check(self, encrypted_tensor, reference, msg, tolerance=None): + if tolerance is None: + tolerance = getattr(self, "default_tolerance", 0.05) + tensor = encrypted_tensor.get_plain_text() + + # Check sizes match + self.assertTrue(tensor.size() == reference.size(), msg) + + if is_float_tensor(reference): + diff = (tensor - reference).abs_() + norm_diff = diff.div(tensor.abs() + reference.abs()).abs_() + test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.2) + test_passed = test_passed.gt(0).all().item() == 1 + else: + test_passed = (tensor == reference).all().item() == 1 + if not test_passed: + logging.info(msg) + logging.info("Result %s" % tensor) + logging.info("Result - Reference = %s" % (tensor - reference)) + self.assertTrue(test_passed, msg=msg) + + def _check_reference_parameters(self, init_name, reference, model): + for name, param in model.named_parameters(recurse=False): + local_name = init_name + "_" + name + self._check(param, reference[local_name], "parameter update failed") + for name, module in model._modules.items(): + local_name = init_name + "_" + name + self._check_reference_parameters(local_name, reference, module) + + def _compute_reference_parameters(self, init_name, reference, model, learning_rate): + for name, param in model.named_parameters(recurse=False): + local_name = init_name + "_" + name + reference[local_name] = ( + param.get_plain_text() - learning_rate * param.grad.get_plain_text() + ) + for name, module in model._modules.items(): + local_name = init_name + "_" + name + reference = self._compute_reference_parameters( + local_name, reference, module, learning_rate + ) + return reference + + def setUp(self): + super().setUp() + # We don't want the main process (rank -1) to initialize the communicator + if self.rank >= 0: + crypten.init() + + @unittest.skipIf( + not crypten.nn.TF_AND_TF2ONNX, "Tensorflow and tf2onnx not installed" + ) + def test_tensorflow_model_conversion(self): + import tensorflow as tf + import tf2onnx + + # create simple model + model_tf1 = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + 10, + activation=tf.nn.relu, + kernel_initializer="ones", + bias_initializer="ones", + input_shape=(4,), + ), + tf.keras.layers.Dense( + 10, + activation=tf.nn.relu, + kernel_initializer="ones", + bias_initializer="ones", + ), + tf.keras.layers.Dense(3, kernel_initializer="ones"), + ] + ) + + model_tf2 = tf.keras.Sequential( + [ + tf.keras.layers.Conv2D( + 32, + 3, + activation="relu", + strides=1, + kernel_initializer="ones", + bias_initializer="ones", + input_shape=(32, 32, 3), + ), + tf.keras.layers.MaxPooling2D(3), + tf.keras.layers.GlobalAveragePooling2D(), + tf.keras.layers.Dropout(0.5), + ] + ) + + model_tf3 = tf.keras.Sequential( + [ + tf.keras.layers.Conv1D( + 32, + 1, + activation="relu", + strides=1, + kernel_initializer="ones", + bias_initializer="ones", + input_shape=(6, 128), + ), + tf.keras.layers.AvgPool1D(1), + ] + ) + + feature_sizes = [(1, 4), (1, 32, 32, 3), (1, 6, 128)] + label_sizes = [(1, 3), (1, 32), (1, 6, 32)] + + for i, curr_model_tf in enumerate([model_tf1, model_tf2, model_tf3]): + # create a random feature vector + features = get_random_test_tensor( + size=feature_sizes[i], is_float=True, min_value=1, max_value=3 + ) + labels = get_random_test_tensor( + size=label_sizes[i], is_float=True, min_value=1 + ) + + # convert to a TF tensor via numpy + features_tf = tf.convert_to_tensor(features.numpy()) + labels_tf = tf.convert_to_tensor(labels.numpy()) + # compute the tensorflow predictions + curr_model_tf.compile("sgd", loss=tf.keras.losses.MeanSquaredError()) + curr_model_tf.fit(features_tf, labels_tf) + result_tf = curr_model_tf(features_tf, training=False) + + # convert TF model to CrypTen model + # write as a SavedModel, then load GraphDef from it + import tempfile + + saved_model_dir = tempfile.NamedTemporaryFile(delete=True).name + os.makedirs(saved_model_dir, exist_ok=True) + curr_model_tf.save(saved_model_dir) + graph_def, inputs, outputs = tf2onnx.tf_loader.from_saved_model( + saved_model_dir, None, None + ) + model_enc = crypten.nn.from_tensorflow( + graph_def, list(inputs.keys()), list(outputs.keys()) + ) + + # encrypt model and run it + model_enc.encrypt() + features_enc = crypten.cryptensor(features) + result_enc = model_enc(features_enc) + + # compare the results + result = torch.tensor(result_tf.numpy()) + self._check(result_enc, result, "nn.from_tensorflow failed") + + def test_from_pytorch_training_classification(self): + """Tests from_pytorch CrypTen training for classification models""" + import torch.nn as nn + import torch.nn.functional as F + + class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=1) + self.fc1 = nn.Linear(16 * 13 * 13, 100) + self.fc2 = nn.Linear(100, 2) + + def forward(self, x): + out = self.conv1(x) + out = F.relu(out) + out = F.max_pool2d(out, 2) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = F.relu(out) + out = self.fc2(out) + out = F.softmax(out, dim=1) + return out + + model_plaintext = CNN() + batch_size = 5 + x_orig = get_random_test_tensor(size=(batch_size, 1, 28, 28), is_float=True) + y_orig = ( + get_random_test_tensor(size=(batch_size, 1), is_float=True).gt(0).long() + ) + y_one_hot = onehot(y_orig, num_targets=2) + + # encrypt training sample: + x_train = crypten.cryptensor(x_orig, requires_grad=True) + y_train = crypten.cryptensor(y_one_hot) + dummy_input = torch.empty((1, 1, 28, 28)) + + for loss_name in ["BCELoss", "CrossEntropyLoss"]: + # create encrypted model + model = crypten.nn.from_pytorch(model_plaintext, dummy_input) + model.train() + model.encrypt() + + self._check_training(model, x_train, y_train, loss_name) + + def test_from_pytorch_training_regression(self): + """Tests from_pytorch CrypTen training for regression models""" + import torch.nn as nn + import torch.nn.functional as F + + class FeedForward(nn.Module): + def __init__(self): + super(FeedForward, self).__init__() + self.fc1 = nn.Linear(3, 10) + self.fc2 = nn.Linear(10, 1) + + def forward(self, x): + out = self.fc1(x) + out = F.relu(out) + out = self.fc2(out) + return out + + model_plaintext = FeedForward() + batch_size = 5 + + x_orig = get_random_test_tensor(size=(batch_size, 3), is_float=True) + dummy_input = torch.empty((1, 3)) + # y is a linear combo of features 1 and 3 + y_orig = 2 * x_orig[:, 0] + 3 * x_orig[:, 2] + + x_train = crypten.cryptensor(x_orig, requires_grad=True) + y_train = crypten.cryptensor(y_orig.unsqueeze(-1)) + + # create encrypted model + model = crypten.nn.from_pytorch(model_plaintext, dummy_input) + model.train() + model.encrypt() + + self._check_training(model, x_train, y_train, "MSELoss") + + def _check_training( + self, model, x_train, y_train, loss_name, num_epochs=3, learning_rate=0.001 + ): + """Verifies gradient updates and loss decreases during training""" + # create loss function + loss = getattr(crypten.nn, loss_name)() + + for i in range(num_epochs): + output = model(x_train) + loss_value = loss(output, y_train) + + # set gradients to "zero" + model.zero_grad() + for param in model.parameters(): + self.assertIsNone(param.grad, "zero_grad did not reset gradients") + + # perform backward pass + loss_value.backward() + for param in model.parameters(): + if param.requires_grad: + self.assertIsNotNone( + param.grad, "required parameter gradient not created" + ) + + # update parameters + orig_parameters, upd_parameters = {}, {} + orig_parameters = self._compute_reference_parameters( + "", orig_parameters, model, 0 + ) + model.update_parameters(learning_rate) + upd_parameters = self._compute_reference_parameters( + "", upd_parameters, model, learning_rate + ) + + # check parameter update + parameter_changed = False + for name, value in orig_parameters.items(): + if param.requires_grad and param.grad is not None: + unchanged = torch.allclose(upd_parameters[name], value) + if unchanged is False: + parameter_changed = True + self.assertTrue( + parameter_changed, "no parameter changed in training step" + ) + + # record initial and current loss + if i == 0: + orig_loss = loss_value.get_plain_text() + curr_loss = loss_value.get_plain_text() + + # check that the loss has decreased after training + self.assertTrue( + curr_loss.item() < orig_loss.item(), + f"{loss_name} has not decreased after training", + ) + + def test_export_pytorch_model(self): + """Tests loading of onnx model from a file""" + pytorch_model = PyTorchLinear() + dummy_input = torch.empty(10, 10) + + with io.BytesIO() as f: + onnx_converter._export_pytorch_model(f, pytorch_model, dummy_input) + + def test_from_onnx(self): + """Tests construction of crypten model from onnx graph""" + pytorch_model = PyTorchLinear() + dummy_input = torch.empty(10, 10) + + with io.BytesIO() as f: + f = onnx_converter._export_pytorch_model(f, pytorch_model, dummy_input) + f.seek(0) + + crypten_model = onnx_converter.from_onnx(f) + + self.assertTrue(hasattr(crypten_model, "encrypt")) + + def test_get_operator_class(self): + """Checks operator is a valid crypten module""" + Node = collections.namedtuple("Node", "op_type") + + op_types = ["Sum", "AveragePool", "Mean"] + for op_type in op_types: + node = Node(op_type) + operator = onnx_converter._get_operator_class(node, {}) + self.assertTrue( + issubclass(operator, crypten.nn.Module), + f"{op_type} operator class {operator} is not a CrypTen module.", + ) + # check conv + kernel_shapes = [[1], [3, 3]] + node = Node("Conv") + for kernel_shape in kernel_shapes: + attributes = {"kernel_shape": kernel_shape} + operator = onnx_converter._get_operator_class(node, attributes) + + # check invalid op_types + invalid_types = [("Conv", {"kernel_shape": [3, 3, 3]}), ("Banana", {})] + for invalid_type, attr in invalid_types: + with self.assertRaises(ValueError): + node = Node(invalid_type) + operator = onnx_converter._get_operator_class(node, attr) + + +class PyTorchLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 1) + + def forward(self, x): + x = self.fc1(x) + return x + + +# Run all unit tests with both TFP and TTP providers +class TestTFP(MultiProcessTestCase, TestOnnxConverter): + def setUp(self): + self._original_provider = crypten.mpc.get_default_provider() + crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedFirstParty) + super(TestTFP, self).setUp() + + def tearDown(self): + crypten.mpc.set_default_provider(self._original_provider) + super(TestTFP, self).tearDown() + + +class TestTTP(MultiProcessTestCase, TestOnnxConverter): + def setUp(self): + self._original_provider = crypten.mpc.get_default_provider() + crypten.mpc.set_default_provider(crypten.mpc.provider.TrustedThirdParty) + super(TestTTP, self).setUp() + + def tearDown(self): + crypten.mpc.set_default_provider(self._original_provider) + super(TestTTP, self).tearDown() + + +# This code only runs when executing the file outside the test harness +if __name__ == "__main__": + unittest.main()