Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 3】Add Paddle where_index operator (#12437)
Browse files Browse the repository at this point in the history
* add paddle where_index op

* add more test cases

* add more test cases

* add more test cases and support boolen

* fix clang format

* use paddle.nonzero instead of LayerHelper

* remove opset6

Co-authored-by: Ilya Churaev <[email protected]>

Co-authored-by: cecilia peng <[email protected]>
Co-authored-by: Ilya Churaev <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2022
1 parent 8ad3d0f commit 5e977fc
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/core/tests/frontend/paddle/op_fuzzy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,12 @@ static const std::vector<std::string> models{
std::string("where_1"),
std::string("where_2"),
std::string("where_3"),
std::string("where_index_1"),
std::string("where_index_2"),
std::string("where_index_3"),
std::string("where_index_4"),
std::string("where_index_5"),
std::string("where_index_6"),
// Temporily disable them until root caused to secure CI stable.
// CVS-66703 to track this.
// std::string("yolo_box_clip_box"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

#
# where paddle model generator
#
import numpy as np
from save_model import saveModel
import sys
import paddle

paddle.enable_static()


def where_index(name: str, x, force_boolean=False):
with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
node_x = paddle.static.data(name='x', shape=x.shape, dtype=x.dtype)
if force_boolean:
node_x_bl = paddle.fluid.layers.cast(node_x, "bool")
out = paddle.nonzero(node_x_bl)
else:
out = paddle.nonzero(node_x)

cpu = paddle.static.cpu_places(1)
exe = paddle.static.Executor(cpu[0])

# startup program will call initializer to initialize the parameters.
exe.run(paddle.static.default_startup_program())
outs = exe.run(
feed={'x': x},
fetch_list=[out])
saveModel(name, exe, feedkeys=['x'], fetchlist=[out], inputs=[
x], outputs=[outs[0]], target_dir=sys.argv[1])

return outs[0]


def main():
# case of int32
datatype = "int32"
condition = np.random.randint(0, 5, size=[5, 8, 2], dtype=datatype)
paddle_out = where_index("where_index_1", condition)

# case of float32
datatype = "float32"
condition = (np.random.randint(
0, 5, size=[8, 3, 2]) * 1.1).astype(datatype)
paddle_out = where_index("where_index_2", condition)

# case of dimension 4
condition = (np.random.randint(
0, 5, size=[8, 3, 2, 6]) * 1.1).astype(datatype)
paddle_out = where_index("where_index_3", condition)

# case of dimension 5
condition = (np.random.randint(
0, 5, size=[4, 6, 8, 2, 5]) * 1.1).astype(datatype)
paddle_out = where_index("where_index_4", condition)

# case of rank 1
condition = np.ones(10).astype(datatype)
paddle_out = where_index("where_index_5", condition, force_boolean=True)

# case of rank 1 and boolean zeros
condition = np.array([1, 0, 1]).astype(datatype)
paddle_out = where_index("where_index_6", condition, force_boolean=True)


if __name__ == "__main__":
main()
22 changes: 22 additions & 0 deletions src/frontends/paddle/src/op/where_index.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "default_opset.hpp"
#include "openvino/frontend/paddle/node_context.hpp"

namespace ov {
namespace frontend {
namespace paddle {
namespace op {
NamedOutputs where_index(const NodeContext& node) {
const auto condition = node.get_input("Condition");
const auto perm = default_opset::Constant::create(element::i64, Shape{2}, {1, 0});
const auto out = std::make_shared<default_opset::NonZero>(condition, element::i64);
return node.default_single_output_mapping({std::make_shared<default_opset::Transpose>(out, perm)}, {"Out"});
}

} // namespace op
} // namespace paddle
} // namespace frontend
} // namespace ov
2 changes: 2 additions & 0 deletions src/frontends/paddle/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ OP_CONVERTER(transpose2);
OP_CONVERTER(trilinear_interp_v2);
OP_CONVERTER(unsqueeze);
OP_CONVERTER(where);
OP_CONVERTER(where_index);
OP_CONVERTER(yolo_box);
OP_CONVERTER(generate_proposals_v2);
} // namespace op
Expand Down Expand Up @@ -198,6 +199,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"trilinear_interp_v2", op::trilinear_interp_v2},
{"unsqueeze2", op::unsqueeze},
{"where", op::where},
{"where_index", op::where_index},
{"yolo_box", op::yolo_box},
{"generate_proposals_v2", op::generate_proposals_v2}};
};
Expand Down

0 comments on commit 5e977fc

Please sign in to comment.