From f15f050c84b51b068d7ec951dd808619755321a4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 16:06:29 +0800 Subject: [PATCH] w --- tools/pnnx/src/pass_level2/torch_max.cpp | 22 ++++++++++-- tools/pnnx/src/pass_level2/torch_mean.cpp | 15 ++++++-- tools/pnnx/src/pass_level2/torch_min.cpp | 43 +++++++++++++++++++++++ 3 files changed, 74 insertions(+), 6 deletions(-) diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 280b7b371dc..deac519205d 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -190,7 +190,7 @@ class torch_max_tnn : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -tnn.ReduceMax op_0 1 1 input out arg0=%keepdims arg1=%dim +tnn.ReduceMax op_0 1 1 input out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -202,8 +202,24 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - op->params["dim"] = captured_params.at("dim"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + std::vector dim; + for (int i = 1; ; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + if (dim.size() == 1) + { + op->params["dim"] = dim[0]; + } + else + { + fprintf(stderr, "fallback to reduce max all\n"); + } + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; } }; diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index 38afb2a760d..e0aaf7e21cf 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -156,7 +156,7 @@ class torch_mean_tnn : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -tnn.ReduceMean op_0 1 1 input out arg0=%keepdims arg1=%dim +tnn.ReduceMean op_0 1 1 input out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -168,8 +168,17 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - op->params["dim"] = captured_params.at("dim"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + std::vector dim; + for (int i = 1; ; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; } }; diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index 07248927d2e..ad8dbc6ac28 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -182,4 +182,47 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 50) +class torch_min_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMin op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1; ; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + if (dim.size() == 1) + { + op->params["dim"] = dim[0]; + } + else + { + fprintf(stderr, "fallback to reduce min all\n"); + } + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_tnn, 50) + } // namespace pnnx