Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 10, 2025
1 parent ca57ab1 commit 560a51a
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tools/pnnx/src/pass_level2/F_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,4 +480,49 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_onnx_1, 110)

class F_pad_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.PadV2 op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.pad";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int ndim = captured_params.at("op_0.arg0").i;

std::vector<int> pads(ndim * 2);
for (int i = 0; i < ndim * 2; i++)
{
pads[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i;
}
op->params["pad"] = pads;

const int type = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 1)).i;
if (type == 0)
{
op->params["mode"] = "constant";
op->params["value"] = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 2));
}
if (type == 1)
{
op->params["mode"] = "reflect";
op->params["value"] = Parameter();
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_tnn, 110)

} // namespace pnnx

0 comments on commit 560a51a

Please sign in to comment.