diff --git a/tools/pnnx/src/pass_level2/F_pad.cpp b/tools/pnnx/src/pass_level2/F_pad.cpp index 795012ca516..fd37af41d41 100644 --- a/tools/pnnx/src/pass_level2/F_pad.cpp +++ b/tools/pnnx/src/pass_level2/F_pad.cpp @@ -480,4 +480,49 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_onnx_1, 110) +class F_pad_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.PadV2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int ndim = captured_params.at("op_0.arg0").i; + + std::vector pads(ndim * 2); + for (int i = 0; i < ndim * 2; i++) + { + pads[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["pad"] = pads; + + const int type = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 1)).i; + if (type == 0) + { + op->params["mode"] = "constant"; + op->params["value"] = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 2)); + } + if (type == 1) + { + op->params["mode"] = "reflect"; + op->params["value"] = Parameter(); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_tnn, 110) + } // namespace pnnx