forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_fx_to_onnx.py
81 lines (66 loc) · 2.85 KB
/
test_fx_to_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Owner(s): ["module: onnx"]
import unittest
import pytorch_test_common
import torch
from torch import nn
from torch.nn import functional as F
from torch.onnx._internal import fx as fx_onnx
from torch.testing._internal import common_utils
class TestFxToOnnx(pytorch_test_common.ExportTestCase):
def setUp(self):
super().setUp()
self.opset_version = torch.onnx._constants.ONNX_DEFAULT_OPSET
def test_simple_function(self):
def func(x):
y = x + 1
z = y.relu()
return (y, z)
_ = fx_onnx.export(func, torch.randn(1, 1, 2), opset_version=self.opset_version)
@unittest.skip(
"Conv Op is not supported at the time. https://github.com/microsoft/onnx-script/issues/397"
)
def test_mnist(self):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
self.fc1 = nn.Linear(9216, 128, bias=False)
self.fc2 = nn.Linear(128, 10, bias=False)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.conv1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.conv2(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = F.max_pool2d(tensor_x, 2)
tensor_x = torch.flatten(tensor_x, 1)
tensor_x = self.fc1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
output = F.log_softmax(tensor_x, dim=1)
return output
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
_ = fx_onnx.export(MNISTModel(), tensor_x, opset_version=self.opset_version)
def test_trace_only_op_with_evaluator(self):
model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
class ArgminArgmaxModel(torch.nn.Module):
def forward(self, input):
return (
torch.argmin(input),
torch.argmax(input),
torch.argmin(input, keepdim=True),
torch.argmax(input, keepdim=True),
torch.argmin(input, dim=0, keepdim=True),
torch.argmax(input, dim=1, keepdim=True),
)
_ = fx_onnx.export(
ArgminArgmaxModel(), model_input, opset_version=self.opset_version
)
def test_multiple_outputs_op_with_evaluator(self):
class TopKModel(torch.nn.Module):
def forward(self, x):
return torch.topk(x, 3)
x = torch.arange(1.0, 6.0, requires_grad=True)
_ = fx_onnx.export(TopKModel(), x, opset_version=self.opset_version)
if __name__ == "__main__":
common_utils.run_tests()