forked from onnx/onnx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ced6287
commit b54b3ef
Showing
2 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) ONNX Project Contributors | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
import onnx | ||
from onnx.backend.test.case.base import Base | ||
from onnx.backend.test.case.node import expect | ||
from onnx.reference.ops.op_skip_layer_normalization import _skip_layer_normalization | ||
|
||
|
||
class SkipLayerNormalization(Base): | ||
@staticmethod | ||
def export_3d() -> None: | ||
x = np.random.randn(3, 4, 2).astype(np.float32) | ||
skip = np.random.randn(3, 4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
beta = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
y, input_skip_bias_sum = _skip_layer_normalization( | ||
x, skip, gamma, beta=beta, B=bias | ||
) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipLayerNormalization", | ||
inputs=["x", "skip", "gamma", "beta", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, beta, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_layer_normalization_3d_example", | ||
) | ||
|
||
@staticmethod | ||
def export_2d() -> None: | ||
x = np.random.randn(4, 2).astype(np.float32) | ||
skip = np.random.randn(4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
beta = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
y, input_skip_bias_sum = _skip_layer_normalization( | ||
x, skip, gamma, beta=beta, B=bias | ||
) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipLayerNormalization", | ||
inputs=["x", "skip", "gamma", "beta", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, beta, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_layer_normalization_2d_example", | ||
) | ||
|
||
@staticmethod | ||
def export_epsilon() -> None: | ||
x = np.random.randn(3, 4, 2).astype(np.float32) | ||
skip = np.random.randn(3, 4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
beta = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
epsilon = 1e-2 | ||
y, input_skip_bias_sum = _skip_layer_normalization( | ||
x, skip, gamma, beta=beta, B=bias, epsilon=epsilon | ||
) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipLayerNormalization", | ||
inputs=["x", "skip", "gamma", "beta", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
epsilon=epsilon, | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, beta, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_layer_normalization_epsilon_example", | ||
) | ||
|
||
@staticmethod | ||
def export_scaling_factor() -> None: | ||
x = np.random.randn(3, 4, 2).astype(np.float32) | ||
skip = np.random.randn(3, 4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
beta = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
scaling_factor = 3 | ||
y, input_skip_bias_sum = _skip_layer_normalization( | ||
x, skip, gamma, beta=beta, B=bias, scaling_factor=scaling_factor | ||
) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipLayerNormalization", | ||
inputs=["x", "skip", "gamma", "beta", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
scaling_factor=scaling_factor, | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, beta, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_layer_normalization_scaling_factor_example", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) ONNX Project Contributors | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
import onnx | ||
from onnx.backend.test.case.base import Base | ||
from onnx.backend.test.case.node import expect | ||
from onnx.reference.ops.op_skip_rms_normalization import _skip_rms_normalization | ||
|
||
|
||
class SkipRMSNormalization(Base): | ||
@staticmethod | ||
def export_3d() -> None: | ||
x = np.random.randn(3, 4, 2).astype(np.float32) | ||
skip = np.random.randn(3, 4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
y, input_skip_bias_sum = _skip_rms_normalization(x, skip, gamma, B=bias) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipRMSNormalization", | ||
inputs=["x", "skip", "gamma", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_rms_normalization_3d_example", | ||
) | ||
|
||
@staticmethod | ||
def export_2d() -> None: | ||
x = np.random.randn(4, 2).astype(np.float32) | ||
skip = np.random.randn(4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
y, input_skip_bias_sum = _skip_rms_normalization(x, skip, gamma, B=bias) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipRMSNormalization", | ||
inputs=["x", "skip", "gamma", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_rms_normalization_2d_example", | ||
) | ||
|
||
@staticmethod | ||
def export_epsilon() -> None: | ||
x = np.random.randn(3, 4, 2).astype(np.float32) | ||
skip = np.random.randn(3, 4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
epsilon = 1e-2 | ||
y, input_skip_bias_sum = _skip_rms_normalization( | ||
x, skip, gamma, B=bias, epsilon=epsilon | ||
) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipRMSNormalization", | ||
inputs=["x", "skip", "gamma", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
epsilon=epsilon, | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_rms_normalization_epsilon_example", | ||
) | ||
|
||
@staticmethod | ||
def export_scaling_factor() -> None: | ||
x = np.random.randn(3, 4, 2).astype(np.float32) | ||
skip = np.random.randn(3, 4, 2).astype(np.float32) | ||
gamma = np.random.randn(2).astype(np.float32) | ||
bias = np.random.randn(2).astype(np.float32) | ||
scaling_factor = 3 | ||
y, input_skip_bias_sum = _skip_rms_normalization( | ||
x, skip, gamma, B=bias, scaling_factor=scaling_factor | ||
) | ||
y.astype(np.float32) | ||
input_skip_bias_sum.astype(np.float32) | ||
|
||
node = onnx.helper.make_node( | ||
"SkipRMSNormalization", | ||
inputs=["x", "skip", "gamma", "bias"], | ||
outputs=["y", "input_skip_bias_sum"], | ||
scaling_factor=scaling_factor, | ||
) | ||
|
||
expect( | ||
node, | ||
inputs=[x, skip, gamma, bias], | ||
outputs=[y, input_skip_bias_sum], | ||
name="test_skip_rms_normalization_scaling_factor_example", | ||
) |