Skip to content

Commit

Permalink
add paddle sum op
Browse files Browse the repository at this point in the history
  • Loading branch information
taixiurong committed Aug 22, 2022
1 parent 56808c7 commit b8cf8b2
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/core/tests/frontend/paddle/op_fuzzy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ static const std::vector<std::string> models{
std::string("strided_slice_input2_3"),
std::string("strided_slice_input3_1"),
std::string("strided_slice_input3_2"),
std::string("sum_float_1"),
std::string("sum_float_2"),
std::string("sum_float_3"),
std::string("swish_default_params"),
std::string("swish_beta"),
std::string("tanh"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#
# stack paddle model generator
#
import numpy as np
from save_model import saveModel
import sys


def sum(name:str, input1, input2, input3):
import paddle
paddle.enable_static()

with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
data1 = paddle.static.data(
'data1', shape=input1.shape, dtype=input1.dtype)
data2 = paddle.static.data(
'data2', shape=input2.shape, dtype=input2.dtype)
data3 = paddle.static.data(
'data3', shape=input3.shape, dtype=input3.dtype)

out = paddle.fluid.layers.sum([data1, data2, data3])
cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])
exe.run(paddle.static.default_startup_program())

outs = exe.run(
feed={"data1": input1,
"data2": input2,
"data3": input3},
fetch_list=[out])
saveModel(name, exe, feedkeys=['data1', 'data2', 'data3'], fetchlist=[out], inputs=[
input1, input2, input3], outputs=[outs[0]], target_dir=sys.argv[1])

return outs[0]


def main():

in_type = np.float32
in_shape = [1, 5]
input1 = np.random.random(in_shape).astype(in_type)
input2 = np.random.random(in_shape).astype(in_type)
input3 = np.random.random(in_shape).astype(in_type)
sum("sum_float_1", input1, input2, input3)

in_type = np.float32
in_shape = [5, 5]
input1 = np.random.random(in_shape).astype(in_type)
input2 = np.random.random(in_shape).astype(in_type)
input3 = np.random.random(in_shape).astype(in_type)
sum("sum_float_2", input1, input2, input3)

in_type = np.float32
in_shape = [5, 10]
input1 = np.random.random(in_shape).astype(in_type)
input2 = np.random.random(in_shape).astype(in_type)
input3 = np.random.random(in_shape).astype(in_type)
sum("sum_float_3", input1, input2, input3)


if __name__ == "__main__":
main()
30 changes: 30 additions & 0 deletions src/frontends/paddle/src/op/sum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"

namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs sum(const NodeContext& node) {
auto datas = node.get_ng_inputs("X");
auto data_type = datas[0].get_element_type();
auto data_shape = datas[0].get_shape();
std::shared_ptr<Node> out_node = datas[0].get_node_shared_ptr();
for (int i = 1; i < datas.size(); ++i) {
PADDLE_OP_CHECK(node,
data_type == datas[i].get_element_type(),
"sum input tensor must have the same data types!");
PADDLE_OP_CHECK(node,
data_shape == datas[i].get_shape(),
"sum input tensor must have the same shape!");
out_node = std::make_shared<default_opset::Add>(datas[i], out_node);
}
return node.default_single_output_mapping({out_node}, {"Out"});
}
} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/paddle/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ OP_CONVERTER(sqrt);
OP_CONVERTER(squeeze);
OP_CONVERTER(stack);
OP_CONVERTER(strided_slice);
OP_CONVERTER(sum);
OP_CONVERTER(swish);
OP_CONVERTER(tanh);
OP_CONVERTER(top_k_v2);
Expand Down Expand Up @@ -180,6 +181,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"squeeze2", op::squeeze},
{"stack", op::stack},
{"strided_slice", op::strided_slice},
{"sum", op::sum},
{"swish", op::swish},
{"sync_batch_norm", op::batch_norm},
{"tanh", op::tanh},
Expand Down

0 comments on commit b8cf8b2

Please sign in to comment.