From 825a8aae7b9bb14a2681c30c444624cc515322ce Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 14:37:24 +0800 Subject: [PATCH 01/29] the tnn2pnnx infrastructure --- tools/pnnx/CMakeLists.txt | 2 ++ tools/pnnx/src/CMakeLists.txt | 18 ++++++++++++++++++ tools/pnnx/src/load_tnn.cpp | 33 +++++++++++++++++++++++++++++++++ tools/pnnx/src/load_tnn.h | 26 ++++++++++++++++++++++++++ tools/pnnx/src/main.cpp | 11 +++++++++++ 5 files changed, 90 insertions(+) create mode 100644 tools/pnnx/src/load_tnn.cpp create mode 100644 tools/pnnx/src/load_tnn.h diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index b09f4758ead..e50ab4788c3 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -123,6 +123,8 @@ else() set(onnxruntime_FOUND FALSE) endif() +option(PNNX_TNN2PNNX "build tnn2pnnx" ON) + add_subdirectory(src) enable_testing() diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index b1ac6f5c024..99dcf7d84ca 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -715,6 +715,19 @@ else() message(STATUS "Building without onnx2pnnx") endif() +if(PNNX_TNN2PNNX) + set(tnn2pnnx_SRCS + load_tnn.cpp + ) + + add_library(tnn2pnnx OBJECT ${tnn2pnnx_SRCS}) + target_compile_definitions(tnn2pnnx PRIVATE BUILD_TNN2PNNX) + + message(STATUS "Building with tnn2pnnx") +else() + message(STATUS "Building without tnn2pnnx") +endif() + if(NOT MSVC) add_definitions(-Wall -Wextra) endif() @@ -765,6 +778,11 @@ if(onnxruntime_FOUND) target_link_libraries(pnnx PRIVATE onnx2pnnx) endif() +if(PNNX_TNN2PNNX) + set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TNN2PNNX) + target_link_libraries(pnnx PRIVATE tnn2pnnx) +endif() + if(PNNX_COVERAGE) target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage) target_link_libraries(pnnx PUBLIC -coverage -lgcov) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp new file mode 100644 index 00000000000..d13f3a130e3 --- /dev/null +++ b/tools/pnnx/src/load_tnn.cpp @@ -0,0 +1,33 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "load_tnn.h" + +#include "ir.h" + +namespace pnnx { + +int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) +{ + fprintf(stderr, "############# pass_level0 tnn\n"); + + fprintf(stderr, "load_tnn %s\n", tnnpath.c_str()); + + // TODO + exit(0); + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/load_tnn.h b/tools/pnnx/src/load_tnn.h new file mode 100644 index 00000000000..6b7075e8555 --- /dev/null +++ b/tools/pnnx/src/load_tnn.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_LOAD_TNN_H +#define PNNX_LOAD_TNN_H + +#include "ir.h" + +namespace pnnx { + +int load_tnn(const std::string& tnnpath, Graph& g); + +} // namespace pnnx + +#endif // PNNX_LOAD_TNN_H diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 8c5c83b891c..ba5e1a7bc76 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -32,6 +32,9 @@ #if BUILD_ONNX2PNNX #include "load_onnx.h" #endif +#if BUILD_TNN2PNNX +#include "load_tnn.h" +#endif #include "pass_ncnn.h" #include "save_ncnn.h" @@ -313,6 +316,14 @@ int main(int argc, char** argv) std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip"; pnnx::Graph pnnx_graph; +#if BUILD_TNN2PNNX + if (1) + { + fprintf(stderr, "TODO distinguish tnnproto file\n"); + load_tnn(ptpath, pnnx_graph); + } + else +#endif #if BUILD_ONNX2PNNX if (!model_file_maybe_torchscript(ptpath)) { From 8e56113b139111984fbc4733b826c062277614e8 Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 14:40:34 +0800 Subject: [PATCH 02/29] w --- tools/pnnx/src/main.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index ba5e1a7bc76..07dd9c5f033 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -322,7 +322,6 @@ int main(int argc, char** argv) fprintf(stderr, "TODO distinguish tnnproto file\n"); load_tnn(ptpath, pnnx_graph); } - else #endif #if BUILD_ONNX2PNNX if (!model_file_maybe_torchscript(ptpath)) From a8851d920940498b8a98a5c3226393fe9619decd Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 16:23:18 +0800 Subject: [PATCH 03/29] w --- tools/pnnx/src/load_tnn.cpp | 354 +++++++++++++++++- tools/pnnx/src/pass_level2/Tensor_permute.cpp | 32 ++ tools/pnnx/src/pass_level2/torch_max.cpp | 27 ++ tools/pnnx/src/pass_level2/torch_mean.cpp | 27 ++ 4 files changed, 438 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index d13f3a130e3..d14ae6f5eb9 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -16,16 +16,366 @@ #include "ir.h" +#include +#include + namespace pnnx { +static bool vstr_is_float(const char vstr[16]) +{ + // look ahead for determine isfloat + for (int j = 0; j < 16; j++) + { + if (vstr[j] == '\0') + break; + + if (vstr[j] == '.' || tolower(vstr[j]) == 'e') + return true; + } + + return false; +} + +static float vstr_to_float(const char vstr[16]) +{ + double v = 0.0; + + const char* p = vstr; + + // sign + bool sign = *p != '-'; + if (*p == '+' || *p == '-') + { + p++; + } + + // digits before decimal point or exponent + unsigned int v1 = 0; + while (isdigit(*p)) + { + v1 = v1 * 10 + (*p - '0'); + p++; + } + + v = (double)v1; + + // digits after decimal point + if (*p == '.') + { + p++; + + unsigned int pow10 = 1; + unsigned int v2 = 0; + + while (isdigit(*p)) + { + v2 = v2 * 10 + (*p - '0'); + pow10 *= 10; + p++; + } + + v += v2 / (double)pow10; + } + + // exponent + if (*p == 'e' || *p == 'E') + { + p++; + + // sign of exponent + bool fact = *p != '-'; + if (*p == '+' || *p == '-') + { + p++; + } + + // digits of exponent + unsigned int expon = 0; + while (isdigit(*p)) + { + expon = expon * 10 + (*p - '0'); + p++; + } + + double scale = 1.0; + while (expon >= 8) + { + scale *= 1e8; + expon -= 8; + } + while (expon > 0) + { + scale *= 10.0; + expon -= 1; + } + + v = fact ? v * scale : v / scale; + } + + // fprintf(stderr, "v = %f\n", v); + return sign ? (float)v : (float)-v; +} + int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) { fprintf(stderr, "############# pass_level0 tnn\n"); fprintf(stderr, "load_tnn %s\n", tnnpath.c_str()); - // TODO - exit(0); + FILE* fp = fopen(tnnpath.c_str(), "rb"); + if (!fp) + { + fprintf(stderr, "fopen %s failed\n", tnnpath.c_str()); + return -1; + } + + char line[4096]; + + // "1 57 1 4206624772 ," + fgets(line, 4096, fp); + int blob_count = 57; + unsigned int magic = 4206624772; + + // "input 2 1 80000 0 ," + fgets(line, 4096, fp); + if (magic == 4206624772) + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + // input operand name + // rank 2 + // shape (1, 80000) + // datatype 0=fp32 + + int ncomsumed = 0; + char blob_name[32]; + int rank = 0; + sscanf(pline, "%s %d%n", blob_name, &rank, &ncomsumed); + + pline += ncomsumed; + + std::vector shape(rank); + for (int i = 0; i < rank; i++) + { + sscanf(pline, "%d%n", &shape[i], &ncomsumed); + + pline += ncomsumed; + } + + int datatype = 0; + sscanf(pline, "%d%n", &datatype, &ncomsumed); + + Operator* op = pnnx_graph.new_operator("pnnx.Input", "input0"); + + Operand* r = pnnx_graph.new_operand(blob_name); + + r->producer = op; + + r->shape = shape; + + if (datatype == 0) + r->type = 1; + + op->outputs.push_back(r); + } + + // all operand names + // " 108 109 110 111 112 113 114 116 118 119 120 125 126 128 130 131 132 133 135 136 138 139 142 144 145 147 148 151 153 154 156 157 160 162 163 165 166 169 171 172 174 175 178 180 181 183 184 188 189 190 191 192 194 85 clipwise_output embedding input ," + fgets(line, 4096, fp); + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + int ncomsumed = 0; + + for (int i = 0; i < blob_count; i++) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, "blob %s\n", blob_name); + + if (!pnnx_graph.get_operand(blob_name)) + { + pnnx_graph.new_operand(blob_name); + } + } + } + + // all output names + // "clipwise_output embedding ," + fgets(line, 4096, fp); + + std::vector output_names; + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + int ncomsumed = 0; + + while (1) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + if (strcmp(blob_name, ",") == 0) + break; + + // fprintf(stderr, "blob %s\n", blob_name); + + output_names.push_back(blob_name); + } + } + + // layer count + // " 56 ," + fgets(line, 4096, fp); + int layer_count = 56; + + for (int i = 0; i < layer_count; i++) + { + // "Unsqueeze Unsqueeze_0 1 1 input 85 1 1 ," + fgets(line, 4096, fp); + + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + int ncomsumed = 0; + + char layer_type[32]; + char layer_name[32]; + int bottom_count; + int top_count; + sscanf(pline, "%s %s %d %d%n", layer_type, layer_name, &bottom_count, &top_count, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, "%s %s %d %d\n", layer_type, layer_name, bottom_count, top_count); + + Operator* op = pnnx_graph.new_operator(std::string("tnn.") + layer_type, layer_name); + + for (int j = 0; j < bottom_count; j++) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, " bottom %s\n", blob_name); + + Operand* r = pnnx_graph.get_operand(blob_name); + if (!r) + { + fprintf(stderr, "%s bottom %s not found\n", layer_name, blob_name); + } + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < top_count; j++) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, " top %s\n", blob_name); + + Operand* r = pnnx_graph.get_operand(blob_name); + if (!r) + { + fprintf(stderr, "%s top %s not found\n", layer_name, blob_name); + } + r->producer = op; + op->outputs.push_back(r); + } + + // layer specific data + // Unsqueeze 1 1 , + // Convolution1D 1 1 257 512 160 0 0 0 -1 1 0 , + + int param_id = 0; + while (1) + { + char vstr[16]; + sscanf(pline, "%s%n", vstr, &ncomsumed); + + pline += ncomsumed; + + if (strcmp(vstr, ",") == 0) + break; + + // fprintf(stderr, "vstr %s\n", vstr); + + bool is_float = vstr_is_float(vstr); + + if (is_float) + { + float v = vstr_to_float(vstr); + + op->params[std::string("arg") + std::to_string(param_id)] = v; + } + else + { + int v = 0; + int nscan = sscanf(vstr, "%d", &v); + if (nscan == 1) + { + op->params[std::string("arg") + std::to_string(param_id)] = v; + } + else + { + // fallback to string type + op->params[std::string("arg") + std::to_string(param_id)] = vstr; + } + } + + param_id++; + } + } + + // append output nodes + const int output_count = (int)output_names.size(); + for (int i = 0; i < output_count; i++) + { + Operator* op = pnnx_graph.new_operator("pnnx.Output", "output" + std::to_string(i)); + + Operand* r = pnnx_graph.get_operand(output_names[i]); + + r->consumers.push_back(op); + + // fprintf(stderr, "r->name = %s\n", r->name.c_str()); + + op->inputs.push_back(r); + } + + + fclose(fp); + + // replace simple operator + for (Operator* op : pnnx_graph.ops) + { + // unary + if (op->type == "tnn.Log") op->type = "aten::log"; + if (op->type == "tnn.ReLU") op->type = "aten::relu"; + if (op->type == "tnn.Sigmoid") op->type = "aten::sigmoid"; + + // binary + if (op->type == "tnn.Add") op->type = "aten::add"; + } return 0; } diff --git a/tools/pnnx/src/pass_level2/Tensor_permute.cpp b/tools/pnnx/src/pass_level2/Tensor_permute.cpp index e53f5a45bbc..7f55d00f636 100644 --- a/tools/pnnx/src/pass_level2/Tensor_permute.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_permute.cpp @@ -82,4 +82,36 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_permute_onnx, 60) +class Tensor_permute_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Permute op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.permute"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int dims_count = captured_params.at("op_0.arg0").i; + std::vector dims(dims_count); + for (int i = 0; i < dims_count; i++) + { + dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["dims"] = dims; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_permute_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 1338d58b88c..280b7b371dc 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -182,4 +182,31 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx_1, 50) +class torch_max_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMax op_0 1 1 input out arg0=%keepdims arg1=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.max"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dim"] = captured_params.at("dim"); + op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_tnn, 50) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index 6a579feeb2d..38afb2a760d 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -148,4 +148,31 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 50) +class torch_mean_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMean op_0 1 1 input out arg0=%keepdims arg1=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.mean"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["dim"] = captured_params.at("dim"); + op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_tnn, 50) + } // namespace pnnx From 598c9f5f57ddcb1a10348004e5d77f34a39a2314 Mon Sep 17 00:00:00 2001 From: nihui <171016+nihui@users.noreply.github.com> Date: Sat, 8 Feb 2025 08:25:07 +0000 Subject: [PATCH 04/29] apply code-format changes --- tools/pnnx/src/load_tnn.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index d14ae6f5eb9..0a5debf2c4c 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -362,7 +362,6 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) op->inputs.push_back(r); } - fclose(fp); // replace simple operator From f134c661804212fb277928d845f6cb187f35a0b4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 16:28:30 +0800 Subject: [PATCH 05/29] w --- .../pnnx/src/pass_level2/torch_unsqueeze.cpp | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp index ea377b071aa..1ff74798a51 100644 --- a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp +++ b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp @@ -104,4 +104,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze_onnx_1, 60) +class torch_unsqueeze_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Unsqueeze op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.unsqueeze"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int dims_count = captured_params.at("op_0.arg0").i; + if (dims_count == 1) + { + op->params["dim"] = captured_params.at("op_0.arg1").i; + } + else + { + std::vector dims(dims_count); + for (int i = 0; i < dims_count; i++) + { + dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["dim"] = dims; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze_tnn, 60) + } // namespace pnnx From 8409c77f781f84dc20fc5d6aa063c59b9145e094 Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 16:33:13 +0800 Subject: [PATCH 06/29] w --- tools/pnnx/src/pass_level2/torch_clamp.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tools/pnnx/src/pass_level2/torch_clamp.cpp b/tools/pnnx/src/pass_level2/torch_clamp.cpp index ffe241e6895..db7e94521eb 100644 --- a/tools/pnnx/src/pass_level2/torch_clamp.cpp +++ b/tools/pnnx/src/pass_level2/torch_clamp.cpp @@ -114,4 +114,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_onnx_2, 40) +class torch_clamp_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Clip op_0 1 1 input out arg0=%min arg1=%max +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.clamp"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_tnn, 40) + } // namespace pnnx From 4b3d8f32df8ec77599e65708689674681310ec5a Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 19:18:42 +0800 Subject: [PATCH 07/29] load model binary --- tools/pnnx/src/ir.h | 3 + tools/pnnx/src/load_tnn.cpp | 200 +++++++++++++++++++++++++++++++++--- 2 files changed, 187 insertions(+), 16 deletions(-) diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 779c2eec9f1..6ab52eb0a21 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -234,6 +234,9 @@ class Attribute #if BUILD_ONNX2PNNX Attribute(const onnx::TensorProto& t); #endif +#if BUILD_TNN2PNNX + Attribute(FILE* bp); +#endif Attribute(const std::initializer_list& shape, const std::vector& t); diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index 0a5debf2c4c..aa185c984b9 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -18,6 +18,7 @@ #include #include +#include namespace pnnx { @@ -116,29 +117,95 @@ static float vstr_to_float(const char vstr[16]) return sign ? (float)v : (float)-v; } +static size_t type_to_elemsize(int type) +{ + if (type == 1) return 4; + if (type == 2) return 8; + if (type == 3) return 2; + if (type == 4) return 4; + if (type == 5) return 8; + if (type == 6) return 2; + if (type == 7) return 1; + if (type == 8) return 1; + if (type == 9) return 1; + if (type == 10) return 8; + if (type == 11) return 16; + if (type == 12) return 4; + return 0; // null +} + +static int get_tnn_tensor_type(int dt) +{ + if (dt == 0) return 1; + + fprintf(stderr, "unsupported tnn tensor type %d\n", dt); + return 0; // unknown type +} + +Attribute::Attribute(FILE* bp) +{ + unsigned int magic; + int datatype; + int length; + int ndim; + fread(&magic, 1, sizeof(unsigned int), bp); + fread(&datatype, 1, sizeof(int), bp); + fread(&length, 1, sizeof(int), bp); + fread(&ndim, 1, sizeof(int), bp); + + type = get_tnn_tensor_type(datatype); + + if (ndim == 0) + { + shape = {1}; + + data.resize(type_to_elemsize(type)); + + // assert length == type_to_elemsize(type) + fread((void*)data.data(), 1, length, bp); + + return; + } + + shape.resize(ndim); + for (int i = 0; i < ndim; i++) + { + fread(&shape[i], 1, sizeof(int), bp); + } + + data.resize(elemcount() * type_to_elemsize(type)); + + // assert length == elemcount() * type_to_elemsize(type) + fread((void*)data.data(), 1, length, bp); +} + int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) { fprintf(stderr, "############# pass_level0 tnn\n"); - fprintf(stderr, "load_tnn %s\n", tnnpath.c_str()); + // generate proto and model path + std::string tnnprotopath = tnnpath; + std::string tnnmodelpath = tnnpath.substr(0, tnnpath.size() - 8) + "tnnmodel"; + + fprintf(stderr, "load_tnn %s %s\n", tnnprotopath.c_str(), tnnmodelpath.c_str()); - FILE* fp = fopen(tnnpath.c_str(), "rb"); - if (!fp) + FILE* pp = fopen(tnnprotopath.c_str(), "rb"); + if (!pp) { - fprintf(stderr, "fopen %s failed\n", tnnpath.c_str()); + fprintf(stderr, "fopen %s failed\n", tnnprotopath.c_str()); return -1; } char line[4096]; // "1 57 1 4206624772 ," - fgets(line, 4096, fp); + fgets(line, 4096, pp); int blob_count = 57; - unsigned int magic = 4206624772; + unsigned int proto_magic = 4206624772; // "input 2 1 80000 0 ," - fgets(line, 4096, fp); - if (magic == 4206624772) + fgets(line, 4096, pp); + if (proto_magic == 4206624772) { // strip leading and tail double quote line[strlen(line) - 2] = '\0'; @@ -175,16 +242,14 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) r->producer = op; r->shape = shape; - - if (datatype == 0) - r->type = 1; + r->type = get_tnn_tensor_type(datatype); op->outputs.push_back(r); } // all operand names // " 108 109 110 111 112 113 114 116 118 119 120 125 126 128 130 131 132 133 135 136 138 139 142 144 145 147 148 151 153 154 156 157 160 162 163 165 166 169 171 172 174 175 178 180 181 183 184 188 189 190 191 192 194 85 clipwise_output embedding input ," - fgets(line, 4096, fp); + fgets(line, 4096, pp); { // strip leading and tail double quote line[strlen(line) - 2] = '\0'; @@ -211,7 +276,7 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // all output names // "clipwise_output embedding ," - fgets(line, 4096, fp); + fgets(line, 4096, pp); std::vector output_names; { @@ -240,13 +305,13 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // layer count // " 56 ," - fgets(line, 4096, fp); + fgets(line, 4096, pp); int layer_count = 56; for (int i = 0; i < layer_count; i++) { // "Unsqueeze Unsqueeze_0 1 1 input 85 1 1 ," - fgets(line, 4096, fp); + fgets(line, 4096, pp); // strip leading and tail double quote line[strlen(line) - 2] = '\0'; @@ -362,7 +427,110 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) op->inputs.push_back(r); } - fclose(fp); + fclose(pp); + + FILE* bp = fopen(tnnmodelpath.c_str(), "rb"); + if (!bp) + { + fprintf(stderr, "fopen %s failed\n", tnnmodelpath.c_str()); + return -1; + } + + // magic 0xfabc0004 + unsigned int model_magic; + fread(&model_magic, 1, sizeof(unsigned int), bp); + if (model_magic != 0xfabc0004) + { + fprintf(stderr, "model_magic %x failed\n", model_magic); + return -1; + } + + int weight_count = 0; + fread(&weight_count, 1, sizeof(int), bp); + + fprintf(stderr, "weight_count = %d\n", weight_count); + + std::unordered_map op_map; + for (auto x : pnnx_graph.ops) + { + op_map[x->name] = x; + } + + for (int i = 0; i < weight_count; i++) + { + int opid; + fread(&opid, 1, sizeof(int), bp); + + int type_size; + std::string type; + fread(&type_size, 1, sizeof(int), bp); + type.resize(type_size); + fread((void*)type.data(), 1, type_size, bp); + + int name_size; + std::string name; + fread(&name_size, 1, sizeof(int), bp); + name.resize(name_size); + fread((void*)name.data(), 1, name_size, bp); + + fprintf(stderr, "model %d %s %s\n", opid, type.c_str(), name.c_str()); + + Operator* op = op_map.at(name); + + int attribute_count = 0; + + if (type == "Convolution1D" || type == "Convolution") + { + // skip name2 == name + int name2_size; + std::string name2; + fread(&name2_size, 1, sizeof(int), bp); + name2.resize(name2_size); + fread((void*)name2.data(), 1, name2_size, bp); + + // bias + int bias; + fread(&bias, 1, sizeof(int), bp); + + attribute_count = bias ? 2 : 1; + } + if (type == "InnerProduct") + { + // skip name2 == name + int name2_size; + std::string name2; + fread(&name2_size, 1, sizeof(int), bp); + name2.resize(name2_size); + fread((void*)name2.data(), 1, name2_size, bp); + + attribute_count = 2; + } + if (type == "MatMul") + { + attribute_count = 1; + } + if (type == "Add" || type == "Sub" || type == "Mul" || type == "Div") + { + attribute_count = 1; + } + if (type == "BatchNormCxx") + { + attribute_count = 2; + } + + for (int j = 0; j < attribute_count; j++) + { + Operator* op_scale = pnnx_graph.new_operator_before("pnnx.Attribute", name + "_attr" + std::to_string(j), op); + Operand* r0 = pnnx_graph.new_operand(name + "_attr" + std::to_string(j)); + op_scale->attrs["data"] = Attribute(bp); + op_scale->outputs.push_back(r0); + r0->producer = op_scale; + r0->consumers.push_back(op); + op->inputs.push_back(r0); + } + } + + fclose(bp); // replace simple operator for (Operator* op : pnnx_graph.ops) From a3c33b8aae626fcb87fbe132889b5fc1d8f9ceb2 Mon Sep 17 00:00:00 2001 From: nihuini Date: Sat, 8 Feb 2025 20:07:03 +0800 Subject: [PATCH 08/29] w --- tools/pnnx/src/load_tnn.cpp | 3 + tools/pnnx/src/pass_level2/F_conv1d.cpp | 199 +++++++++++++++++++++++ tools/pnnx/src/pass_level2/F_conv2d.cpp | 200 ++++++++++++++++++++++++ tools/pnnx/src/pass_level2/F_linear.cpp | 23 +++ 4 files changed, 425 insertions(+) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index aa185c984b9..3e8263084dc 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -542,6 +542,9 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // binary if (op->type == "tnn.Add") op->type = "aten::add"; + if (op->type == "tnn.Sub") op->type = "aten::sub"; + if (op->type == "tnn.Mul") op->type = "aten::mul"; + if (op->type == "tnn.Div") op->type = "aten::div"; } return 0; diff --git a/tools/pnnx/src/pass_level2/F_conv1d.cpp b/tools/pnnx/src/pass_level2/F_conv1d.cpp index cdc503d5345..25a52ebccd3 100644 --- a/tools/pnnx/src/pass_level2/F_conv1d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv1d.cpp @@ -195,4 +195,203 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_onnx_1, 140) +class F_conv1d_activation_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +F.relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +F.relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +F.silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } + + bool match(const std::map& captured_params) const + { + this->activation = captured_params.at("op_0.arg9").i; + return activation != 0; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + for (int i = 0; i < 9; i++) + { + std::string argN = std::string("arg") + std::to_string(i); + ops.at("conv1d")->params[argN] = captured_params.at("op_0." + argN); + } + + ops.at("conv1d")->params["arg9"] = 0; + } + +protected: + mutable int activation; +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_activation_tnn, 139) + +class F_conv1d_activation_tnn_1 : public F_conv1d_activation_tnn +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +F.relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +F.relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +F.silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_activation_tnn_1, 139) + +class F_conv1d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv1d"; + } + + bool match(const std::map& captured_params) const + { + const int activation = captured_params.at("op_0.arg9").i; + return activation == 0; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["groups"] = captured_params.at("op_0.arg0"); + // captured_params.at("op_0.arg1"); // inch + // captured_params.at("op_0.arg2"); // outch + // captured_params.at("op_0.arg3"); // kernel_size + op->params["stride"] = {captured_params.at("op_0.arg4").i}; + op->params["padding"] = {captured_params.at("op_0.arg5").i}; + // captured_params.at("op_0.arg6"); // bias + // captured_params.at("op_0.arg7"); // pad_type + op->params["dilation"] = {captured_params.at("op_0.arg8").i}; + // captured_params.at("op_0.arg9"); // activation + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_tnn, 140) + +class F_conv1d_tnn_1 : public F_conv1d_tnn +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params) const + { + F_conv1d_tnn::write(op, captured_params); + + op->params["bias"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_tnn_1, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv2d.cpp b/tools/pnnx/src/pass_level2/F_conv2d.cpp index 806cc44af9e..940e6473fa9 100644 --- a/tools/pnnx/src/pass_level2/F_conv2d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv2d.cpp @@ -305,4 +305,204 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_onnx_1, 140) +class F_conv2d_activation_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +F.relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +F.relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +F.silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } + + bool match(const std::map& captured_params) const + { + this->activation = captured_params.at("op_0.arg13").i; + return activation != 0; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + for (int i = 0; i < 13; i++) + { + std::string argN = std::string("arg") + std::to_string(i); + ops.at("conv2d")->params[argN] = captured_params.at("op_0." + argN); + } + + ops.at("conv2d")->params["arg13"] = 0; + } + +protected: + mutable int activation; +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_activation_tnn, 139) + +class F_conv2d_activation_tnn_1 : public F_conv2d_activation_tnn +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +F.relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +F.relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +F.silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_activation_tnn_1, 139) + +class F_conv2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv2d"; + } + + bool match(const std::map& captured_params) const + { + const int activation = captured_params.at("op_0.arg13").i; + return activation == 0; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["groups"] = captured_params.at("op_0.arg0"); + // captured_params.at("op_0.arg1"); // inch + // captured_params.at("op_0.arg2"); // outch + // captured_params.at("op_0.arg3"); // kernel_h + // captured_params.at("op_0.arg4"); // kernel_w + op->params["stride"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i}; + op->params["padding"] = {captured_params.at("op_0.arg7").i, captured_params.at("op_0.arg8").i}; + // captured_params.at("op_0.arg9"); // bias + // captured_params.at("op_0.arg10"); // pad_type + op->params["dilation"] = {captured_params.at("op_0.arg11").i, captured_params.at("op_0.arg12").i}; + // captured_params.at("op_0.arg13"); // activation + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_tnn, 140) + +class F_conv2d_tnn_1 : public F_conv2d_tnn +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params) const + { + F_conv2d_tnn::write(op, captured_params); + + op->params["bias"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_tnn_1, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_linear.cpp b/tools/pnnx/src/pass_level2/F_linear.cpp index 2b43fda9979..d796cb47188 100644 --- a/tools/pnnx/src/pass_level2/F_linear.cpp +++ b/tools/pnnx/src/pass_level2/F_linear.cpp @@ -340,4 +340,27 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_onnx_4, 110) +class F_linear_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight @data=(%in_features,%out_features)f32 +pnnx.Input input_2 0 1 bias @data=(%out_features)f32 +tnn.InnerProduct op_0 3 1 input weight bias out arg0=* arg1=* arg2=0 arg3=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_tnn, 140) + } // namespace pnnx From ca57ab1b0f8f64b6d95fe1a2f8ae20b7cecad945 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 11:37:40 +0800 Subject: [PATCH 09/29] w --- .../src/pass_level2/F_adaptive_avg_pool2d.cpp | 77 +++++++++++++++++++ .../src/pass_level2/F_adaptive_max_pool2d.cpp | 77 +++++++++++++++++++ tools/pnnx/src/pass_level2/F_avg_pool2d.cpp | 71 +++++++++++++++++ tools/pnnx/src/pass_level2/F_max_pool2d.cpp | 70 +++++++++++++++++ 4 files changed, 295 insertions(+) diff --git a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp index 72fed050263..5a6a31909fc 100644 --- a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp @@ -64,4 +64,81 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_onnx, 120) +class F_adaptive_avg_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_avg_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 1) + return false; + + const int pad_h = captured_params.at("op_0.arg5").i; + const int pad_w = captured_params.at("op_0.arg6").i; + if (pad_h != 0 || pad_w != 0) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return true; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + return true; + } + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + { + // global pooling + op->params["output_size"] = {1, 1}; + } + + // captured_params.at("op_0.arg3"); // stride_h + // captured_params.at("op_0.arg4"); // stride_w + // captured_params.at("op_0.arg5"); // pad_h + // captured_params.at("op_0.arg6"); // pad_w + // captured_params.at("op_0.arg7"); // kernel_index_h + // captured_params.at("op_0.arg8"); // kernel_index_w + // captured_params.at("op_0.arg9"); // pad_type + // captured_params.at("op_0.arg10"); // ceil_mode + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + { + op->params["output_size"] = {captured_params.at("op_0.arg12").i, captured_params.at("op_0.arg13").i}; + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp index 79776cae7f6..72bfc6aaba7 100644 --- a/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp @@ -43,4 +43,81 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d, 120) +class F_adaptive_max_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_max_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 0) + return false; + + const int pad_h = captured_params.at("op_0.arg5").i; + const int pad_w = captured_params.at("op_0.arg6").i; + if (pad_h != 0 || pad_w != 0) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return true; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + return true; + } + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + { + // global pooling + op->params["output_size"] = {1, 1}; + } + + // captured_params.at("op_0.arg3"); // stride_h + // captured_params.at("op_0.arg4"); // stride_w + // captured_params.at("op_0.arg5"); // pad_h + // captured_params.at("op_0.arg6"); // pad_w + // captured_params.at("op_0.arg7"); // kernel_index_h + // captured_params.at("op_0.arg8"); // kernel_index_w + // captured_params.at("op_0.arg9"); // pad_type + // captured_params.at("op_0.arg10"); // ceil_mode + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + { + op->params["output_size"] = {captured_params.at("op_0.arg12").i, captured_params.at("op_0.arg13").i}; + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp index da03102c6d8..d9fb11ac983 100644 --- a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp @@ -205,4 +205,75 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_onnx, 120) +class F_avg_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.avg_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 1) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return false; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive != 0) + return false; + } + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + op->params["kernel_size"] = {captured_params.at("op_0.arg1").i, captured_params.at("op_0.arg2").i}; + op->params["stride"] = {captured_params.at("op_0.arg3").i, captured_params.at("op_0.arg4").i}; + op->params["padding"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i}; + + const int kernel_index_h = captured_params.at("op_0.arg7").i; + const int kernel_index_w = captured_params.at("op_0.arg8").i; + if (kernel_index_h != -1 || kernel_index_w != -1) + { + fprintf(stderr, "unsupported F.avg_pool2d kernel_index %d %d\n", kernel_index_h, kernel_index_w); + } + + const int pad_type = captured_params.at("op_0.arg9").i; + if (pad_type > 0) + { + fprintf(stderr, "unsupported F.avg_pool2d pad_type %d\n", pad_type); + } + + op->params["ceil_mode"] = captured_params.at("op_0.arg10").i ? true : false; + // captured_params.at("op_0.arg11"); // is_adaptive + // captured_params.at("op_0.arg12"); // output_h + // captured_params.at("op_0.arg13"); // output_w + + op->params["count_include_pad"] = false; + op->params["divisor_override"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp index f3fbc681121..0fa98c79873 100644 --- a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp @@ -298,4 +298,74 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_onnx_1, 120) +class F_max_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 0) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return false; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive != 0) + return false; + } + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + op->params["kernel_size"] = {captured_params.at("op_0.arg1").i, captured_params.at("op_0.arg2").i}; + op->params["stride"] = {captured_params.at("op_0.arg3").i, captured_params.at("op_0.arg4").i}; + op->params["padding"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i}; + + const int kernel_index_h = captured_params.at("op_0.arg7").i; + const int kernel_index_w = captured_params.at("op_0.arg8").i; + if (kernel_index_h != -1 || kernel_index_w != -1) + { + fprintf(stderr, "unsupported F.avg_pool2d kernel_index %d %d\n", kernel_index_h, kernel_index_w); + } + + const int pad_type = captured_params.at("op_0.arg9").i; + if (pad_type > 0) + { + fprintf(stderr, "unsupported F.avg_pool2d pad_type %d\n", pad_type); + } + + op->params["ceil_mode"] = captured_params.at("op_0.arg10").i ? true : false; + // captured_params.at("op_0.arg11"); // is_adaptive + // captured_params.at("op_0.arg12"); // output_h + // captured_params.at("op_0.arg13"); // output_w + + op->params["return_indices"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_tnn, 120) + } // namespace pnnx From 560a51ac97cdab41358c84ad87ac1944ae150cda Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 11:50:35 +0800 Subject: [PATCH 10/29] w --- tools/pnnx/src/pass_level2/F_pad.cpp | 45 ++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tools/pnnx/src/pass_level2/F_pad.cpp b/tools/pnnx/src/pass_level2/F_pad.cpp index 795012ca516..fd37af41d41 100644 --- a/tools/pnnx/src/pass_level2/F_pad.cpp +++ b/tools/pnnx/src/pass_level2/F_pad.cpp @@ -480,4 +480,49 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_onnx_1, 110) +class F_pad_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.PadV2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int ndim = captured_params.at("op_0.arg0").i; + + std::vector pads(ndim * 2); + for (int i = 0; i < ndim * 2; i++) + { + pads[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["pad"] = pads; + + const int type = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 1)).i; + if (type == 0) + { + op->params["mode"] = "constant"; + op->params["value"] = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 2)); + } + if (type == 1) + { + op->params["mode"] = "reflect"; + op->params["value"] = Parameter(); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_tnn, 110) + } // namespace pnnx From faa4b03954703939b48dff1ed892c6baae90e230 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 14:41:26 +0800 Subject: [PATCH 11/29] lower conv-act, lower power --- tools/pnnx/src/CMakeLists.txt | 6 + tools/pnnx/src/load_tnn.cpp | 7 + tools/pnnx/src/pass_level2/F_conv1d.cpp | 132 -------- tools/pnnx/src/pass_level2/F_conv2d.cpp | 132 -------- .../pass_tnn/lower_convolution_activation.cpp | 295 ++++++++++++++++++ .../pass_tnn/lower_convolution_activation.h | 25 ++ tools/pnnx/src/pass_tnn/lower_power.cpp | 62 ++++ tools/pnnx/src/pass_tnn/lower_power.h | 25 ++ 8 files changed, 420 insertions(+), 264 deletions(-) create mode 100644 tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp create mode 100644 tools/pnnx/src/pass_tnn/lower_convolution_activation.h create mode 100644 tools/pnnx/src/pass_tnn/lower_power.cpp create mode 100644 tools/pnnx/src/pass_tnn/lower_power.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 99dcf7d84ca..3cb849ff213 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -716,7 +716,13 @@ else() endif() if(PNNX_TNN2PNNX) + set(pnnx_pass_tnn_SRCS + pass_tnn/lower_convolution_activation.cpp + pass_tnn/lower_power.cpp + ) + set(tnn2pnnx_SRCS + ${pnnx_pass_tnn_SRCS} load_tnn.cpp ) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index 3e8263084dc..7b6d699f2eb 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -20,6 +20,9 @@ #include #include +#include "pass_tnn/lower_convolution_activation.h" +#include "pass_tnn/lower_power.h" + namespace pnnx { static bool vstr_is_float(const char vstr[16]) @@ -547,6 +550,10 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) if (op->type == "tnn.Div") op->type = "aten::div"; } + tnn2pnnx::lower_convolution_activation(pnnx_graph); + + tnn2pnnx::lower_power(pnnx_graph); + return 0; } diff --git a/tools/pnnx/src/pass_level2/F_conv1d.cpp b/tools/pnnx/src/pass_level2/F_conv1d.cpp index 25a52ebccd3..6f8eabae2f6 100644 --- a/tools/pnnx/src/pass_level2/F_conv1d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv1d.cpp @@ -195,138 +195,6 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_onnx_1, 140) -class F_conv1d_activation_tnn : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution1D op_0 3 1 input weight bias out %*=%* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* replace_pattern_graph() const - { - if (this->activation == 1) - { - return R"PNNXIR(7767517 -6 5 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution1D conv1d 3 1 input weight bias a -F.relu relu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else if (this->activation == 2) - { - return R"PNNXIR(7767517 -6 5 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution1D conv1d 3 1 input weight bias a -F.relu6 relu6 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else // if (this->activation == 256) - { - return R"PNNXIR(7767517 -6 5 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution1D conv1d 3 1 input weight bias a -F.silu silu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - } - - bool match(const std::map& captured_params) const - { - this->activation = captured_params.at("op_0.arg9").i; - return activation != 0; - } - - void write(const std::map& ops, const std::map& captured_params) const - { - for (int i = 0; i < 9; i++) - { - std::string argN = std::string("arg") + std::to_string(i); - ops.at("conv1d")->params[argN] = captured_params.at("op_0." + argN); - } - - ops.at("conv1d")->params["arg9"] = 0; - } - -protected: - mutable int activation; -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_activation_tnn, 139) - -class F_conv1d_activation_tnn_1 : public F_conv1d_activation_tnn -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution1D op_0 2 1 input weight out %*=%* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* replace_pattern_graph() const - { - if (this->activation == 1) - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution1D conv1d 2 1 input weight a -F.relu relu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else if (this->activation == 2) - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution1D conv1d 2 1 input weight a -F.relu6 relu6 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else // if (this->activation == 256) - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution1D conv1d 2 1 input weight a -F.silu silu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_activation_tnn_1, 139) - class F_conv1d_tnn : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_level2/F_conv2d.cpp b/tools/pnnx/src/pass_level2/F_conv2d.cpp index 940e6473fa9..368096520ff 100644 --- a/tools/pnnx/src/pass_level2/F_conv2d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv2d.cpp @@ -305,138 +305,6 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_onnx_1, 140) -class F_conv2d_activation_tnn : public GraphRewriterPass -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution op_0 3 1 input weight bias out %*=%* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* replace_pattern_graph() const - { - if (this->activation == 1) - { - return R"PNNXIR(7767517 -6 5 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution conv2d 3 1 input weight bias a -F.relu relu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else if (this->activation == 2) - { - return R"PNNXIR(7767517 -6 5 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution conv2d 3 1 input weight bias a -F.relu6 relu6 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else // if (this->activation == 256) - { - return R"PNNXIR(7767517 -6 5 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -pnnx.Input input_2 0 1 bias -tnn.Convolution conv2d 3 1 input weight bias a -F.silu silu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - } - - bool match(const std::map& captured_params) const - { - this->activation = captured_params.at("op_0.arg13").i; - return activation != 0; - } - - void write(const std::map& ops, const std::map& captured_params) const - { - for (int i = 0; i < 13; i++) - { - std::string argN = std::string("arg") + std::to_string(i); - ops.at("conv2d")->params[argN] = captured_params.at("op_0." + argN); - } - - ops.at("conv2d")->params["arg13"] = 0; - } - -protected: - mutable int activation; -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_activation_tnn, 139) - -class F_conv2d_activation_tnn_1 : public F_conv2d_activation_tnn -{ -public: - const char* match_pattern_graph() const - { - return R"PNNXIR(7767517 -4 3 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution op_0 2 1 input weight out %*=%* -pnnx.Output output 1 0 out -)PNNXIR"; - } - - const char* replace_pattern_graph() const - { - if (this->activation == 1) - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution conv2d 2 1 input weight a -F.relu relu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else if (this->activation == 2) - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution conv2d 2 1 input weight a -F.relu6 relu6 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - else // if (this->activation == 256) - { - return R"PNNXIR(7767517 -5 4 -pnnx.Input input_0 0 1 input -pnnx.Input input_1 0 1 weight -tnn.Convolution conv2d 2 1 input weight a -F.silu silu 1 1 a out -pnnx.Output output 1 0 out -)PNNXIR"; - } - } -}; - -REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_activation_tnn_1, 139) - class F_conv2d_tnn : public GraphRewriterPass { public: diff --git a/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp b/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp new file mode 100644 index 00000000000..05da8f8e4d6 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp @@ -0,0 +1,295 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "lower_convolution_activation.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +class lower_convolution_activation_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } + + bool match(const std::map& captured_params) const + { + this->activation = captured_params.at("op_0.arg13").i; + return activation != 0; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + for (int i = 0; i < 13; i++) + { + std::string argN = std::string("arg") + std::to_string(i); + ops.at("conv2d")->params[argN] = captured_params.at("op_0." + argN); + } + + ops.at("conv2d")->params["arg13"] = 0; + } + +protected: + mutable int activation; +}; + +class lower_convolution_activation_pass_1 : public lower_convolution_activation_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } +}; + +class lower_convolution1d_activation_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } + + bool match(const std::map& captured_params) const + { + this->activation = captured_params.at("op_0.arg9").i; + return activation != 0; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + for (int i = 0; i < 9; i++) + { + std::string argN = std::string("arg") + std::to_string(i); + ops.at("conv1d")->params[argN] = captured_params.at("op_0." + argN); + } + + ops.at("conv1d")->params["arg9"] = 0; + } + +protected: + mutable int activation; +}; + +class lower_convolution1d_activation_pass_1 : public lower_convolution1d_activation_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } +}; + +void lower_convolution_activation(Graph& graph) +{ + lower_convolution_activation_pass a; + lower_convolution_activation_pass_1 a1; + lower_convolution1d_activation_pass b; + lower_convolution1d_activation_pass_1 b1; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &a1, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &b1, opindex); +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_convolution_activation.h b/tools/pnnx/src/pass_tnn/lower_convolution_activation.h new file mode 100644 index 00000000000..91d85f75500 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_convolution_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_convolution_activation(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_power.cpp b/tools/pnnx/src/pass_tnn/lower_power.cpp new file mode 100644 index 00000000000..a35edaa5ffb --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_power.cpp @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "lower_power.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +class lower_power_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Power op_0 1 1 input out arg0=%exponent arg1=%alpha arg2=%beta +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input +prim::Constant alpha 0 1 alpha value=%alpha +prim::Constant beta 0 1 beta value=%beta +prim::Constant exponent 0 1 exponent value=%exponent +aten::mul scale 2 1 input alpha a +aten::add shift 2 1 a beta b +aten::pow pow 2 1 b exponent out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +void lower_power(Graph& graph) +{ + lower_power_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_power.h b/tools/pnnx/src/pass_tnn/lower_power.h new file mode 100644 index 00000000000..35bbedf6393 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_power.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_power(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx From 4a2e8a54e9ac11ccf83daa47df4869fa01571158 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 15:09:20 +0800 Subject: [PATCH 12/29] w --- tools/pnnx/src/pass_level2/F_conv1d.cpp | 4 +++ tools/pnnx/src/pass_level2/F_conv2d.cpp | 4 +++ tools/pnnx/src/pass_level2/F_pad.cpp | 17 ++++++++-- tools/pnnx/src/pass_level2/torch_matmul.cpp | 35 +++++++++++++++++++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/pass_level2/F_conv1d.cpp b/tools/pnnx/src/pass_level2/F_conv1d.cpp index 6f8eabae2f6..78df32b5d60 100644 --- a/tools/pnnx/src/pass_level2/F_conv1d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv1d.cpp @@ -232,6 +232,10 @@ pnnx.Output output 1 0 out // captured_params.at("op_0.arg6"); // bias // captured_params.at("op_0.arg7"); // pad_type op->params["dilation"] = {captured_params.at("op_0.arg8").i}; + if (op->params["dilation"].ai == std::vector{-1}) + { + op->params["dilation"] = {1}; + } // captured_params.at("op_0.arg9"); // activation } }; diff --git a/tools/pnnx/src/pass_level2/F_conv2d.cpp b/tools/pnnx/src/pass_level2/F_conv2d.cpp index 368096520ff..f650b55f414 100644 --- a/tools/pnnx/src/pass_level2/F_conv2d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv2d.cpp @@ -343,6 +343,10 @@ pnnx.Output output 1 0 out // captured_params.at("op_0.arg9"); // bias // captured_params.at("op_0.arg10"); // pad_type op->params["dilation"] = {captured_params.at("op_0.arg11").i, captured_params.at("op_0.arg12").i}; + if (op->params["dilation"].ai == std::vector{-1, -1}) + { + op->params["dilation"] = {1, 1}; + } // captured_params.at("op_0.arg13"); // activation } }; diff --git a/tools/pnnx/src/pass_level2/F_pad.cpp b/tools/pnnx/src/pass_level2/F_pad.cpp index fd37af41d41..5511a7d67c6 100644 --- a/tools/pnnx/src/pass_level2/F_pad.cpp +++ b/tools/pnnx/src/pass_level2/F_pad.cpp @@ -503,10 +503,23 @@ pnnx.Output output 1 0 out const int ndim = captured_params.at("op_0.arg0").i; std::vector pads(ndim * 2); - for (int i = 0; i < ndim * 2; i++) + for (int i = 0; i < ndim; i++) { - pads[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + pads[(ndim - 1 - i) * 2] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; } + for (int i = 0; i < ndim; i++) + { + pads[(ndim - 1 - i) * 2 + 1] = captured_params.at("op_0.arg" + std::to_string(ndim + i + 1)).i; + } + + // strip zero pads for higher dims + // (3,3,0,0,0,0) to (3,3) + for (int i = ndim - 1; i >= 0; i--) + { + if (pads[i * 2] == 0 && pads[i * 2 + 1] == 0) + pads.resize(i * 2); + } + op->params["pad"] = pads; const int type = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 1)).i; diff --git a/tools/pnnx/src/pass_level2/torch_matmul.cpp b/tools/pnnx/src/pass_level2/torch_matmul.cpp index a12f2c84b07..c1388ad7d4d 100644 --- a/tools/pnnx/src/pass_level2/torch_matmul.cpp +++ b/tools/pnnx/src/pass_level2/torch_matmul.cpp @@ -60,4 +60,39 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_matmul_onnx, 90) +class torch_matmul_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +tnn.MatMul op_0 2 1 input other out arg0=%weight_position +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.matmul"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int weight_position = captured_params.at("weight_position").i; + if (weight_position == 0) + { + // swap input and weight + std::swap(op->inputs[0], op->inputs[1]); + } + } + +protected: + mutable int weight_position; +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_matmul_tnn, 90) + } // namespace pnnx From ed0bdd01caae2cd22d3f3c8bc7c270ce4a6819b2 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 15:43:53 +0800 Subject: [PATCH 13/29] w --- tools/pnnx/src/ir.cpp | 6 +-- tools/pnnx/src/pass_level2/F_batch_norm.cpp | 47 +++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 5fcf9916f4b..09a30f66979 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2253,11 +2253,11 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { if (op->type == "Tensor.index_put" && it.first == "values") { - fprintf(pyfp, "torch.tensor(%f)", param.f); + fprintf(pyfp, "torch.tensor(%g)", param.f); } else { - fprintf(pyfp, "%f", param.f); + fprintf(pyfp, "%g", param.f); } } if (param.type == 4) @@ -2316,7 +2316,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { - fprintf(pyfp, "%f", param.af[i]); + fprintf(pyfp, "%g", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); } diff --git a/tools/pnnx/src/pass_level2/F_batch_norm.cpp b/tools/pnnx/src/pass_level2/F_batch_norm.cpp index f4ead03505b..4a82809ff23 100644 --- a/tools/pnnx/src/pass_level2/F_batch_norm.cpp +++ b/tools/pnnx/src/pass_level2/F_batch_norm.cpp @@ -121,4 +121,51 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx, 130) +class F_batch_norm_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data=(%num_features)f32 +pnnx.Attribute op_1 0 1 bias @data=(%num_features)f32 +tnn.BatchNormCxx op_2 3 1 input weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Attribute mean 0 1 running_mean +pnnx.Attribute var 0 1 running_var +pnnx.Attribute weight 0 1 weight @data=%op_0.data +pnnx.Attribute bias 0 1 bias @data=%op_1.data +F.batch_norm bn 5 1 input running_mean running_var weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + const int num_features = captured_params.at("num_features").i; + + Operator* op_mean = ops.at("mean"); + op_mean->attrs["data"] = Attribute({num_features}, std::vector(num_features, 0.f)); + + Operator* op_var = ops.at("var"); + op_var->attrs["data"] = Attribute({num_features}, std::vector(num_features, 1.f)); + + Operator* op_bn = ops.at("bn"); + op_bn->params["eps"] = 0.f; + op_bn->inputnames = {"input", "running_mean", "running_var", "weight", "bias"}; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_tnn, 130) + } // namespace pnnx From f15f050c84b51b068d7ec951dd808619755321a4 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 16:06:29 +0800 Subject: [PATCH 14/29] w --- tools/pnnx/src/pass_level2/torch_max.cpp | 22 ++++++++++-- tools/pnnx/src/pass_level2/torch_mean.cpp | 15 ++++++-- tools/pnnx/src/pass_level2/torch_min.cpp | 43 +++++++++++++++++++++++ 3 files changed, 74 insertions(+), 6 deletions(-) diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 280b7b371dc..deac519205d 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -190,7 +190,7 @@ class torch_max_tnn : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -tnn.ReduceMax op_0 1 1 input out arg0=%keepdims arg1=%dim +tnn.ReduceMax op_0 1 1 input out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -202,8 +202,24 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - op->params["dim"] = captured_params.at("dim"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + std::vector dim; + for (int i = 1; ; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + if (dim.size() == 1) + { + op->params["dim"] = dim[0]; + } + else + { + fprintf(stderr, "fallback to reduce max all\n"); + } + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; } }; diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index 38afb2a760d..e0aaf7e21cf 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -156,7 +156,7 @@ class torch_mean_tnn : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -tnn.ReduceMean op_0 1 1 input out arg0=%keepdims arg1=%dim +tnn.ReduceMean op_0 1 1 input out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -168,8 +168,17 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - op->params["dim"] = captured_params.at("dim"); - op->params["keepdim"] = captured_params.at("keepdims").i ? true : false; + std::vector dim; + for (int i = 1; ; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; } }; diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index 07248927d2e..ad8dbc6ac28 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -182,4 +182,47 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 50) +class torch_min_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMin op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1; ; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + if (dim.size() == 1) + { + op->params["dim"] = dim[0]; + } + else + { + fprintf(stderr, "fallback to reduce min all\n"); + } + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_tnn, 50) + } // namespace pnnx From 549e83f7ec27a8f1a213881d216145bbe18d04e2 Mon Sep 17 00:00:00 2001 From: nihui <171016+nihui@users.noreply.github.com> Date: Mon, 10 Feb 2025 08:08:22 +0000 Subject: [PATCH 15/29] apply code-format changes --- tools/pnnx/src/pass_level2/torch_max.cpp | 2 +- tools/pnnx/src/pass_level2/torch_mean.cpp | 2 +- tools/pnnx/src/pass_level2/torch_min.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index deac519205d..78dd72601b9 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -203,7 +203,7 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { std::vector dim; - for (int i = 1; ; i++) + for (int i = 1;; i++) { if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) break; diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index e0aaf7e21cf..45fc54023b8 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -169,7 +169,7 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { std::vector dim; - for (int i = 1; ; i++) + for (int i = 1;; i++) { if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) break; diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index ad8dbc6ac28..201604aff26 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -203,7 +203,7 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { std::vector dim; - for (int i = 1; ; i++) + for (int i = 1;; i++) { if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) break; From 4564fb36211b6830a7fe1aabf2a68d88acd8571a Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 16:47:15 +0800 Subject: [PATCH 16/29] w --- tools/pnnx/src/load_tnn.cpp | 84 ++++++++++++++++--- tools/pnnx/src/pass_level2/F_softmax.cpp | 21 +++++ .../src/pass_level2/torch_index_select.cpp | 53 ++++++++++++ 3 files changed, 147 insertions(+), 11 deletions(-) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index 7b6d699f2eb..aaf2950d733 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -139,7 +139,11 @@ static size_t type_to_elemsize(int type) static int get_tnn_tensor_type(int dt) { - if (dt == 0) return 1; + if (dt == 0) return 1;// fp32 + if (dt == 1) return 3;// fp16 + if (dt == 2) return 7;// int8 + if (dt == 3) return 4;// int32 + if (dt == 4) return 13;// bf16 fprintf(stderr, "unsupported tnn tensor type %d\n", dt); return 0; // unknown type @@ -203,12 +207,28 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // "1 57 1 4206624772 ," fgets(line, 4096, pp); - int blob_count = 57; + int blob_count = 0; unsigned int proto_magic = 4206624772; + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + sscanf(pline, "%*d %d %*d %u", &blob_count, &proto_magic); + if (proto_magic != 4206624772) + { + fprintf(stderr, "wrong magic %u\n", proto_magic); + } + + if (blob_count == 0) + { + fprintf(stderr, "wrong blob_count %d\n", blob_count); + } + } // "input 2 1 80000 0 ," fgets(line, 4096, pp); - if (proto_magic == 4206624772) { // strip leading and tail double quote line[strlen(line) - 2] = '\0'; @@ -309,7 +329,20 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // layer count // " 56 ," fgets(line, 4096, pp); - int layer_count = 56; + int layer_count = 0; + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + sscanf(pline, "%d", &layer_count); + + if (layer_count == 0) + { + fprintf(stderr, "wrong layer_count %d\n", layer_count); + } + } for (int i = 0; i < layer_count; i++) { @@ -480,7 +513,7 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) Operator* op = op_map.at(name); - int attribute_count = 0; + std::vector attrs; if (type == "Convolution1D" || type == "Convolution") { @@ -495,7 +528,11 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) int bias; fread(&bias, 1, sizeof(int), bp); - attribute_count = bias ? 2 : 1; + attrs.push_back(Attribute(bp)); + if (bias) + { + attrs.push_back(Attribute(bp)); + } } if (type == "InnerProduct") { @@ -506,26 +543,50 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) name2.resize(name2_size); fread((void*)name2.data(), 1, name2_size, bp); - attribute_count = 2; + attrs.push_back(Attribute(bp)); + attrs.push_back(Attribute(bp)); } if (type == "MatMul") { - attribute_count = 1; + attrs.push_back(Attribute(bp)); } if (type == "Add" || type == "Sub" || type == "Mul" || type == "Div") { - attribute_count = 1; + attrs.push_back(Attribute(bp)); } if (type == "BatchNormCxx") { - attribute_count = 2; + attrs.push_back(Attribute(bp)); + attrs.push_back(Attribute(bp)); } + if (type == "Gather") + { + // data_in_resource + int data_in_resource; + fread(&data_in_resource, 1, sizeof(int), bp); + + if (data_in_resource) + { + attrs.push_back(Attribute(bp)); + } + + // indices_in_resource + int indices_in_resource; + fread(&indices_in_resource, 1, sizeof(int), bp); + + if (indices_in_resource) + { + attrs.push_back(Attribute(bp)); + } + } + + const int attribute_count = (int)attrs.size(); for (int j = 0; j < attribute_count; j++) { Operator* op_scale = pnnx_graph.new_operator_before("pnnx.Attribute", name + "_attr" + std::to_string(j), op); Operand* r0 = pnnx_graph.new_operand(name + "_attr" + std::to_string(j)); - op_scale->attrs["data"] = Attribute(bp); + op_scale->attrs["data"] = attrs[j]; op_scale->outputs.push_back(r0); r0->producer = op_scale; r0->consumers.push_back(op); @@ -541,6 +602,7 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // unary if (op->type == "tnn.Log") op->type = "aten::log"; if (op->type == "tnn.ReLU") op->type = "aten::relu"; + if (op->type == "tnn.ReLU6") op->type = "aten::relu6"; if (op->type == "tnn.Sigmoid") op->type = "aten::sigmoid"; // binary diff --git a/tools/pnnx/src/pass_level2/F_softmax.cpp b/tools/pnnx/src/pass_level2/F_softmax.cpp index 5497b9e176f..d61213bf3bb 100644 --- a/tools/pnnx/src/pass_level2/F_softmax.cpp +++ b/tools/pnnx/src/pass_level2/F_softmax.cpp @@ -133,4 +133,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax_onnx_1, 100) +class F_softmax_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.SoftmaxCaffe op_0 1 1 input out arg0=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softmax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax_tnn, 100) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_index_select.cpp b/tools/pnnx/src/pass_level2/torch_index_select.cpp index 084554af8a4..d1c23d43f26 100644 --- a/tools/pnnx/src/pass_level2/torch_index_select.cpp +++ b/tools/pnnx/src/pass_level2/torch_index_select.cpp @@ -13,6 +13,7 @@ // specific language governing permissions and limitations under the License. #include "pass_level2.h" +#include namespace pnnx { @@ -39,4 +40,56 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_index_select, 70) +class torch_index_select_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 index @data=(?)i32 +tnn.Gather op_1 2 1 input index out arg0=%dim arg1=0 arg2=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute index 0 1 index +torch.index_select select 2 1 input index out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + const Attribute& index_data = captured_attrs.at("op_0.data"); + + // i32 to i64 + Operator* op_index = ops.at("index"); + const int* p = (const int*)index_data.data.data(); + const int n = index_data.data.size() / 4; + std::vector indices(n); + for (int i = 0; i < n; i++) + { + indices[i] = p[i]; + } + op_index->attrs["data"].type = 5;// i64 + op_index->attrs["data"].shape = {n}; + op_index->attrs["data"].data.resize(n * 8); + memcpy((void*)op_index->attrs["data"].data.data(), (const void*)indices.data(), n * 8); + + + Operator* op_gather = ops.at("select"); + op_gather->params["dim"] = captured_params.at("dim"); + op_gather->inputnames = {"input", "index"}; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_index_select_tnn, 70) + } // namespace pnnx From 28302af6500ace8e042acb7d9443b85843733165 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 16:54:33 +0800 Subject: [PATCH 17/29] w --- tools/pnnx/src/pass_level2/torch_index_select.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tools/pnnx/src/pass_level2/torch_index_select.cpp b/tools/pnnx/src/pass_level2/torch_index_select.cpp index d1c23d43f26..f9fe8eec67d 100644 --- a/tools/pnnx/src/pass_level2/torch_index_select.cpp +++ b/tools/pnnx/src/pass_level2/torch_index_select.cpp @@ -45,6 +45,9 @@ class torch_index_select_tnn : public GraphRewriterPass public: const char* match_pattern_graph() const { + // clang-format off + // *INDENT-OFF* + return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input @@ -52,6 +55,9 @@ pnnx.Attribute op_0 0 1 index @data=(?)i32 tnn.Gather op_1 2 1 input index out arg0=%dim arg1=0 arg2=1 pnnx.Output output 1 0 out )PNNXIR"; + + // *INDENT-ON* + // clang-format on } const char* replace_pattern_graph() const From 72b78d5a4c3ab015a0f4d637fdd03af868a95933 Mon Sep 17 00:00:00 2001 From: nihui <171016+nihui@users.noreply.github.com> Date: Mon, 10 Feb 2025 08:56:24 +0000 Subject: [PATCH 18/29] apply code-format changes --- tools/pnnx/src/load_tnn.cpp | 10 +++++----- tools/pnnx/src/pass_level2/torch_index_select.cpp | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index aaf2950d733..634abc1c87c 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -139,11 +139,11 @@ static size_t type_to_elemsize(int type) static int get_tnn_tensor_type(int dt) { - if (dt == 0) return 1;// fp32 - if (dt == 1) return 3;// fp16 - if (dt == 2) return 7;// int8 - if (dt == 3) return 4;// int32 - if (dt == 4) return 13;// bf16 + if (dt == 0) return 1; // fp32 + if (dt == 1) return 3; // fp16 + if (dt == 2) return 7; // int8 + if (dt == 3) return 4; // int32 + if (dt == 4) return 13; // bf16 fprintf(stderr, "unsupported tnn tensor type %d\n", dt); return 0; // unknown type diff --git a/tools/pnnx/src/pass_level2/torch_index_select.cpp b/tools/pnnx/src/pass_level2/torch_index_select.cpp index f9fe8eec67d..2bfc10558d4 100644 --- a/tools/pnnx/src/pass_level2/torch_index_select.cpp +++ b/tools/pnnx/src/pass_level2/torch_index_select.cpp @@ -84,12 +84,11 @@ pnnx.Output output 1 0 out { indices[i] = p[i]; } - op_index->attrs["data"].type = 5;// i64 + op_index->attrs["data"].type = 5; // i64 op_index->attrs["data"].shape = {n}; op_index->attrs["data"].data.resize(n * 8); memcpy((void*)op_index->attrs["data"].data.data(), (const void*)indices.data(), n * 8); - Operator* op_gather = ops.at("select"); op_gather->params["dim"] = captured_params.at("dim"); op_gather->inputnames = {"input", "index"}; From 8dd61993103c32ef0a4decb5064621d1e1b80ad1 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Feb 2025 17:50:07 +0800 Subject: [PATCH 19/29] w --- tools/pnnx/src/load_tnn.cpp | 143 ++++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index aaf2950d733..8b2fd6fd205 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -207,7 +207,6 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) // "1 57 1 4206624772 ," fgets(line, 4096, pp); - int blob_count = 0; unsigned int proto_magic = 4206624772; { // strip leading and tail double quote @@ -215,16 +214,11 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) line[0] = '\0'; const char* pline = line + 1; - sscanf(pline, "%*d %d %*d %u", &blob_count, &proto_magic); + sscanf(pline, "%*d %*d %*d %u", &proto_magic); if (proto_magic != 4206624772) { fprintf(stderr, "wrong magic %u\n", proto_magic); } - - if (blob_count == 0) - { - fprintf(stderr, "wrong blob_count %d\n", blob_count); - } } // "input 2 1 80000 0 ," @@ -270,37 +264,14 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) op->outputs.push_back(r); } - // all operand names - // " 108 109 110 111 112 113 114 116 118 119 120 125 126 128 130 131 132 133 135 136 138 139 142 144 145 147 148 151 153 154 156 157 160 162 163 165 166 169 171 172 174 175 178 180 181 183 184 188 189 190 191 192 194 85 clipwise_output embedding input ," + // skip the very long operand names + // " 108 109 ........ clipwise_output embedding input ," + fscanf(pp, "%*[^,]"); fgets(line, 4096, pp); - { - // strip leading and tail double quote - line[strlen(line) - 2] = '\0'; - line[0] = '\0'; - const char* pline = line + 1; - - int ncomsumed = 0; - - for (int i = 0; i < blob_count; i++) - { - char blob_name[32]; - sscanf(pline, "%s%n", blob_name, &ncomsumed); - - pline += ncomsumed; - - // fprintf(stderr, "blob %s\n", blob_name); - - if (!pnnx_graph.get_operand(blob_name)) - { - pnnx_graph.new_operand(blob_name); - } - } - } // all output names // "clipwise_output embedding ," fgets(line, 4096, pp); - std::vector output_names; { // strip leading and tail double quote @@ -320,7 +291,7 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) if (strcmp(blob_name, ",") == 0) break; - // fprintf(stderr, "blob %s\n", blob_name); + fprintf(stderr, "blob %s\n", blob_name); output_names.push_back(blob_name); } @@ -380,7 +351,15 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) Operand* r = pnnx_graph.get_operand(blob_name); if (!r) { - fprintf(stderr, "%s bottom %s not found\n", layer_name, blob_name); + // insert constant producer + Operator* op_constant = pnnx_graph.new_operator_before("pnnx.Attribute", blob_name, op); + + r = pnnx_graph.new_operand(blob_name); + + // op_constant->attrs["data"] = attrs[j]; + op_constant->outputs.push_back(r); + + r->producer = op_constant; } r->consumers.push_back(op); op->inputs.push_back(r); @@ -398,7 +377,7 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) Operand* r = pnnx_graph.get_operand(blob_name); if (!r) { - fprintf(stderr, "%s top %s not found\n", layer_name, blob_name); + r = pnnx_graph.new_operand(blob_name); } r->producer = op; op->outputs.push_back(r); @@ -515,6 +494,19 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) std::vector attrs; + if (type == "Add" || type == "Sub" || type == "Mul" || type == "Div") + { + attrs.push_back(Attribute(bp)); + } + if (type == "BatchNormCxx") + { + attrs.push_back(Attribute(bp)); + attrs.push_back(Attribute(bp)); + } + if (type == "ConstantOfShape") + { + attrs.push_back(Attribute(bp)); + } if (type == "Convolution1D" || type == "Convolution") { // skip name2 == name @@ -534,31 +526,6 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) attrs.push_back(Attribute(bp)); } } - if (type == "InnerProduct") - { - // skip name2 == name - int name2_size; - std::string name2; - fread(&name2_size, 1, sizeof(int), bp); - name2.resize(name2_size); - fread((void*)name2.data(), 1, name2_size, bp); - - attrs.push_back(Attribute(bp)); - attrs.push_back(Attribute(bp)); - } - if (type == "MatMul") - { - attrs.push_back(Attribute(bp)); - } - if (type == "Add" || type == "Sub" || type == "Mul" || type == "Div") - { - attrs.push_back(Attribute(bp)); - } - if (type == "BatchNormCxx") - { - attrs.push_back(Attribute(bp)); - attrs.push_back(Attribute(bp)); - } if (type == "Gather") { // data_in_resource @@ -579,21 +546,67 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) attrs.push_back(Attribute(bp)); } } + if (type == "InnerProduct") + { + // skip name2 == name + int name2_size; + std::string name2; + fread(&name2_size, 1, sizeof(int), bp); + name2.resize(name2_size); + fread((void*)name2.data(), 1, name2_size, bp); + + attrs.push_back(Attribute(bp)); + attrs.push_back(Attribute(bp)); + } + if (type == "MatMul") + { + attrs.push_back(Attribute(bp)); + } const int attribute_count = (int)attrs.size(); for (int j = 0; j < attribute_count; j++) { - Operator* op_scale = pnnx_graph.new_operator_before("pnnx.Attribute", name + "_attr" + std::to_string(j), op); + Operator* op_constant = pnnx_graph.new_operator_before("pnnx.Attribute", name + "_attr" + std::to_string(j), op); Operand* r0 = pnnx_graph.new_operand(name + "_attr" + std::to_string(j)); - op_scale->attrs["data"] = attrs[j]; - op_scale->outputs.push_back(r0); - r0->producer = op_scale; + op_constant->attrs["data"] = attrs[j]; + op_constant->outputs.push_back(r0); + r0->producer = op_constant; r0->consumers.push_back(op); op->inputs.push_back(r0); } } + // magic 0xfabc0004 + // unsigned int model_magic; + fread(&model_magic, 1, sizeof(unsigned int), bp); + if (model_magic != 0xfabc0004) + { + fprintf(stderr, "model_magic %x failed\n", model_magic); + return -1; + } + + int constant_count = 0; + fread(&constant_count, 1, sizeof(int), bp); + + fprintf(stderr, "constant_count = %d\n", constant_count); + + // collect constants + for (int i = 0; i < constant_count; i++) + { + int name_size; + std::string name; + fread(&name_size, 1, sizeof(int), bp); + name.resize(name_size); + fread((void*)name.data(), 1, name_size, bp); + + fprintf(stderr, "model constant %s\n", name.c_str()); + + Operator* op_constant = op_map.at(name); + + op_constant->attrs["data"] = Attribute(bp); + } + fclose(bp); // replace simple operator From 51e99e8c6a5a83c875a8c4d7e035bc8ce4764207 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 11 Feb 2025 11:08:21 +0800 Subject: [PATCH 20/29] w --- tools/pnnx/src/CMakeLists.txt | 2 + tools/pnnx/src/load_tnn.cpp | 5 + tools/pnnx/src/pass_level2/Tensor_reshape.cpp | 66 +++++++++ .../pass_tnn/fuse_shape_list_construct.cpp | 126 ++++++++++++++++++ .../src/pass_tnn/fuse_shape_list_construct.h | 25 ++++ tools/pnnx/src/pass_tnn/fuse_shape_size.cpp | 71 ++++++++++ tools/pnnx/src/pass_tnn/fuse_shape_size.h | 25 ++++ 7 files changed, 320 insertions(+) create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_size.cpp create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_size.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 3cb849ff213..5b7c8c7bc0b 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -717,6 +717,8 @@ endif() if(PNNX_TNN2PNNX) set(pnnx_pass_tnn_SRCS + pass_tnn/fuse_shape_size.cpp + pass_tnn/fuse_shape_list_construct.cpp pass_tnn/lower_convolution_activation.cpp pass_tnn/lower_power.cpp ) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index 70b95a1d7f9..786c523cc9d 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -20,6 +20,8 @@ #include #include +#include "pass_tnn/fuse_shape_size.h" +#include "pass_tnn/fuse_shape_list_construct.h" #include "pass_tnn/lower_convolution_activation.h" #include "pass_tnn/lower_power.h" @@ -625,6 +627,9 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) if (op->type == "tnn.Div") op->type = "aten::div"; } + tnn2pnnx::fuse_shape_size(pnnx_graph); + tnn2pnnx::fuse_shape_list_construct(pnnx_graph); + tnn2pnnx::lower_convolution_activation(pnnx_graph); tnn2pnnx::lower_power(pnnx_graph); diff --git a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp index 4261ed0a467..c197173968e 100644 --- a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp @@ -123,4 +123,70 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_2, 61) +class Tensor_reshape_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Reshape op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int axis = captured_params.at("op_0.arg0").i; + const int num_axes = captured_params.at("op_0.arg1").i; + const int shape_rank = captured_params.at("op_0.arg2").i; + + std::vector shape(shape_rank); + for (int i = 0; i < shape_rank; i++) + { + shape[i] = captured_params.at("op_0.arg" + std::to_string(i + 3)).i; + } + + const int reshape_type = captured_params.at("op_0.arg" + std::to_string(shape_rank + 3)).i; + + // HACK + if (shape == std::vector{0, -1, 0, 0}) + { + shape = {-1}; + } + + op->params["shape"] = shape; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn, 60) + +class Tensor_reshape_tnn_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Input shape 0 1 shape +tnn.Reshape op_0 2 1 input shape out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn_1, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp new file mode 100644 index 00000000000..c8db926033f --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp @@ -0,0 +1,126 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_shape_list_construct.h" + +#include + +namespace pnnx { + +namespace tnn2pnnx { + +void fuse_shape_list_construct(Graph& graph) +{ + // TODO unpool tnn.Unsqueeze + + // a0 = pnnx.Attribute @data=(1)i32 + // a1 = tnn.Unsqueeze(..., arg0=1, arg1=0) + // y = tnn.Concat(a0, a1, ..., arg0=0) + // tnn.Reshape(x, y, args=...) / tnn.ConstantOfShape(y) + + // prim::ListConstruct (a0, a1, ...) + // tnn.Reshape(x, y) / tnn.ConstantOfShape(y) + + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "tnn.Concat") + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + if (op2->type == "tnn.Reshape") + { + if (op2->inputs.size() != 2) + continue; + + if (op2->inputs[1] != op->outputs[0]) + continue; + } + else if (op2->type == "tnn.ConstantOfShape") + { + if (op2->inputs[0] != op->outputs[0]) + continue; + } + else + { + continue; + } + + matched = true; + + fprintf(stderr, "match concat + reshape/constantofshape\n"); + + op->type = "prim::ListConstruct"; + + // drop tnn.Unsqueeze between aten::size and prim::ListConstruct + + const size_t count = op->inputs.size(); + for (size_t j = 0; j < count; j++) + { + Operand* r = op->inputs[j]; + + if (r->producer->type != "tnn.Unsqueeze") + continue; + + Operator* op_uqz = r->producer; + + Operand* r0 = op_uqz->inputs[0]; + + if (r0->producer->type != "aten::size") + continue; + + // drop tnn.Unsqueeze + + r0->remove_consumer(op_uqz); + r->remove_consumer(op); + + op->inputs[j] = r0; + r0->consumers.push_back(op); + + if (r->consumers.empty()) + { + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), r)); + delete r; + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_uqz)); + delete op_uqz; + } + } + + if (op2->type == "tnn.Reshape") + { + // drop tnn.Reshape args + op2->params.clear(); + } + + break; + } + + if (!matched) + break; + } + +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h new file mode 100644 index 00000000000..55d72e6305c --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void fuse_shape_list_construct(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_size.cpp b/tools/pnnx/src/pass_tnn/fuse_shape_size.cpp new file mode 100644 index 00000000000..062c8bb2b65 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_size.cpp @@ -0,0 +1,71 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_shape_size.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +class fuse_shape_size_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +tnn.Shape op_0 1 1 input a +pnnx.Attribute op_1 0 1 index @data=(1)i32 +tnn.Gather op_2 2 1 a index out arg0=0 arg1=0 arg2=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +prim::Constant index 0 1 index +aten::size size 2 1 input index out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + const Attribute& index_data = captured_attrs.at("op_1.data"); + const int index = ((const int*)index_data.data.data())[0]; + + Operator* op_index = ops.at("index"); + op_index->params["value"] = index; + } +}; + +void fuse_shape_size(Graph& graph) +{ + // TODO unpool tnn.Shape + + fuse_shape_size_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_size.h b/tools/pnnx/src/pass_tnn/fuse_shape_size.h new file mode 100644 index 00000000000..2058af6438e --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_size.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void fuse_shape_size(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx From 0af0a57c3c62186f095c568813b65f49eb2e4c0c Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 11 Feb 2025 11:26:34 +0800 Subject: [PATCH 21/29] w --- tools/pnnx/src/pass_level2/F_conv1d.cpp | 3 ++ tools/pnnx/src/pass_level2/F_conv2d.cpp | 3 ++ tools/pnnx/src/pass_level2/Tensor_reshape.cpp | 52 ++++++++++++++++++- .../pass_tnn/lower_convolution_activation.cpp | 6 +++ 4 files changed, 62 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/pass_level2/F_conv1d.cpp b/tools/pnnx/src/pass_level2/F_conv1d.cpp index 78df32b5d60..c7b2858173c 100644 --- a/tools/pnnx/src/pass_level2/F_conv1d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv1d.cpp @@ -217,6 +217,9 @@ pnnx.Output output 1 0 out bool match(const std::map& captured_params) const { + if (captured_params.find("op_0.arg9") == captured_params.end()) + return true; + const int activation = captured_params.at("op_0.arg9").i; return activation == 0; } diff --git a/tools/pnnx/src/pass_level2/F_conv2d.cpp b/tools/pnnx/src/pass_level2/F_conv2d.cpp index f650b55f414..dfee3b253e1 100644 --- a/tools/pnnx/src/pass_level2/F_conv2d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv2d.cpp @@ -327,6 +327,9 @@ pnnx.Output output 1 0 out bool match(const std::map& captured_params) const { + if (captured_params.find("op_0.arg13") == captured_params.end()) + return true; + const int activation = captured_params.at("op_0.arg13").i; return activation == 0; } diff --git a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp index c197173968e..a9238d3e441 100644 --- a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp @@ -175,8 +175,8 @@ class Tensor_reshape_tnn_1 : public GraphRewriterPass return R"PNNXIR(7767517 4 3 pnnx.Input input 0 1 input -pnnx.Input shape 0 1 shape -tnn.Reshape op_0 2 1 input shape out +pnnx.Attribute shape 0 1 shape @data +tnn.Reshape op_0 2 1 input shape out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -185,8 +185,56 @@ pnnx.Output output 1 0 out { return "Tensor.reshape"; } + + bool match(const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + // one dim i32 + const auto& shape_data = captured_attrs.at("shape.data"); + return shape_data.shape.size() == 1 && shape_data.type == 4; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + const auto& shape_data = captured_attrs.at("shape.data"); + const int* p = (const int*)shape_data.data.data(); + const int ndim = shape_data.data.size() / 4; + + std::vector shape(ndim); + for (int i = 0; i < ndim; i++) + { + shape[i] = p[i]; + } + + op->params["shape"] = shape; + } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn_1, 60) +class Tensor_reshape_tnn_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Input shape 0 1 shape +tnn.Reshape op_0 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + void write(Operator* /*op*/, const std::map& /*captured_params*/) const + { + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn_2, 61) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp b/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp index 05da8f8e4d6..a4d76c557ce 100644 --- a/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp +++ b/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp @@ -77,6 +77,9 @@ pnnx.Output output 1 0 out bool match(const std::map& captured_params) const { + if (captured_params.find("op_0.arg13") == captured_params.end()) + return false; + this->activation = captured_params.at("op_0.arg13").i; return activation != 0; } @@ -205,6 +208,9 @@ pnnx.Output output 1 0 out bool match(const std::map& captured_params) const { + if (captured_params.find("op_0.arg9") == captured_params.end()) + return false; + this->activation = captured_params.at("op_0.arg9").i; return activation != 0; } From 4dc0e2c21257de0cad9f3e6e8160760bb310a01f Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 11 Feb 2025 11:38:19 +0800 Subject: [PATCH 22/29] w --- tools/pnnx/src/pass_level2/torch_matmul.cpp | 16 ++++----- tools/pnnx/src/pass_level2/torch_sum.cpp | 36 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/tools/pnnx/src/pass_level2/torch_matmul.cpp b/tools/pnnx/src/pass_level2/torch_matmul.cpp index c1388ad7d4d..62127a777a2 100644 --- a/tools/pnnx/src/pass_level2/torch_matmul.cpp +++ b/tools/pnnx/src/pass_level2/torch_matmul.cpp @@ -69,7 +69,7 @@ class torch_matmul_tnn : public GraphRewriterPass 4 3 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 other -tnn.MatMul op_0 2 1 input other out arg0=%weight_position +tnn.MatMul op_0 2 1 input other out %*=%* pnnx.Output output 1 0 out )PNNXIR"; } @@ -81,16 +81,16 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params) const { - const int weight_position = captured_params.at("weight_position").i; - if (weight_position == 0) + if (captured_params.find("op_0.arg0") != captured_params.end()) { - // swap input and weight - std::swap(op->inputs[0], op->inputs[1]); + const int weight_position = captured_params.at("op_0.arg0").i; + if (weight_position == 0) + { + // swap input and weight + std::swap(op->inputs[0], op->inputs[1]); + } } } - -protected: - mutable int weight_position; }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_matmul_tnn, 90) diff --git a/tools/pnnx/src/pass_level2/torch_sum.cpp b/tools/pnnx/src/pass_level2/torch_sum.cpp index 280cdff98aa..66ea74f3c7e 100644 --- a/tools/pnnx/src/pass_level2/torch_sum.cpp +++ b/tools/pnnx/src/pass_level2/torch_sum.cpp @@ -111,4 +111,40 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_onnx, 50) +class torch_sum_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceSum op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.sum"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_tnn, 50) + } // namespace pnnx From 1d1c3394915e12cc1fe38a7ce8796e3ef7f25d6e Mon Sep 17 00:00:00 2001 From: nihui <171016+nihui@users.noreply.github.com> Date: Tue, 11 Feb 2025 03:56:30 +0000 Subject: [PATCH 23/29] apply code-format changes --- tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp index c8db926033f..ad50c472079 100644 --- a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp +++ b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp @@ -118,7 +118,6 @@ void fuse_shape_list_construct(Graph& graph) if (!matched) break; } - } } // namespace tnn2pnnx From a360149a628ca925eb326be27bda6df257a30d81 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 11 Feb 2025 20:17:33 +0800 Subject: [PATCH 24/29] w --- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/load_tnn.cpp | 5 + tools/pnnx/src/pass_level2/Tensor_expand.cpp | 22 ++ .../pnnx/src/pass_level2/Tensor_expand_as.cpp | 28 ++ tools/pnnx/src/pass_level2/Tensor_slice.cpp | 74 +++++ tools/pnnx/src/pass_level2/nn_LSTM.cpp | 281 ++++++++++++++++++ tools/pnnx/src/pass_level2/torch_full.cpp | 28 ++ tools/pnnx/src/pass_level2/torch_norm.cpp | 37 +++ .../pass_tnn/fuse_shape_list_construct.cpp | 12 +- tools/pnnx/src/pass_tnn/lower_concat.cpp | 63 ++++ tools/pnnx/src/pass_tnn/lower_concat.h | 25 ++ 11 files changed, 575 insertions(+), 1 deletion(-) create mode 100644 tools/pnnx/src/pass_tnn/lower_concat.cpp create mode 100644 tools/pnnx/src/pass_tnn/lower_concat.h diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 5b7c8c7bc0b..3e84a4e9abf 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -719,6 +719,7 @@ if(PNNX_TNN2PNNX) set(pnnx_pass_tnn_SRCS pass_tnn/fuse_shape_size.cpp pass_tnn/fuse_shape_list_construct.cpp + pass_tnn/lower_concat.cpp pass_tnn/lower_convolution_activation.cpp pass_tnn/lower_power.cpp ) diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index 786c523cc9d..4363d55786e 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -22,6 +22,7 @@ #include "pass_tnn/fuse_shape_size.h" #include "pass_tnn/fuse_shape_list_construct.h" +#include "pass_tnn/lower_concat.h" #include "pass_tnn/lower_convolution_activation.h" #include "pass_tnn/lower_power.h" @@ -625,6 +626,8 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) if (op->type == "tnn.Sub") op->type = "aten::sub"; if (op->type == "tnn.Mul") op->type = "aten::mul"; if (op->type == "tnn.Div") op->type = "aten::div"; + + // misc } tnn2pnnx::fuse_shape_size(pnnx_graph); @@ -634,6 +637,8 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) tnn2pnnx::lower_power(pnnx_graph); + tnn2pnnx::lower_concat(pnnx_graph); + return 0; } diff --git a/tools/pnnx/src/pass_level2/Tensor_expand.cpp b/tools/pnnx/src/pass_level2/Tensor_expand.cpp index 0930bc51204..f3a7d0275e5 100644 --- a/tools/pnnx/src/pass_level2/Tensor_expand.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_expand.cpp @@ -109,4 +109,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_onnx, 60) +class Tensor_expand_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shape +tnn.Expand op_0 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.expand"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_tnn, 61) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp b/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp index ca726b89faf..85557468445 100644 --- a/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp @@ -38,4 +38,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_as, 60) +class Tensor_expand_as_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +tnn.Shape op_0 1 1 other shape +tnn.Expand op_1 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.expand_as"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params.clear(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_as_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_slice.cpp b/tools/pnnx/src/pass_level2/Tensor_slice.cpp index b610b29a256..8fd684c9d84 100644 --- a/tools/pnnx/src/pass_level2/Tensor_slice.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_slice.cpp @@ -150,4 +150,78 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice_onnx_1, 70) +class Tensor_slice_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.StridedSliceV2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int nbegins = captured_params.at("op_0.arg0").i; + std::vector begins(nbegins); + for (int i = 0; i < nbegins; i++) + { + begins[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + const int nends = captured_params.at("op_0.arg" + std::to_string(nbegins + 1)).i; + std::vector ends(nends); + for (int i = 0; i < nends; i++) + { + ends[i] = captured_params.at("op_0.arg" + std::to_string(i + 2 + nbegins)).i; + } + const int naxes = captured_params.at("op_0.arg" + std::to_string(nbegins + nends + 2)).i; + std::vector axes(naxes); + for (int i = 0; i < naxes; i++) + { + axes[i] = captured_params.at("op_0.arg" + std::to_string(i + 3 + nbegins + nends)).i; + } + + std::vector strides; + if (captured_params.find("op_0.arg" + std::to_string(nbegins + nends + naxes + 3)) != captured_params.end()) + { + const int nstrides = captured_params.at("op_0.arg" + std::to_string(nbegins + nends + naxes + 3)).i; + strides.resize(nstrides); + for (int i = 0; i < nstrides; i++) + { + strides[i] = captured_params.at("op_0.arg" + std::to_string(i + 4 + nbegins + nends + naxes)).i; + } + } + else + { + strides.resize(naxes, 1); + } + + if (axes.size() == 1) + { + op->params["dim"] = axes[0]; + op->params["start"] = begins[0]; + op->params["end"] = ends[0]; + op->params["step"] = strides[0]; + } + else + { + op->params["dims"] = axes; + op->params["starts"] = begins; + op->params["ends"] = ends; + op->params["steps"] = strides; + op->params["selects"] = std::vector(axes.size(), INT_MAX); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice_tnn, 70) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/nn_LSTM.cpp b/tools/pnnx/src/pass_level2/nn_LSTM.cpp index 56ac7ab981c..e89347696a9 100644 --- a/tools/pnnx/src/pass_level2/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_level2/nn_LSTM.cpp @@ -629,4 +629,285 @@ pnnx.Output output 1 0 out2 REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B5, 140) +class nn_LSTM_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +tnn.LSTMONNX lstm 6 3 input W R B initial_h initial_c out outh outc %*=%* +pnnx.Output output 3 0 out outh outc +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.LSTM"; + } + + const char* name_str() const + { + return "lstm"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + // arg0=0, arg1=512, arg2=2 + + // captured_params.at("lstm.arg0"); // clip_threshold + + const int hidden_size = captured_params.at("lstm.arg1").i; + const int direction = captured_params.at("lstm.arg2").i; + + const int num_directions = direction == 2 ? 2 : 1; + + const auto& W = captured_attrs.at("W.data"); // 2,2048,512 + const auto& R = captured_attrs.at("R.data"); // 2,2048,512 + const auto& B = captured_attrs.at("B.data"); // 2,4096 + + if (W.shape.size() != 3 || W.shape[0] != num_directions || W.shape[1] != 4 * hidden_size) + return false; + + if (R.shape.size() != 3 || R.shape[0] != num_directions || R.shape[1] != 4 * hidden_size || R.shape[2] != hidden_size) + return false; + + if (B.shape.size() != 2 || B.shape[0] != num_directions || B.shape[1] != 8 * hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const int hidden_size = captured_params.at("lstm.arg1").i; + const int direction = captured_params.at("lstm.arg2").i; + + const int num_directions = direction == 2 ? 2 : 1; + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + const auto& B = captured_attrs.at("B.data"); + + const int input_size = W.shape[2]; + + op->params["input_size"] = input_size; + op->params["hidden_size"] = hidden_size; + op->params["num_layers"] = 1; + op->params["bias"] = false; + op->params["batch_first"] = false; + op->params["bidirectional"] = direction == 2 ? true : false; + op->params["proj_size"] = 0; + + // split W R and reorder IOFG to IFGO + auto W_data = W.get_float32_data(); + auto R_data = R.get_float32_data(); + + std::vector W2(4 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* iptr = (const float*)W_data.data(); + const float* optr = (const float*)W_data.data() + weight_data_size_g; + const float* fptr = (const float*)W_data.data() + weight_data_size_g * 2; + const float* gptr = (const float*)W_data.data() + weight_data_size_g * 3; + + float* w_iptr = (float*)W2.data(); + float* w_fptr = (float*)W2.data() + weight_data_size_g; + float* w_gptr = (float*)W2.data() + weight_data_size_g * 2; + float* w_optr = (float*)W2.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + std::vector R2(4 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* iptr = (const float*)R_data.data(); + const float* optr = (const float*)R_data.data() + weight_data_size_g; + const float* fptr = (const float*)R_data.data() + weight_data_size_g * 2; + const float* gptr = (const float*)R_data.data() + weight_data_size_g * 3; + + float* w_iptr = (float*)R2.data(); + float* w_fptr = (float*)R2.data() + weight_data_size_g; + float* w_gptr = (float*)R2.data() + weight_data_size_g * 2; + float* w_optr = (float*)R2.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + if (direction == 2) + { + op->attrs["weight_ih_l0"] = Attribute({4 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); + + std::vector W2R(4 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* iptr = (const float*)W_data.data() + weight_data_size_g * 4; + const float* optr = (const float*)W_data.data() + weight_data_size_g * 5; + const float* fptr = (const float*)W_data.data() + weight_data_size_g * 6; + const float* gptr = (const float*)W_data.data() + weight_data_size_g * 7; + + float* w_iptr = (float*)W2R.data(); + float* w_fptr = (float*)W2R.data() + weight_data_size_g; + float* w_gptr = (float*)W2R.data() + weight_data_size_g * 2; + float* w_optr = (float*)W2R.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + std::vector R2R(4 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* iptr = (const float*)R_data.data() + weight_data_size_g * 4; + const float* optr = (const float*)R_data.data() + weight_data_size_g * 5; + const float* fptr = (const float*)R_data.data() + weight_data_size_g * 6; + const float* gptr = (const float*)R_data.data() + weight_data_size_g * 7; + + float* w_iptr = (float*)R2R.data(); + float* w_fptr = (float*)R2R.data() + weight_data_size_g; + float* w_gptr = (float*)R2R.data() + weight_data_size_g * 2; + float* w_optr = (float*)R2R.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + op->attrs["weight_ih_l0_reverse"] = Attribute({4 * hidden_size, input_size}, W2R); + op->attrs["weight_hh_l0_reverse"] = Attribute({4 * hidden_size, hidden_size}, R2R); + } + else + { + op->attrs["weight_ih_l0"] = Attribute({4 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); + } + + + bool has_bias = false; + for (auto b : B.get_float32_data()) + { + if (b != 0.f) + { + has_bias = true; + break; + } + } + + op->params["bias"] = has_bias; + + if (has_bias) + { + // split B and reorder IOFG to IFGO + auto B_data = B.get_float32_data(); + + std::vector B2(4 * hidden_size); + std::vector B3(4 * hidden_size); + { + const float* iptr = (const float*)B_data.data(); + const float* optr = (const float*)B_data.data() + hidden_size; + const float* fptr = (const float*)B_data.data() + hidden_size * 2; + const float* gptr = (const float*)B_data.data() + hidden_size * 3; + + float* w_iptr = (float*)B2.data(); + float* w_fptr = (float*)B2.data() + hidden_size; + float* w_gptr = (float*)B2.data() + hidden_size * 2; + float* w_optr = (float*)B2.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + { + const float* iptr = (const float*)B_data.data() + hidden_size * 4; + const float* optr = (const float*)B_data.data() + hidden_size * 5; + const float* fptr = (const float*)B_data.data() + hidden_size * 6; + const float* gptr = (const float*)B_data.data() + hidden_size * 7; + + float* w_iptr = (float*)B3.data(); + float* w_fptr = (float*)B3.data() + hidden_size; + float* w_gptr = (float*)B3.data() + hidden_size * 2; + float* w_optr = (float*)B3.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + + if (direction == 2) + { + op->attrs["bias_ih_l0"] = Attribute({4 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({4 * hidden_size}, B3); + + std::vector B2R(4 * hidden_size); + std::vector B3R(4 * hidden_size); + { + const float* iptr = (const float*)B_data.data() + hidden_size * 8; + const float* optr = (const float*)B_data.data() + hidden_size * 9; + const float* fptr = (const float*)B_data.data() + hidden_size * 10; + const float* gptr = (const float*)B_data.data() + hidden_size * 11; + + float* w_iptr = (float*)B2R.data(); + float* w_fptr = (float*)B2R.data() + hidden_size; + float* w_gptr = (float*)B2R.data() + hidden_size * 2; + float* w_optr = (float*)B2R.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + { + const float* iptr = (const float*)B_data.data() + hidden_size * 12; + const float* optr = (const float*)B_data.data() + hidden_size * 13; + const float* fptr = (const float*)B_data.data() + hidden_size * 14; + const float* gptr = (const float*)B_data.data() + hidden_size * 15; + + float* w_iptr = (float*)B3R.data(); + float* w_fptr = (float*)B3R.data() + hidden_size; + float* w_gptr = (float*)B3R.data() + hidden_size * 2; + float* w_optr = (float*)B3R.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + + op->attrs["bias_ih_l0_reverse"] = Attribute({4 * hidden_size}, B2R); + op->attrs["bias_hh_l0_reverse"] = Attribute({4 * hidden_size}, B3R); + } + else + { + op->attrs["bias_ih_l0"] = Attribute({4 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({4 * hidden_size}, B3); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_tnn, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_full.cpp b/tools/pnnx/src/pass_level2/torch_full.cpp index ffea79b38de..e8cfa7f43df 100644 --- a/tools/pnnx/src/pass_level2/torch_full.cpp +++ b/tools/pnnx/src/pass_level2/torch_full.cpp @@ -93,4 +93,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_onnx, 21) +class torch_full_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 size +pnnx.Attribute value 0 1 value @data=(1)f32 +tnn.ConstantOfShape op_0 2 1 size value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.full"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0]; + op->params["dtype"] = "torch.float"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_tnn, 21) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_norm.cpp b/tools/pnnx/src/pass_level2/torch_norm.cpp index 57b54cb1328..c1980d1c6b5 100644 --- a/tools/pnnx/src/pass_level2/torch_norm.cpp +++ b/tools/pnnx/src/pass_level2/torch_norm.cpp @@ -144,4 +144,41 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_norm_fro, 90) REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_norm_fro_dims, 90) +class torch_norm_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceL2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.norm"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + op->params["p"] = 2; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_norm_tnn, 90) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp index ad50c472079..b2688852efa 100644 --- a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp +++ b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp @@ -60,6 +60,11 @@ void fuse_shape_list_construct(Graph& graph) if (op2->inputs[0] != op->outputs[0]) continue; } + else if (op2->type == "tnn.Expand") + { + if (op2->inputs[1] != op->outputs[0]) + continue; + } else { continue; @@ -67,7 +72,7 @@ void fuse_shape_list_construct(Graph& graph) matched = true; - fprintf(stderr, "match concat + reshape/constantofshape\n"); + fprintf(stderr, "match concat + reshape/constantofshape/expand\n"); op->type = "prim::ListConstruct"; @@ -111,6 +116,11 @@ void fuse_shape_list_construct(Graph& graph) // drop tnn.Reshape args op2->params.clear(); } + if (op2->type == "tnn.Expand") + { + // drop tnn.Expand args + op2->params.clear(); + } break; } diff --git a/tools/pnnx/src/pass_tnn/lower_concat.cpp b/tools/pnnx/src/pass_tnn/lower_concat.cpp new file mode 100644 index 00000000000..5a9cb1d662e --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_concat.cpp @@ -0,0 +1,63 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "lower_concat.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_concat(Graph& graph) +{ + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "tnn.Concat") + continue; + + const int dim = op->params["arg0"].i; + + op->type = "aten::cat"; + op->params.clear(); + op->params["dim"] = dim; + + // insert listconstruct for inputs + Operator* op0 = graph.new_operator_before("prim::ListConstruct", op->name + "_lc", op); + Operand* r = graph.new_operand(op->name + "_lc"); + + r->producer = op0; + r->consumers.push_back(op); + + op0->outputs.push_back(r); + + for (size_t j = 0; j < op->inputs.size(); j++) + { + Operand* x = op->inputs[j]; + + x->remove_consumer(op); + x->consumers.push_back(op0); + op0->inputs.push_back(x); + } + + op->inputs.clear(); + op->inputs.push_back(r); + } +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_concat.h b/tools/pnnx/src/pass_tnn/lower_concat.h new file mode 100644 index 00000000000..84fd76a542b --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_concat.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_concat(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx From fb159d9d96303a1b1ee42cff0e3e47ee71785f9a Mon Sep 17 00:00:00 2001 From: nihui <171016+nihui@users.noreply.github.com> Date: Tue, 11 Feb 2025 12:19:35 +0000 Subject: [PATCH 25/29] apply code-format changes --- tools/pnnx/src/pass_level2/nn_LSTM.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tools/pnnx/src/pass_level2/nn_LSTM.cpp b/tools/pnnx/src/pass_level2/nn_LSTM.cpp index e89347696a9..6b1134f039b 100644 --- a/tools/pnnx/src/pass_level2/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_level2/nn_LSTM.cpp @@ -803,7 +803,6 @@ pnnx.Output output 3 0 out outh outc op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); } - bool has_bias = false; for (auto b : B.get_float32_data()) { From 117ea4e40635c41cfbe48c9a602c567ff3a3c6b8 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 12 Feb 2025 14:06:40 +0800 Subject: [PATCH 26/29] w --- tools/pnnx/src/ir.cpp | 2 + tools/pnnx/src/load_tnn.cpp | 10 +++++ tools/pnnx/src/pass_level2/F_embedding.cpp | 28 +++++++++++++ tools/pnnx/src/pass_level2/Tensor_to.cpp | 39 +++++++++++++++++++ tools/pnnx/src/pass_level2/torch_full.cpp | 2 +- .../pnnx/src/pass_level2/torch_full_like.cpp | 29 ++++++++++++++ tools/pnnx/src/pass_level2/torch_squeeze.cpp | 39 +++++++++++++++++++ .../pnnx/src/pass_level3/fuse_expression.cpp | 3 ++ .../pnnx/src/pass_level5/eval_expression.cpp | 6 +++ .../pnnx/src/pass_ncnn/expand_expression.cpp | 2 + 10 files changed, 159 insertions(+), 1 deletion(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 09a30f66979..e79a09826d6 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1030,6 +1030,7 @@ static std::string expand_expression(const Operator* op) || t == "ceil" || t == "cos" || t == "cosh" + || t == "erf" || t == "exp" || t == "floor" || t == "log" @@ -1062,6 +1063,7 @@ static std::string expand_expression(const Operator* op) if (t == "ceil") unaryop = "torch.ceil"; if (t == "cos") unaryop = "torch.cos"; if (t == "cosh") unaryop = "torch.cosh"; + if (t == "erf") unaryop = "torch.erf"; if (t == "exp") unaryop = "torch.exp"; if (t == "floor") unaryop = "torch.floor"; if (t == "log") unaryop = "torch.log"; diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp index 4363d55786e..2e0d8bb683c 100644 --- a/tools/pnnx/src/load_tnn.cpp +++ b/tools/pnnx/src/load_tnn.cpp @@ -605,6 +605,13 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) fprintf(stderr, "model constant %s\n", name.c_str()); + if (op_map.find(name) == op_map.end()) + { + // FIXME + Attribute skip(bp); + continue; + } + Operator* op_constant = op_map.at(name); op_constant->attrs["data"] = Attribute(bp); @@ -616,10 +623,13 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) for (Operator* op : pnnx_graph.ops) { // unary + if (op->type == "tnn.Erf") op->type = "aten::erf"; if (op->type == "tnn.Log") op->type = "aten::log"; if (op->type == "tnn.ReLU") op->type = "aten::relu"; if (op->type == "tnn.ReLU6") op->type = "aten::relu6"; if (op->type == "tnn.Sigmoid") op->type = "aten::sigmoid"; + if (op->type == "tnn.Sqrt") op->type = "aten::sqrt"; + if (op->type == "tnn.Tanh") op->type = "aten::tanh"; // binary if (op->type == "tnn.Add") op->type = "aten::add"; diff --git a/tools/pnnx/src/pass_level2/F_embedding.cpp b/tools/pnnx/src/pass_level2/F_embedding.cpp index 810f6b98cac..d869d897a7d 100644 --- a/tools/pnnx/src/pass_level2/F_embedding.cpp +++ b/tools/pnnx/src/pass_level2/F_embedding.cpp @@ -69,4 +69,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_onnx, 140) +class F_embedding_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Gather op_0 2 1 input weight out arg0=0 arg1=1 arg2=0 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.embedding"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["scale_grad_by_freq"] = false; + op->params["sparse"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_tnn, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_to.cpp b/tools/pnnx/src/pass_level2/Tensor_to.cpp index 81da30bdd68..2780a529693 100644 --- a/tools/pnnx/src/pass_level2/Tensor_to.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_to.cpp @@ -162,4 +162,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_onnx, 60) +class Tensor_to_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Cast op_0 1 1 input out arg0=%to +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.to"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int to = captured_params.at("to").i; + + op->params["non_blocking"] = false; + op->params["copy"] = false; + op->params["memory_format"] = "torch.preserve_format"; + + if (to == 0) op->params["dtype"] = "torch.float"; + if (to == 1) op->params["dtype"] = "torch.half"; + if (to == 2) op->params["dtype"] = "torch.int8"; + if (to == 3) op->params["dtype"] = "torch.int"; + if (to == 4) op->params["dtype"] = "torch.bfloat16"; + if (to == 5) op->params["dtype"] = "torch.long"; + if (to == 6) op->params["dtype"] = "torch.uint32"; + if (to == 8) op->params["dtype"] = "torch.uint8"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_full.cpp b/tools/pnnx/src/pass_level2/torch_full.cpp index e8cfa7f43df..1e77c60a930 100644 --- a/tools/pnnx/src/pass_level2/torch_full.cpp +++ b/tools/pnnx/src/pass_level2/torch_full.cpp @@ -112,7 +112,7 @@ pnnx.Output output 1 0 out return "torch.full"; } - void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const { op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0]; op->params["dtype"] = "torch.float"; diff --git a/tools/pnnx/src/pass_level2/torch_full_like.cpp b/tools/pnnx/src/pass_level2/torch_full_like.cpp index 3b016611763..20e2cec3004 100644 --- a/tools/pnnx/src/pass_level2/torch_full_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_full_like.cpp @@ -67,4 +67,33 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like, 20) +class torch_full_like_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +tnn.Shape op_0 1 1 input shape +pnnx.Attribute value 0 1 value @data=(1)f32 +tnn.ConstantOfShape op_2 2 1 shape value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.full_like"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0]; + op->params["dtype"] = "torch.float"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like_tnn, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_squeeze.cpp b/tools/pnnx/src/pass_level2/torch_squeeze.cpp index e79a2849144..0fdabe82144 100644 --- a/tools/pnnx/src/pass_level2/torch_squeeze.cpp +++ b/tools/pnnx/src/pass_level2/torch_squeeze.cpp @@ -137,4 +137,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_onnx_1, 60) +class torch_squeeze_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Squeeze op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.squeeze"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int dims_count = captured_params.at("op_0.arg0").i; + if (dims_count == 1) + { + op->params["dim"] = captured_params.at("op_0.arg1").i; + } + else + { + std::vector dims(dims_count); + for (int i = 0; i < dims_count; i++) + { + dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["dim"] = dims; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index 4b9f3f83929..dcb3d1f1365 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -135,6 +135,7 @@ static bool operand_maybe_tensor(const Operand* operand) || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" + || op->type == "aten::erf" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" @@ -648,6 +649,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" + || op->type == "aten::erf" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" @@ -888,6 +890,7 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan || op->type == "aten::cos" || op->type == "aten::cosh" || op->type == "aten::div" + || op->type == "aten::erf" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::floor_divide" diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index c7d5d5d0226..441cfca1c36 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -175,6 +175,7 @@ static std::string eval_expression(const Operator* op) || t == "ceil" || t == "cos" || t == "cosh" + || t == "erf" || t == "exp" || t == "floor" || t == "log" @@ -262,6 +263,11 @@ static std::string eval_expression(const Operator* op) float r = cosh(af); exprstack.push(std::to_string(r)); } + if (t == "erf") + { + float r = erf(af); + exprstack.push(std::to_string(r)); + } if (t == "exp") { float r = exp(af); diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index 2fdc6d77d62..63efb4f8e47 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -123,6 +123,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx || t == "atan" || t == "ceil" || t == "cos" + || t == "erf" || t == "exp" || t == "floor" || t == "log" @@ -154,6 +155,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx if (t == "ceil") op_unary->params["0"] = 3; if (t == "cos") op_unary->params["0"] = 10; if (t == "exp") op_unary->params["0"] = 7; + if (t == "erf") fprintf(stderr, "UnaryOp erf not supported yet\n"); // TODO if (t == "floor") op_unary->params["0"] = 2; if (t == "log") op_unary->params["0"] = 8; if (t == "log10") op_unary->params["0"] = 17; From 2a49b41c639d16fcfde6716c87b7c528c41f440a Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 12 Feb 2025 14:23:22 +0800 Subject: [PATCH 27/29] w --- tools/pnnx/src/main.cpp | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 07dd9c5f033..35211a2fa15 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -176,6 +176,34 @@ static bool model_file_maybe_torchscript(const std::string& path) return signature == 0x04034b50; } +static bool model_file_maybe_tnnproto(const std::string& path) +{ + FILE* fp = fopen(path.c_str(), "rb"); + if (!fp) + { + fprintf(stderr, "open failed %s\n", path.c_str()); + return false; + } + + char line[256]; + char* s = fgets(line, 256, fp); + if (!s) + { + fclose(fp); + return false; + } + + uint32_t signature = 0; + if (line[0] == '\"') + { + sscanf(line + 1, "%*d %*d %*d %d", &signature); + } + + fclose(fp); + + return signature == 4206624772; +} + static void show_usage() { fprintf(stderr, "Usage: pnnx [model.pt] [(key=value)...]\n"); @@ -316,12 +344,16 @@ int main(int argc, char** argv) std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip"; pnnx::Graph pnnx_graph; + + // clang-format off + // *INDENT-OFF* + #if BUILD_TNN2PNNX - if (1) + if (model_file_maybe_tnnproto(ptpath)) { - fprintf(stderr, "TODO distinguish tnnproto file\n"); load_tnn(ptpath, pnnx_graph); } + else #endif #if BUILD_ONNX2PNNX if (!model_file_maybe_torchscript(ptpath)) @@ -340,6 +372,9 @@ int main(int argc, char** argv) foldable_constants_zippath, foldable_constants); } + // *INDENT-ON* + // clang-format on + fprintf(stderr, "############# pass_level2\n"); pnnx::pass_level2(pnnx_graph); From f0b2c3d26d47b5f16e249a4523f83da65c78204b Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 12 Feb 2025 14:47:57 +0800 Subject: [PATCH 28/29] build --- tools/pnnx/src/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 3e84a4e9abf..e41c206abf5 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -731,6 +731,7 @@ if(PNNX_TNN2PNNX) add_library(tnn2pnnx OBJECT ${tnn2pnnx_SRCS}) target_compile_definitions(tnn2pnnx PRIVATE BUILD_TNN2PNNX) + target_compile_options(tnn2pnnx PUBLIC "${TORCH_CXX_FLAGS}") message(STATUS "Building with tnn2pnnx") else() From b11db185534471aca5a2ddb1bb97f512f301ef63 Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 12 Feb 2025 16:33:30 +0800 Subject: [PATCH 29/29] w --- tools/pnnx/src/ir.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index e79a09826d6..0fd83233175 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -2255,11 +2255,17 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { if (op->type == "Tensor.index_put" && it.first == "values") { - fprintf(pyfp, "torch.tensor(%g)", param.f); + if (param.f == (int)param.f) + fprintf(pyfp, "torch.tensor(%.1f)", param.f); + else + fprintf(pyfp, "torch.tensor(%g)", param.f); } else { - fprintf(pyfp, "%g", param.f); + if (param.f == (int)param.f) + fprintf(pyfp, "%.1f", param.f); + else + fprintf(pyfp, "%g", param.f); } } if (param.type == 4) @@ -2318,7 +2324,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { - fprintf(pyfp, "%g", param.af[i]); + if (param.af[i] == (int)param.af[i]) + fprintf(pyfp, "%.1f", param.af[i]); + else + fprintf(pyfp, "%g", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); }