Skip to content

Commit

Permalink
Add support for PowConstVar and PowVarVar
Browse files Browse the repository at this point in the history
  • Loading branch information
shinh committed Sep 26, 2019
1 parent 68b403f commit aea7d5e
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ onnx_chainer.export(model, x, filename='vgg16.onnx')

## Supported Functions

Currently 82 Chainer Functions are supported to export in ONNX format.
Currently 84 Chainer Functions are supported to export in ONNX format.

### Activation

Expand Down Expand Up @@ -159,7 +159,9 @@ Currently 82 Chainer Functions are supported to export in ONNX format.
- Mul
- MulConstant
- Neg
- PowConstVar
- PowVarConst
- PowVarVar
- Prod
- RsqrtGPU
- Sin
Expand Down
2 changes: 2 additions & 0 deletions docs/source/introduction/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ Currently 82 Chainer Functions are supported to export in ONNX format.
* Mul
* MulConstant
* Neg
* PowConstVar
* PowVarConst
* PowVarVar
* Prod
* RsqrtGPU
* Sqrt
Expand Down
2 changes: 2 additions & 0 deletions onnx_chainer/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@
from onnx_chainer.functions.math import convert_Mul # NOQA
from onnx_chainer.functions.math import convert_MulConstant # NOQA
from onnx_chainer.functions.math import convert_Neg # NOQA
from onnx_chainer.functions.math import convert_PowConstVar # NOQA
from onnx_chainer.functions.math import convert_PowVarConst # NOQA
from onnx_chainer.functions.math import convert_PowVarVar # NOQA
from onnx_chainer.functions.math import convert_Prod # NOQA
from onnx_chainer.functions.math import convert_RsqrtGPU # NOQA
from onnx_chainer.functions.math import convert_Sin # NOQA
Expand Down
18 changes: 18 additions & 0 deletions onnx_chainer/functions/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ def convert_Arctan(func, opset_version, input_names, output_names, context):
return onnx_helper.make_node('Atan', input_names, output_names),


@support((1, 7))
def convert_PowConstVar(
func, opset_version, input_names, output_names, context):
value_name = context.add_const(
np.array(func.value, dtype=func.inputs[0].dtype), 'value')
input_names.insert(0, value_name)

if opset_version == 1 or opset_version == 7:
return onnx_helper.make_node('Pow', input_names, output_names),


@support((1, 7))
def convert_PowVarConst(
func, opset_version, input_names, output_names, context):
Expand All @@ -141,6 +152,13 @@ def convert_PowVarConst(
return onnx_helper.make_node('Pow', input_names, output_names),


@support((1, 7))
def convert_PowVarVar(
func, opset_version, input_names, output_names, context):
if opset_version == 1 or opset_version == 7:
return onnx_helper.make_node('Pow', input_names, output_names),


@support((1, 6))
def convert_Clip(func, opset_version, input_names, output_names, context):
if opset_version == 1:
Expand Down
2 changes: 2 additions & 0 deletions onnx_chainer/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@
'Mul',
'MulConstant',
'Neg',
'PowConstVar',
'PowVarConst',
'PowVarVar',
'Prod',
'RsqrtGPU',
'Sin',
Expand Down
5 changes: 4 additions & 1 deletion tests/functions_tests/test_maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
{'op_name': 'Sqrt', 'ops': 'chainer.functions.sqrt(a)'},
{'op_name': 'RSqrt', 'ops': 'chainer.functions.rsqrt(a)'},
{'op_name': 'PowVarConst',
'ops': 'chainer.functions.math.basic_math.pow(a, 2)'},
'ops': 'a ** 2.3'},
{'op_name': 'PowConstVar',
'ops': '2.3 ** a'},
{'op_name': 'Sum', 'ops': 'chainer.functions.sum(a)'},
{'op_name': 'Sum', 'ops': 'chainer.functions.sum(a, axis=1)',
'condition': 'axis1'},
Expand Down Expand Up @@ -121,6 +123,7 @@ def test_output_gpu(self):
'ops': 'chainer.functions.matmul(a, b, transb=True)'},
{'op_name': 'Maximum', 'ops': 'chainer.functions.maximum(a, b)'},
{'op_name': 'Minimum', 'ops': 'chainer.functions.minimum(a, b)'},
{'op_name': 'PowVarVar', 'ops': 'a ** b'},
)
class TestBinaryMathOperators(ONNXModelTest):

Expand Down

0 comments on commit aea7d5e

Please sign in to comment.