forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_fx_to_onnx_with_onnxruntime.py
378 lines (313 loc) · 14.2 KB
/
test_fx_to_onnx_with_onnxruntime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# Owner(s): ["module: onnx"]
from __future__ import annotations
import inspect
import io
import os
import tempfile
from typing import Any, Callable, Sequence, Tuple, Union
import onnx.reference
import onnx_test_common
import onnxruntime # type: ignore[import]
import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.onnx._internal import diagnostics, fx as fx_onnx
from torch.testing._internal import common_utils
from torch.utils import _pytree as pytree
def _run_onnx_reference_runtime(
onnx_model: Union[str, io.BytesIO],
pytorch_inputs: Tuple[Any, ...],
verbose: int = 10,
) -> Sequence[Any]:
session = onnx.reference.ReferenceEvaluator(onnx_model, verbose=verbose)
return session.run(
None, {k: v.cpu().numpy() for k, v in zip(session.input_names, pytorch_inputs)}
)
def _run_ort(
onnx_model: Union[str, io.BytesIO], pytorch_inputs: Tuple[Any, ...]
) -> Sequence[Any]:
session = onnxruntime.InferenceSession(
onnx_model, providers=["CPUExecutionProvider"]
)
input_names = [ort_input.name for ort_input in session.get_inputs()]
return session.run(
None, {k: v.cpu().numpy() for k, v in zip(input_names, pytorch_inputs)}
)
def _run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model: Union[torch.nn.Module, Callable],
input_args,
rtol: float = 1e-3,
atol: float = 1e-7,
opset_version: int = 17,
**input_kwargs,
):
# Feed args and kwargs into exporter.
# Note that exporter should flatten kwargs into positional args the exported model;
# since ONNX doesn't represent kwargs.
onnx_model = fx_onnx.export_after_normalizing_args_and_kwargs(
model,
*input_args,
opset_version=opset_version,
use_binary_format=True,
**input_kwargs,
)
# Inspect the model's signature. It will be used
# to flatten kwargs.
if isinstance(model, torch.nn.Module):
signature = inspect.signature(model.forward)
else:
signature = inspect.signature(model)
# Bind args and kwargs to the model's signature to
# flatten kwargs into positional args since ONNX
# model cannot be called with kwargs.
bound = signature.bind(*input_args, **input_kwargs)
# Fill optional inputs.
bound.apply_defaults()
assert not bound.kwargs
ref_outputs, _ = pytree.tree_flatten(model(*input_args, **input_kwargs))
ort_outputs = _run_ort(onnx_model, bound.args)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(
ref_output, torch.tensor(ort_output), rtol=rtol, atol=atol
)
class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
def setUp(self):
super().setUp()
self.diag_ctx = diagnostics.engine.create_diagnostic_context(
"test_fx_export", version=torch.__version__
)
self.opset_version = 17
def tearDown(self):
diagnostics.engine.dump(
f"test_report_{self._testMethodName}.sarif", compress=False
)
super().tearDown()
def test_simple_function(self):
def func(x):
# TODO(justinchuby): Replicate torch's type casting policy
# in the exporter for type promotion support
y = x + 1.0
z = y.relu()
return (y, z)
tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,))
def test_func_with_args_and_kwargs(self):
# Non-tensor optional kwargs are always folded into constant and
# removed from input list in Dynamo-traced graph, so we can't
# define a function like
# def func(x, b=1.0)
# here. E.g., if you change the `b` to 1.0 below, it will complain
# somewhere that model is called with extra args because the modified
# function is traced into
# def forward(self, x : torch.Tensor):
# add = x + 1.0; x = None
# relu = add.relu()
# return (add, relu)
# To summarize, optional kwargs must be tensors; otherwise, they are
# treated as in-graph constants in Dynamo.
def func(x, b=torch.tensor(1.0)):
y = x + b
z = y.relu()
return (y, z)
tensor_x = torch.randn(1, 1, 2, dtype=torch.float32)
# Test without providing optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(func, (tensor_x,))
# Test with only positional args.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
func, (tensor_x, torch.tensor(8.0))
)
# Test while specifying optional kwarg.
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
func, (tensor_x,), b=torch.tensor(5.0)
)
def test_mnist(self):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=True)
self.conv2 = nn.Conv2d(32, 64, 3, 2, bias=True)
self.fc1 = nn.Linear(9216, 128, bias=True)
self.fc2 = nn.Linear(128, 10, bias=True)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.conv1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.conv2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = torch.flatten(tensor_x, 1)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc2(tensor_x)
return output
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(MNISTModel(), (tensor_x,))
# test single op with no kwargs
def test_sigmoid(self):
x = torch.randn(1, 4, 2, 3)
class SigmoidModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
return self.sigmoid(x)
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(SigmoidModel(), (x,))
# test single op with no kwargs
def test_sigmoid_add(self):
self.opset_version = 17
# TODO(titaiwang): change to randn once it's ready
x = torch.tensor([1.0, 2.0], dtype=torch.float)
class SigmoidAddModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = torch.ops.aten.add(x, 1.0, alpha=2.0)
return self.sigmoid(x)
_run_test_with_fx_to_onnx_exporter_and_onnx_runtime(SigmoidAddModel(), (x,))
def test_gpt2_tiny(self):
model_name = "sshleifer/tiny-gpt2"
# Download pytorch model
model = transformers.AutoModel.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
# Transform input tokens
inputs = tokenizer("Hello world!", return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
onnx_model = fx_onnx.export_after_normalizing_args_and_kwargs(
model, use_binary_format=True, opset_version=self.opset_version, **inputs
)
ref_outputs, _ = pytree.tree_flatten(model(**inputs, return_dict=False))
ort_outputs = _run_ort(onnx_model, (input_ids, attention_mask))
assert len(ref_outputs) == len(ort_outputs)
assert len(ref_outputs) == 5
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
def _test_large_scale_exporter(
self,
model_name,
create_model: Callable,
create_args: Callable,
create_pytorch_only_kwargs: Callable,
):
"""Test helper for large-scale exporter.
Arguments:
model_name: Name of the model. It used to name temporary files.
create_model: A function that creates a model. It should always create the same model.
create_args: A function that creates random input arguments for the model.
create_pytorch_only_kwargs: A function that creates kwargs for calling PyTorch model with real tensors.
This test contains several steps.
1. Create a toy model.
2. Save the toy's state (parameters) to a file. This is for simulating a checkpoint file.
3. Load it back and export it to ONNX with large-scale exporter.
All operations (including model loading) are done under
FakeTensorMode so no real tensor is created and no real
computation happens.
4. The ONNX model generated in step 3 doesn't contain parameters,
and this step adds them as external data and save a new ONNX model.
5. Run PyTorch and ONNX models and compare their results.
"""
# Create the toy model.
model = create_model()
with tempfile.NamedTemporaryFile(
prefix=model_name, suffix=".pt"
) as tmp_file, tempfile.TemporaryDirectory(
suffix="large_scale_export"
) as tmp_folder:
# Dump state_dict to a file to simulate how HuggingFace model is initialized.
# The file will be loaded via .load_state_dict(...)
torch.save(model.state_dict(), tmp_file.name)
ftm = FakeTensorMode(
allow_non_fake_inputs=True, allow_fallback_kernels=False
)
ctx = fx_onnx.FxToOnnxContext()
# The following coed block does several things.
# 1. Create a model whose parameters and buffers are all FakeTensor's.
# 2. Convert nn.Module into ONNX model without initializers.
# 3. Record the file paths to find real initializers.
with ftm, ctx:
# Toy model with parameters and buffers as FakeTensor's.
fake_model = create_model()
fake_model.load_state_dict(torch.load(tmp_file.name))
# Toy inputs as FakeTensor's.
fake_args = create_args()
# Export ONNX model without initializers while ctx.paths records
# all files that contains real initializers.
(onnx_model, _, _, _) = fx_onnx.export_without_parameters_and_buffers(
fake_model,
*fake_args,
use_binary_format=False,
opset_version=self.opset_version,
)
# Tasks done by the following block.
# 1. Iterate through all tensors stored in ctx.paths (the file content is loaded torch.load)
# 2. If a tensor's name matches a "onnx_model"'s input name, an initializer is created and saved to
# a seperated folder.
# 3. A new ONNX model is saved into file with the initializers saved in the previous step.
# 4. ORT executes the new ONNX model and compares the results with the original GPT model.
# Model saved to tmp_folder/onnx_model_location
# Initializers are saved to tmp_folder/onnx_initializer_location/*.onnx
onnx_model_location = model_name + "_external_data.onnx"
onnx_initializer_location = model_name + "_initializers"
fx_onnx.save_model_with_external_data(
tmp_folder,
onnx_model_location,
onnx_initializer_location,
tuple(ctx.paths),
onnx_model,
)
# Generate random inputs.
args = create_args()
kwargs = create_pytorch_only_kwargs()
# Original outputs.
ref_outputs, _ = pytree.tree_flatten(model(*args, **kwargs))
# ORT outputs.
ort_outputs = _run_ort(
os.path.join(tmp_folder, onnx_model_location),
(arg for arg in args if arg is not None),
)
assert len(ref_outputs) == len(ort_outputs)
for ref_output, ort_output in zip(ref_outputs, ort_outputs):
torch.testing.assert_close(ref_output, torch.tensor(ort_output))
def test_large_scale_exporter_with_toy_mlp(self):
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc0(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output
def create_model():
return MLPModel()
def create_args():
return (torch.rand((97, 8), dtype=torch.float32),)
def create_pytorch_only_extra_kwargs():
return {}
self._test_large_scale_exporter(
"toy_mlp1", create_model, create_args, create_pytorch_only_extra_kwargs
)
def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
def create_model():
return transformers.AutoModel.from_pretrained(model_name)
def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
kwargs = tokenizer("Hello world!", return_tensors="pt")
input_ids = kwargs["input_ids"]
attention_mask = kwargs["attention_mask"]
return input_ids, None, attention_mask
def create_pytorch_only_extra_kwargs():
return {"return_dict": False}
self._test_large_scale_exporter(
"tiny_gpt2", create_model, create_args, create_pytorch_only_extra_kwargs
)
if __name__ == "__main__":
common_utils.run_tests()