From 5ea7681e70f43750e2066deb19c282c8161001e1 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 12 Feb 2025 18:22:05 +0800 Subject: [PATCH] tnn2pnnx (#5898) --- tools/pnnx/CMakeLists.txt | 2 + tools/pnnx/src/CMakeLists.txt | 28 + tools/pnnx/src/ir.cpp | 17 +- tools/pnnx/src/ir.h | 3 + tools/pnnx/src/load_tnn.cpp | 655 ++++++++++++++++++ tools/pnnx/src/load_tnn.h | 26 + tools/pnnx/src/main.cpp | 45 ++ .../src/pass_level2/F_adaptive_avg_pool2d.cpp | 77 ++ .../src/pass_level2/F_adaptive_max_pool2d.cpp | 77 ++ tools/pnnx/src/pass_level2/F_avg_pool2d.cpp | 71 ++ tools/pnnx/src/pass_level2/F_batch_norm.cpp | 47 ++ tools/pnnx/src/pass_level2/F_conv1d.cpp | 74 ++ tools/pnnx/src/pass_level2/F_conv2d.cpp | 75 ++ tools/pnnx/src/pass_level2/F_embedding.cpp | 28 + tools/pnnx/src/pass_level2/F_linear.cpp | 23 + tools/pnnx/src/pass_level2/F_max_pool2d.cpp | 70 ++ tools/pnnx/src/pass_level2/F_pad.cpp | 58 ++ tools/pnnx/src/pass_level2/F_softmax.cpp | 21 + tools/pnnx/src/pass_level2/Tensor_expand.cpp | 22 + .../pnnx/src/pass_level2/Tensor_expand_as.cpp | 28 + tools/pnnx/src/pass_level2/Tensor_permute.cpp | 32 + tools/pnnx/src/pass_level2/Tensor_reshape.cpp | 114 +++ tools/pnnx/src/pass_level2/Tensor_slice.cpp | 74 ++ tools/pnnx/src/pass_level2/Tensor_to.cpp | 39 ++ tools/pnnx/src/pass_level2/nn_LSTM.cpp | 280 ++++++++ tools/pnnx/src/pass_level2/torch_clamp.cpp | 21 + tools/pnnx/src/pass_level2/torch_full.cpp | 28 + .../pnnx/src/pass_level2/torch_full_like.cpp | 29 + .../src/pass_level2/torch_index_select.cpp | 58 ++ tools/pnnx/src/pass_level2/torch_matmul.cpp | 35 + tools/pnnx/src/pass_level2/torch_max.cpp | 43 ++ tools/pnnx/src/pass_level2/torch_mean.cpp | 36 + tools/pnnx/src/pass_level2/torch_min.cpp | 43 ++ tools/pnnx/src/pass_level2/torch_norm.cpp | 37 + tools/pnnx/src/pass_level2/torch_squeeze.cpp | 39 ++ tools/pnnx/src/pass_level2/torch_sum.cpp | 36 + .../pnnx/src/pass_level2/torch_unsqueeze.cpp | 39 ++ .../pnnx/src/pass_level3/fuse_expression.cpp | 3 + .../pnnx/src/pass_level5/eval_expression.cpp | 6 + .../pnnx/src/pass_ncnn/expand_expression.cpp | 2 + .../pass_tnn/fuse_shape_list_construct.cpp | 135 ++++ .../src/pass_tnn/fuse_shape_list_construct.h | 25 + tools/pnnx/src/pass_tnn/fuse_shape_size.cpp | 71 ++ tools/pnnx/src/pass_tnn/fuse_shape_size.h | 25 + tools/pnnx/src/pass_tnn/lower_concat.cpp | 63 ++ tools/pnnx/src/pass_tnn/lower_concat.h | 25 + .../pass_tnn/lower_convolution_activation.cpp | 301 ++++++++ .../pass_tnn/lower_convolution_activation.h | 25 + tools/pnnx/src/pass_tnn/lower_power.cpp | 62 ++ tools/pnnx/src/pass_tnn/lower_power.h | 25 + 50 files changed, 3195 insertions(+), 3 deletions(-) create mode 100644 tools/pnnx/src/load_tnn.cpp create mode 100644 tools/pnnx/src/load_tnn.h create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_size.cpp create mode 100644 tools/pnnx/src/pass_tnn/fuse_shape_size.h create mode 100644 tools/pnnx/src/pass_tnn/lower_concat.cpp create mode 100644 tools/pnnx/src/pass_tnn/lower_concat.h create mode 100644 tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp create mode 100644 tools/pnnx/src/pass_tnn/lower_convolution_activation.h create mode 100644 tools/pnnx/src/pass_tnn/lower_power.cpp create mode 100644 tools/pnnx/src/pass_tnn/lower_power.h diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index b09f4758ead..e50ab4788c3 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -123,6 +123,8 @@ else() set(onnxruntime_FOUND FALSE) endif() +option(PNNX_TNN2PNNX "build tnn2pnnx" ON) + add_subdirectory(src) enable_testing() diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index b1ac6f5c024..e41c206abf5 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -715,6 +715,29 @@ else() message(STATUS "Building without onnx2pnnx") endif() +if(PNNX_TNN2PNNX) + set(pnnx_pass_tnn_SRCS + pass_tnn/fuse_shape_size.cpp + pass_tnn/fuse_shape_list_construct.cpp + pass_tnn/lower_concat.cpp + pass_tnn/lower_convolution_activation.cpp + pass_tnn/lower_power.cpp + ) + + set(tnn2pnnx_SRCS + ${pnnx_pass_tnn_SRCS} + load_tnn.cpp + ) + + add_library(tnn2pnnx OBJECT ${tnn2pnnx_SRCS}) + target_compile_definitions(tnn2pnnx PRIVATE BUILD_TNN2PNNX) + target_compile_options(tnn2pnnx PUBLIC "${TORCH_CXX_FLAGS}") + + message(STATUS "Building with tnn2pnnx") +else() + message(STATUS "Building without tnn2pnnx") +endif() + if(NOT MSVC) add_definitions(-Wall -Wextra) endif() @@ -765,6 +788,11 @@ if(onnxruntime_FOUND) target_link_libraries(pnnx PRIVATE onnx2pnnx) endif() +if(PNNX_TNN2PNNX) + set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TNN2PNNX) + target_link_libraries(pnnx PRIVATE tnn2pnnx) +endif() + if(PNNX_COVERAGE) target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage) target_link_libraries(pnnx PUBLIC -coverage -lgcov) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 5fcf9916f4b..0fd83233175 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1030,6 +1030,7 @@ static std::string expand_expression(const Operator* op) || t == "ceil" || t == "cos" || t == "cosh" + || t == "erf" || t == "exp" || t == "floor" || t == "log" @@ -1062,6 +1063,7 @@ static std::string expand_expression(const Operator* op) if (t == "ceil") unaryop = "torch.ceil"; if (t == "cos") unaryop = "torch.cos"; if (t == "cosh") unaryop = "torch.cosh"; + if (t == "erf") unaryop = "torch.erf"; if (t == "exp") unaryop = "torch.exp"; if (t == "floor") unaryop = "torch.floor"; if (t == "log") unaryop = "torch.log"; @@ -2253,11 +2255,17 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) { if (op->type == "Tensor.index_put" && it.first == "values") { - fprintf(pyfp, "torch.tensor(%f)", param.f); + if (param.f == (int)param.f) + fprintf(pyfp, "torch.tensor(%.1f)", param.f); + else + fprintf(pyfp, "torch.tensor(%g)", param.f); } else { - fprintf(pyfp, "%f", param.f); + if (param.f == (int)param.f) + fprintf(pyfp, "%.1f", param.f); + else + fprintf(pyfp, "%g", param.f); } } if (param.type == 4) @@ -2316,7 +2324,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, "("); for (size_t i = 0; i < param.af.size(); i++) { - fprintf(pyfp, "%f", param.af[i]); + if (param.af[i] == (int)param.af[i]) + fprintf(pyfp, "%.1f", param.af[i]); + else + fprintf(pyfp, "%g", param.af[i]); if (i + 1 != param.af.size() || param.af.size() == 1) fprintf(pyfp, ","); } diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 779c2eec9f1..6ab52eb0a21 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -234,6 +234,9 @@ class Attribute #if BUILD_ONNX2PNNX Attribute(const onnx::TensorProto& t); #endif +#if BUILD_TNN2PNNX + Attribute(FILE* bp); +#endif Attribute(const std::initializer_list& shape, const std::vector& t); diff --git a/tools/pnnx/src/load_tnn.cpp b/tools/pnnx/src/load_tnn.cpp new file mode 100644 index 00000000000..2e0d8bb683c --- /dev/null +++ b/tools/pnnx/src/load_tnn.cpp @@ -0,0 +1,655 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "load_tnn.h" + +#include "ir.h" + +#include +#include +#include + +#include "pass_tnn/fuse_shape_size.h" +#include "pass_tnn/fuse_shape_list_construct.h" +#include "pass_tnn/lower_concat.h" +#include "pass_tnn/lower_convolution_activation.h" +#include "pass_tnn/lower_power.h" + +namespace pnnx { + +static bool vstr_is_float(const char vstr[16]) +{ + // look ahead for determine isfloat + for (int j = 0; j < 16; j++) + { + if (vstr[j] == '\0') + break; + + if (vstr[j] == '.' || tolower(vstr[j]) == 'e') + return true; + } + + return false; +} + +static float vstr_to_float(const char vstr[16]) +{ + double v = 0.0; + + const char* p = vstr; + + // sign + bool sign = *p != '-'; + if (*p == '+' || *p == '-') + { + p++; + } + + // digits before decimal point or exponent + unsigned int v1 = 0; + while (isdigit(*p)) + { + v1 = v1 * 10 + (*p - '0'); + p++; + } + + v = (double)v1; + + // digits after decimal point + if (*p == '.') + { + p++; + + unsigned int pow10 = 1; + unsigned int v2 = 0; + + while (isdigit(*p)) + { + v2 = v2 * 10 + (*p - '0'); + pow10 *= 10; + p++; + } + + v += v2 / (double)pow10; + } + + // exponent + if (*p == 'e' || *p == 'E') + { + p++; + + // sign of exponent + bool fact = *p != '-'; + if (*p == '+' || *p == '-') + { + p++; + } + + // digits of exponent + unsigned int expon = 0; + while (isdigit(*p)) + { + expon = expon * 10 + (*p - '0'); + p++; + } + + double scale = 1.0; + while (expon >= 8) + { + scale *= 1e8; + expon -= 8; + } + while (expon > 0) + { + scale *= 10.0; + expon -= 1; + } + + v = fact ? v * scale : v / scale; + } + + // fprintf(stderr, "v = %f\n", v); + return sign ? (float)v : (float)-v; +} + +static size_t type_to_elemsize(int type) +{ + if (type == 1) return 4; + if (type == 2) return 8; + if (type == 3) return 2; + if (type == 4) return 4; + if (type == 5) return 8; + if (type == 6) return 2; + if (type == 7) return 1; + if (type == 8) return 1; + if (type == 9) return 1; + if (type == 10) return 8; + if (type == 11) return 16; + if (type == 12) return 4; + return 0; // null +} + +static int get_tnn_tensor_type(int dt) +{ + if (dt == 0) return 1; // fp32 + if (dt == 1) return 3; // fp16 + if (dt == 2) return 7; // int8 + if (dt == 3) return 4; // int32 + if (dt == 4) return 13; // bf16 + + fprintf(stderr, "unsupported tnn tensor type %d\n", dt); + return 0; // unknown type +} + +Attribute::Attribute(FILE* bp) +{ + unsigned int magic; + int datatype; + int length; + int ndim; + fread(&magic, 1, sizeof(unsigned int), bp); + fread(&datatype, 1, sizeof(int), bp); + fread(&length, 1, sizeof(int), bp); + fread(&ndim, 1, sizeof(int), bp); + + type = get_tnn_tensor_type(datatype); + + if (ndim == 0) + { + shape = {1}; + + data.resize(type_to_elemsize(type)); + + // assert length == type_to_elemsize(type) + fread((void*)data.data(), 1, length, bp); + + return; + } + + shape.resize(ndim); + for (int i = 0; i < ndim; i++) + { + fread(&shape[i], 1, sizeof(int), bp); + } + + data.resize(elemcount() * type_to_elemsize(type)); + + // assert length == elemcount() * type_to_elemsize(type) + fread((void*)data.data(), 1, length, bp); +} + +int load_tnn(const std::string& tnnpath, Graph& pnnx_graph) +{ + fprintf(stderr, "############# pass_level0 tnn\n"); + + // generate proto and model path + std::string tnnprotopath = tnnpath; + std::string tnnmodelpath = tnnpath.substr(0, tnnpath.size() - 8) + "tnnmodel"; + + fprintf(stderr, "load_tnn %s %s\n", tnnprotopath.c_str(), tnnmodelpath.c_str()); + + FILE* pp = fopen(tnnprotopath.c_str(), "rb"); + if (!pp) + { + fprintf(stderr, "fopen %s failed\n", tnnprotopath.c_str()); + return -1; + } + + char line[4096]; + + // "1 57 1 4206624772 ," + fgets(line, 4096, pp); + unsigned int proto_magic = 4206624772; + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + sscanf(pline, "%*d %*d %*d %u", &proto_magic); + if (proto_magic != 4206624772) + { + fprintf(stderr, "wrong magic %u\n", proto_magic); + } + } + + // "input 2 1 80000 0 ," + fgets(line, 4096, pp); + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + // input operand name + // rank 2 + // shape (1, 80000) + // datatype 0=fp32 + + int ncomsumed = 0; + char blob_name[32]; + int rank = 0; + sscanf(pline, "%s %d%n", blob_name, &rank, &ncomsumed); + + pline += ncomsumed; + + std::vector shape(rank); + for (int i = 0; i < rank; i++) + { + sscanf(pline, "%d%n", &shape[i], &ncomsumed); + + pline += ncomsumed; + } + + int datatype = 0; + sscanf(pline, "%d%n", &datatype, &ncomsumed); + + Operator* op = pnnx_graph.new_operator("pnnx.Input", "input0"); + + Operand* r = pnnx_graph.new_operand(blob_name); + + r->producer = op; + + r->shape = shape; + r->type = get_tnn_tensor_type(datatype); + + op->outputs.push_back(r); + } + + // skip the very long operand names + // " 108 109 ........ clipwise_output embedding input ," + fscanf(pp, "%*[^,]"); + fgets(line, 4096, pp); + + // all output names + // "clipwise_output embedding ," + fgets(line, 4096, pp); + std::vector output_names; + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + int ncomsumed = 0; + + while (1) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + if (strcmp(blob_name, ",") == 0) + break; + + fprintf(stderr, "blob %s\n", blob_name); + + output_names.push_back(blob_name); + } + } + + // layer count + // " 56 ," + fgets(line, 4096, pp); + int layer_count = 0; + { + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + sscanf(pline, "%d", &layer_count); + + if (layer_count == 0) + { + fprintf(stderr, "wrong layer_count %d\n", layer_count); + } + } + + for (int i = 0; i < layer_count; i++) + { + // "Unsqueeze Unsqueeze_0 1 1 input 85 1 1 ," + fgets(line, 4096, pp); + + // strip leading and tail double quote + line[strlen(line) - 2] = '\0'; + line[0] = '\0'; + const char* pline = line + 1; + + int ncomsumed = 0; + + char layer_type[32]; + char layer_name[32]; + int bottom_count; + int top_count; + sscanf(pline, "%s %s %d %d%n", layer_type, layer_name, &bottom_count, &top_count, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, "%s %s %d %d\n", layer_type, layer_name, bottom_count, top_count); + + Operator* op = pnnx_graph.new_operator(std::string("tnn.") + layer_type, layer_name); + + for (int j = 0; j < bottom_count; j++) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, " bottom %s\n", blob_name); + + Operand* r = pnnx_graph.get_operand(blob_name); + if (!r) + { + // insert constant producer + Operator* op_constant = pnnx_graph.new_operator_before("pnnx.Attribute", blob_name, op); + + r = pnnx_graph.new_operand(blob_name); + + // op_constant->attrs["data"] = attrs[j]; + op_constant->outputs.push_back(r); + + r->producer = op_constant; + } + r->consumers.push_back(op); + op->inputs.push_back(r); + } + + for (int j = 0; j < top_count; j++) + { + char blob_name[32]; + sscanf(pline, "%s%n", blob_name, &ncomsumed); + + pline += ncomsumed; + + // fprintf(stderr, " top %s\n", blob_name); + + Operand* r = pnnx_graph.get_operand(blob_name); + if (!r) + { + r = pnnx_graph.new_operand(blob_name); + } + r->producer = op; + op->outputs.push_back(r); + } + + // layer specific data + // Unsqueeze 1 1 , + // Convolution1D 1 1 257 512 160 0 0 0 -1 1 0 , + + int param_id = 0; + while (1) + { + char vstr[16]; + sscanf(pline, "%s%n", vstr, &ncomsumed); + + pline += ncomsumed; + + if (strcmp(vstr, ",") == 0) + break; + + // fprintf(stderr, "vstr %s\n", vstr); + + bool is_float = vstr_is_float(vstr); + + if (is_float) + { + float v = vstr_to_float(vstr); + + op->params[std::string("arg") + std::to_string(param_id)] = v; + } + else + { + int v = 0; + int nscan = sscanf(vstr, "%d", &v); + if (nscan == 1) + { + op->params[std::string("arg") + std::to_string(param_id)] = v; + } + else + { + // fallback to string type + op->params[std::string("arg") + std::to_string(param_id)] = vstr; + } + } + + param_id++; + } + } + + // append output nodes + const int output_count = (int)output_names.size(); + for (int i = 0; i < output_count; i++) + { + Operator* op = pnnx_graph.new_operator("pnnx.Output", "output" + std::to_string(i)); + + Operand* r = pnnx_graph.get_operand(output_names[i]); + + r->consumers.push_back(op); + + // fprintf(stderr, "r->name = %s\n", r->name.c_str()); + + op->inputs.push_back(r); + } + + fclose(pp); + + FILE* bp = fopen(tnnmodelpath.c_str(), "rb"); + if (!bp) + { + fprintf(stderr, "fopen %s failed\n", tnnmodelpath.c_str()); + return -1; + } + + // magic 0xfabc0004 + unsigned int model_magic; + fread(&model_magic, 1, sizeof(unsigned int), bp); + if (model_magic != 0xfabc0004) + { + fprintf(stderr, "model_magic %x failed\n", model_magic); + return -1; + } + + int weight_count = 0; + fread(&weight_count, 1, sizeof(int), bp); + + fprintf(stderr, "weight_count = %d\n", weight_count); + + std::unordered_map op_map; + for (auto x : pnnx_graph.ops) + { + op_map[x->name] = x; + } + + for (int i = 0; i < weight_count; i++) + { + int opid; + fread(&opid, 1, sizeof(int), bp); + + int type_size; + std::string type; + fread(&type_size, 1, sizeof(int), bp); + type.resize(type_size); + fread((void*)type.data(), 1, type_size, bp); + + int name_size; + std::string name; + fread(&name_size, 1, sizeof(int), bp); + name.resize(name_size); + fread((void*)name.data(), 1, name_size, bp); + + fprintf(stderr, "model %d %s %s\n", opid, type.c_str(), name.c_str()); + + Operator* op = op_map.at(name); + + std::vector attrs; + + if (type == "Add" || type == "Sub" || type == "Mul" || type == "Div") + { + attrs.push_back(Attribute(bp)); + } + if (type == "BatchNormCxx") + { + attrs.push_back(Attribute(bp)); + attrs.push_back(Attribute(bp)); + } + if (type == "ConstantOfShape") + { + attrs.push_back(Attribute(bp)); + } + if (type == "Convolution1D" || type == "Convolution") + { + // skip name2 == name + int name2_size; + std::string name2; + fread(&name2_size, 1, sizeof(int), bp); + name2.resize(name2_size); + fread((void*)name2.data(), 1, name2_size, bp); + + // bias + int bias; + fread(&bias, 1, sizeof(int), bp); + + attrs.push_back(Attribute(bp)); + if (bias) + { + attrs.push_back(Attribute(bp)); + } + } + if (type == "Gather") + { + // data_in_resource + int data_in_resource; + fread(&data_in_resource, 1, sizeof(int), bp); + + if (data_in_resource) + { + attrs.push_back(Attribute(bp)); + } + + // indices_in_resource + int indices_in_resource; + fread(&indices_in_resource, 1, sizeof(int), bp); + + if (indices_in_resource) + { + attrs.push_back(Attribute(bp)); + } + } + if (type == "InnerProduct") + { + // skip name2 == name + int name2_size; + std::string name2; + fread(&name2_size, 1, sizeof(int), bp); + name2.resize(name2_size); + fread((void*)name2.data(), 1, name2_size, bp); + + attrs.push_back(Attribute(bp)); + attrs.push_back(Attribute(bp)); + } + if (type == "MatMul") + { + attrs.push_back(Attribute(bp)); + } + + const int attribute_count = (int)attrs.size(); + + for (int j = 0; j < attribute_count; j++) + { + Operator* op_constant = pnnx_graph.new_operator_before("pnnx.Attribute", name + "_attr" + std::to_string(j), op); + Operand* r0 = pnnx_graph.new_operand(name + "_attr" + std::to_string(j)); + op_constant->attrs["data"] = attrs[j]; + op_constant->outputs.push_back(r0); + r0->producer = op_constant; + r0->consumers.push_back(op); + op->inputs.push_back(r0); + } + } + + // magic 0xfabc0004 + // unsigned int model_magic; + fread(&model_magic, 1, sizeof(unsigned int), bp); + if (model_magic != 0xfabc0004) + { + fprintf(stderr, "model_magic %x failed\n", model_magic); + return -1; + } + + int constant_count = 0; + fread(&constant_count, 1, sizeof(int), bp); + + fprintf(stderr, "constant_count = %d\n", constant_count); + + // collect constants + for (int i = 0; i < constant_count; i++) + { + int name_size; + std::string name; + fread(&name_size, 1, sizeof(int), bp); + name.resize(name_size); + fread((void*)name.data(), 1, name_size, bp); + + fprintf(stderr, "model constant %s\n", name.c_str()); + + if (op_map.find(name) == op_map.end()) + { + // FIXME + Attribute skip(bp); + continue; + } + + Operator* op_constant = op_map.at(name); + + op_constant->attrs["data"] = Attribute(bp); + } + + fclose(bp); + + // replace simple operator + for (Operator* op : pnnx_graph.ops) + { + // unary + if (op->type == "tnn.Erf") op->type = "aten::erf"; + if (op->type == "tnn.Log") op->type = "aten::log"; + if (op->type == "tnn.ReLU") op->type = "aten::relu"; + if (op->type == "tnn.ReLU6") op->type = "aten::relu6"; + if (op->type == "tnn.Sigmoid") op->type = "aten::sigmoid"; + if (op->type == "tnn.Sqrt") op->type = "aten::sqrt"; + if (op->type == "tnn.Tanh") op->type = "aten::tanh"; + + // binary + if (op->type == "tnn.Add") op->type = "aten::add"; + if (op->type == "tnn.Sub") op->type = "aten::sub"; + if (op->type == "tnn.Mul") op->type = "aten::mul"; + if (op->type == "tnn.Div") op->type = "aten::div"; + + // misc + } + + tnn2pnnx::fuse_shape_size(pnnx_graph); + tnn2pnnx::fuse_shape_list_construct(pnnx_graph); + + tnn2pnnx::lower_convolution_activation(pnnx_graph); + + tnn2pnnx::lower_power(pnnx_graph); + + tnn2pnnx::lower_concat(pnnx_graph); + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/load_tnn.h b/tools/pnnx/src/load_tnn.h new file mode 100644 index 00000000000..6b7075e8555 --- /dev/null +++ b/tools/pnnx/src/load_tnn.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_LOAD_TNN_H +#define PNNX_LOAD_TNN_H + +#include "ir.h" + +namespace pnnx { + +int load_tnn(const std::string& tnnpath, Graph& g); + +} // namespace pnnx + +#endif // PNNX_LOAD_TNN_H diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 8c5c83b891c..35211a2fa15 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -32,6 +32,9 @@ #if BUILD_ONNX2PNNX #include "load_onnx.h" #endif +#if BUILD_TNN2PNNX +#include "load_tnn.h" +#endif #include "pass_ncnn.h" #include "save_ncnn.h" @@ -173,6 +176,34 @@ static bool model_file_maybe_torchscript(const std::string& path) return signature == 0x04034b50; } +static bool model_file_maybe_tnnproto(const std::string& path) +{ + FILE* fp = fopen(path.c_str(), "rb"); + if (!fp) + { + fprintf(stderr, "open failed %s\n", path.c_str()); + return false; + } + + char line[256]; + char* s = fgets(line, 256, fp); + if (!s) + { + fclose(fp); + return false; + } + + uint32_t signature = 0; + if (line[0] == '\"') + { + sscanf(line + 1, "%*d %*d %*d %d", &signature); + } + + fclose(fp); + + return signature == 4206624772; +} + static void show_usage() { fprintf(stderr, "Usage: pnnx [model.pt] [(key=value)...]\n"); @@ -313,6 +344,17 @@ int main(int argc, char** argv) std::string foldable_constants_zippath = ptbase + ".foldable_constants.zip"; pnnx::Graph pnnx_graph; + + // clang-format off + // *INDENT-OFF* + +#if BUILD_TNN2PNNX + if (model_file_maybe_tnnproto(ptpath)) + { + load_tnn(ptpath, pnnx_graph); + } + else +#endif #if BUILD_ONNX2PNNX if (!model_file_maybe_torchscript(ptpath)) { @@ -330,6 +372,9 @@ int main(int argc, char** argv) foldable_constants_zippath, foldable_constants); } + // *INDENT-ON* + // clang-format on + fprintf(stderr, "############# pass_level2\n"); pnnx::pass_level2(pnnx_graph); diff --git a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp index 72fed050263..5a6a31909fc 100644 --- a/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp @@ -64,4 +64,81 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_onnx, 120) +class F_adaptive_avg_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_avg_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 1) + return false; + + const int pad_h = captured_params.at("op_0.arg5").i; + const int pad_w = captured_params.at("op_0.arg6").i; + if (pad_h != 0 || pad_w != 0) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return true; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + return true; + } + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + { + // global pooling + op->params["output_size"] = {1, 1}; + } + + // captured_params.at("op_0.arg3"); // stride_h + // captured_params.at("op_0.arg4"); // stride_w + // captured_params.at("op_0.arg5"); // pad_h + // captured_params.at("op_0.arg6"); // pad_w + // captured_params.at("op_0.arg7"); // kernel_index_h + // captured_params.at("op_0.arg8"); // kernel_index_w + // captured_params.at("op_0.arg9"); // pad_type + // captured_params.at("op_0.arg10"); // ceil_mode + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + { + op->params["output_size"] = {captured_params.at("op_0.arg12").i, captured_params.at("op_0.arg13").i}; + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp index 79776cae7f6..72bfc6aaba7 100644 --- a/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp @@ -43,4 +43,81 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d, 120) +class F_adaptive_max_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.adaptive_max_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 0) + return false; + + const int pad_h = captured_params.at("op_0.arg5").i; + const int pad_w = captured_params.at("op_0.arg6").i; + if (pad_h != 0 || pad_w != 0) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return true; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + return true; + } + + return false; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + { + // global pooling + op->params["output_size"] = {1, 1}; + } + + // captured_params.at("op_0.arg3"); // stride_h + // captured_params.at("op_0.arg4"); // stride_w + // captured_params.at("op_0.arg5"); // pad_h + // captured_params.at("op_0.arg6"); // pad_w + // captured_params.at("op_0.arg7"); // kernel_index_h + // captured_params.at("op_0.arg8"); // kernel_index_w + // captured_params.at("op_0.arg9"); // pad_type + // captured_params.at("op_0.arg10"); // ceil_mode + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive == 1) + { + op->params["output_size"] = {captured_params.at("op_0.arg12").i, captured_params.at("op_0.arg13").i}; + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp index da03102c6d8..d9fb11ac983 100644 --- a/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_avg_pool2d.cpp @@ -205,4 +205,75 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_onnx, 120) +class F_avg_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.avg_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 1) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return false; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive != 0) + return false; + } + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + op->params["kernel_size"] = {captured_params.at("op_0.arg1").i, captured_params.at("op_0.arg2").i}; + op->params["stride"] = {captured_params.at("op_0.arg3").i, captured_params.at("op_0.arg4").i}; + op->params["padding"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i}; + + const int kernel_index_h = captured_params.at("op_0.arg7").i; + const int kernel_index_w = captured_params.at("op_0.arg8").i; + if (kernel_index_h != -1 || kernel_index_w != -1) + { + fprintf(stderr, "unsupported F.avg_pool2d kernel_index %d %d\n", kernel_index_h, kernel_index_w); + } + + const int pad_type = captured_params.at("op_0.arg9").i; + if (pad_type > 0) + { + fprintf(stderr, "unsupported F.avg_pool2d pad_type %d\n", pad_type); + } + + op->params["ceil_mode"] = captured_params.at("op_0.arg10").i ? true : false; + // captured_params.at("op_0.arg11"); // is_adaptive + // captured_params.at("op_0.arg12"); // output_h + // captured_params.at("op_0.arg13"); // output_w + + op->params["count_include_pad"] = false; + op->params["divisor_override"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_batch_norm.cpp b/tools/pnnx/src/pass_level2/F_batch_norm.cpp index f4ead03505b..4a82809ff23 100644 --- a/tools/pnnx/src/pass_level2/F_batch_norm.cpp +++ b/tools/pnnx/src/pass_level2/F_batch_norm.cpp @@ -121,4 +121,51 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx, 130) +class F_batch_norm_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 weight @data=(%num_features)f32 +pnnx.Attribute op_1 0 1 bias @data=(%num_features)f32 +tnn.BatchNormCxx op_2 3 1 input weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +7 6 +pnnx.Input input 0 1 input +pnnx.Attribute mean 0 1 running_mean +pnnx.Attribute var 0 1 running_var +pnnx.Attribute weight 0 1 weight @data=%op_0.data +pnnx.Attribute bias 0 1 bias @data=%op_1.data +F.batch_norm bn 5 1 input running_mean running_var weight bias out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + const int num_features = captured_params.at("num_features").i; + + Operator* op_mean = ops.at("mean"); + op_mean->attrs["data"] = Attribute({num_features}, std::vector(num_features, 0.f)); + + Operator* op_var = ops.at("var"); + op_var->attrs["data"] = Attribute({num_features}, std::vector(num_features, 1.f)); + + Operator* op_bn = ops.at("bn"); + op_bn->params["eps"] = 0.f; + op_bn->inputnames = {"input", "running_mean", "running_var", "weight", "bias"}; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_tnn, 130) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv1d.cpp b/tools/pnnx/src/pass_level2/F_conv1d.cpp index cdc503d5345..c7b2858173c 100644 --- a/tools/pnnx/src/pass_level2/F_conv1d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv1d.cpp @@ -195,4 +195,78 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_onnx_1, 140) +class F_conv1d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv1d"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.arg9") == captured_params.end()) + return true; + + const int activation = captured_params.at("op_0.arg9").i; + return activation == 0; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["groups"] = captured_params.at("op_0.arg0"); + // captured_params.at("op_0.arg1"); // inch + // captured_params.at("op_0.arg2"); // outch + // captured_params.at("op_0.arg3"); // kernel_size + op->params["stride"] = {captured_params.at("op_0.arg4").i}; + op->params["padding"] = {captured_params.at("op_0.arg5").i}; + // captured_params.at("op_0.arg6"); // bias + // captured_params.at("op_0.arg7"); // pad_type + op->params["dilation"] = {captured_params.at("op_0.arg8").i}; + if (op->params["dilation"].ai == std::vector{-1}) + { + op->params["dilation"] = {1}; + } + // captured_params.at("op_0.arg9"); // activation + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_tnn, 140) + +class F_conv1d_tnn_1 : public F_conv1d_tnn +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params) const + { + F_conv1d_tnn::write(op, captured_params); + + op->params["bias"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv1d_tnn_1, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_conv2d.cpp b/tools/pnnx/src/pass_level2/F_conv2d.cpp index 806cc44af9e..dfee3b253e1 100644 --- a/tools/pnnx/src/pass_level2/F_conv2d.cpp +++ b/tools/pnnx/src/pass_level2/F_conv2d.cpp @@ -305,4 +305,79 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_onnx_1, 140) +class F_conv2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.conv2d"; + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.arg13") == captured_params.end()) + return true; + + const int activation = captured_params.at("op_0.arg13").i; + return activation == 0; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["groups"] = captured_params.at("op_0.arg0"); + // captured_params.at("op_0.arg1"); // inch + // captured_params.at("op_0.arg2"); // outch + // captured_params.at("op_0.arg3"); // kernel_h + // captured_params.at("op_0.arg4"); // kernel_w + op->params["stride"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i}; + op->params["padding"] = {captured_params.at("op_0.arg7").i, captured_params.at("op_0.arg8").i}; + // captured_params.at("op_0.arg9"); // bias + // captured_params.at("op_0.arg10"); // pad_type + op->params["dilation"] = {captured_params.at("op_0.arg11").i, captured_params.at("op_0.arg12").i}; + if (op->params["dilation"].ai == std::vector{-1, -1}) + { + op->params["dilation"] = {1, 1}; + } + // captured_params.at("op_0.arg13"); // activation + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_tnn, 140) + +class F_conv2d_tnn_1 : public F_conv2d_tnn +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params) const + { + F_conv2d_tnn::write(op, captured_params); + + op->params["bias"] = Parameter(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_conv2d_tnn_1, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_embedding.cpp b/tools/pnnx/src/pass_level2/F_embedding.cpp index 810f6b98cac..d869d897a7d 100644 --- a/tools/pnnx/src/pass_level2/F_embedding.cpp +++ b/tools/pnnx/src/pass_level2/F_embedding.cpp @@ -69,4 +69,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_onnx, 140) +class F_embedding_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Gather op_0 2 1 input weight out arg0=0 arg1=1 arg2=0 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.embedding"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params["scale_grad_by_freq"] = false; + op->params["sparse"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_tnn, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_linear.cpp b/tools/pnnx/src/pass_level2/F_linear.cpp index 2b43fda9979..d796cb47188 100644 --- a/tools/pnnx/src/pass_level2/F_linear.cpp +++ b/tools/pnnx/src/pass_level2/F_linear.cpp @@ -340,4 +340,27 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_onnx_4, 110) +class F_linear_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight @data=(%in_features,%out_features)f32 +pnnx.Input input_2 0 1 bias @data=(%out_features)f32 +tnn.InnerProduct op_0 3 1 input weight bias out arg0=* arg1=* arg2=0 arg3=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.linear"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_linear_tnn, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp index f3fbc681121..0fa98c79873 100644 --- a/tools/pnnx/src/pass_level2/F_max_pool2d.cpp +++ b/tools/pnnx/src/pass_level2/F_max_pool2d.cpp @@ -298,4 +298,74 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_onnx_1, 120) +class F_max_pool2d_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Pooling op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.max_pool2d"; + } + + bool match(const std::map& captured_params) const + { + const int pool_type = captured_params.at("op_0.arg0").i; + if (pool_type != 0) + return false; + + const int kernel_h = captured_params.at("op_0.arg1").i; + const int kernel_w = captured_params.at("op_0.arg2").i; + if (kernel_h == 0 && kernel_w == 0) + return false; + + if (captured_params.find("op_0.arg11") != captured_params.end()) + { + const int is_adaptive = captured_params.at("op_0.arg11").i; + if (is_adaptive != 0) + return false; + } + + return true; + } + + void write(Operator* op, const std::map& captured_params) const + { + // captured_params.at("op_0.arg0"); // pool_type + op->params["kernel_size"] = {captured_params.at("op_0.arg1").i, captured_params.at("op_0.arg2").i}; + op->params["stride"] = {captured_params.at("op_0.arg3").i, captured_params.at("op_0.arg4").i}; + op->params["padding"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i}; + + const int kernel_index_h = captured_params.at("op_0.arg7").i; + const int kernel_index_w = captured_params.at("op_0.arg8").i; + if (kernel_index_h != -1 || kernel_index_w != -1) + { + fprintf(stderr, "unsupported F.avg_pool2d kernel_index %d %d\n", kernel_index_h, kernel_index_w); + } + + const int pad_type = captured_params.at("op_0.arg9").i; + if (pad_type > 0) + { + fprintf(stderr, "unsupported F.avg_pool2d pad_type %d\n", pad_type); + } + + op->params["ceil_mode"] = captured_params.at("op_0.arg10").i ? true : false; + // captured_params.at("op_0.arg11"); // is_adaptive + // captured_params.at("op_0.arg12"); // output_h + // captured_params.at("op_0.arg13"); // output_w + + op->params["return_indices"] = false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_tnn, 120) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_pad.cpp b/tools/pnnx/src/pass_level2/F_pad.cpp index 795012ca516..5511a7d67c6 100644 --- a/tools/pnnx/src/pass_level2/F_pad.cpp +++ b/tools/pnnx/src/pass_level2/F_pad.cpp @@ -480,4 +480,62 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_onnx_1, 110) +class F_pad_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.PadV2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.pad"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int ndim = captured_params.at("op_0.arg0").i; + + std::vector pads(ndim * 2); + for (int i = 0; i < ndim; i++) + { + pads[(ndim - 1 - i) * 2] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + for (int i = 0; i < ndim; i++) + { + pads[(ndim - 1 - i) * 2 + 1] = captured_params.at("op_0.arg" + std::to_string(ndim + i + 1)).i; + } + + // strip zero pads for higher dims + // (3,3,0,0,0,0) to (3,3) + for (int i = ndim - 1; i >= 0; i--) + { + if (pads[i * 2] == 0 && pads[i * 2 + 1] == 0) + pads.resize(i * 2); + } + + op->params["pad"] = pads; + + const int type = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 1)).i; + if (type == 0) + { + op->params["mode"] = "constant"; + op->params["value"] = captured_params.at("op_0.arg" + std::to_string(ndim * 2 + 2)); + } + if (type == 1) + { + op->params["mode"] = "reflect"; + op->params["value"] = Parameter(); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pad_tnn, 110) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/F_softmax.cpp b/tools/pnnx/src/pass_level2/F_softmax.cpp index 5497b9e176f..d61213bf3bb 100644 --- a/tools/pnnx/src/pass_level2/F_softmax.cpp +++ b/tools/pnnx/src/pass_level2/F_softmax.cpp @@ -133,4 +133,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax_onnx_1, 100) +class F_softmax_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.SoftmaxCaffe op_0 1 1 input out arg0=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.softmax"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_softmax_tnn, 100) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_expand.cpp b/tools/pnnx/src/pass_level2/Tensor_expand.cpp index 0930bc51204..f3a7d0275e5 100644 --- a/tools/pnnx/src/pass_level2/Tensor_expand.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_expand.cpp @@ -109,4 +109,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_onnx, 60) +class Tensor_expand_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 shape +tnn.Expand op_0 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.expand"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_tnn, 61) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp b/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp index ca726b89faf..85557468445 100644 --- a/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_expand_as.cpp @@ -38,4 +38,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_as, 60) +class Tensor_expand_as_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +tnn.Shape op_0 1 1 other shape +tnn.Expand op_1 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.expand_as"; + } + + void write(Operator* op, const std::map& /*captured_params*/) const + { + op->params.clear(); + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_as_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_permute.cpp b/tools/pnnx/src/pass_level2/Tensor_permute.cpp index e53f5a45bbc..7f55d00f636 100644 --- a/tools/pnnx/src/pass_level2/Tensor_permute.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_permute.cpp @@ -82,4 +82,36 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_permute_onnx, 60) +class Tensor_permute_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Permute op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.permute"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int dims_count = captured_params.at("op_0.arg0").i; + std::vector dims(dims_count); + for (int i = 0; i < dims_count; i++) + { + dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["dims"] = dims; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_permute_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp index 4261ed0a467..a9238d3e441 100644 --- a/tools/pnnx/src/pass_level2/Tensor_reshape.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_reshape.cpp @@ -123,4 +123,118 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_2, 61) +class Tensor_reshape_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Reshape op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int axis = captured_params.at("op_0.arg0").i; + const int num_axes = captured_params.at("op_0.arg1").i; + const int shape_rank = captured_params.at("op_0.arg2").i; + + std::vector shape(shape_rank); + for (int i = 0; i < shape_rank; i++) + { + shape[i] = captured_params.at("op_0.arg" + std::to_string(i + 3)).i; + } + + const int reshape_type = captured_params.at("op_0.arg" + std::to_string(shape_rank + 3)).i; + + // HACK + if (shape == std::vector{0, -1, 0, 0}) + { + shape = {-1}; + } + + op->params["shape"] = shape; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn, 60) + +class Tensor_reshape_tnn_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute shape 0 1 shape @data +tnn.Reshape op_0 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + bool match(const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + // one dim i32 + const auto& shape_data = captured_attrs.at("shape.data"); + return shape_data.shape.size() == 1 && shape_data.type == 4; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + const auto& shape_data = captured_attrs.at("shape.data"); + const int* p = (const int*)shape_data.data.data(); + const int ndim = shape_data.data.size() / 4; + + std::vector shape(ndim); + for (int i = 0; i < ndim; i++) + { + shape[i] = p[i]; + } + + op->params["shape"] = shape; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn_1, 60) + +class Tensor_reshape_tnn_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Input shape 0 1 shape +tnn.Reshape op_0 2 1 input shape out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.reshape"; + } + + void write(Operator* /*op*/, const std::map& /*captured_params*/) const + { + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn_2, 61) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_slice.cpp b/tools/pnnx/src/pass_level2/Tensor_slice.cpp index b610b29a256..8fd684c9d84 100644 --- a/tools/pnnx/src/pass_level2/Tensor_slice.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_slice.cpp @@ -150,4 +150,78 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice_onnx_1, 70) +class Tensor_slice_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.StridedSliceV2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.slice"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int nbegins = captured_params.at("op_0.arg0").i; + std::vector begins(nbegins); + for (int i = 0; i < nbegins; i++) + { + begins[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + const int nends = captured_params.at("op_0.arg" + std::to_string(nbegins + 1)).i; + std::vector ends(nends); + for (int i = 0; i < nends; i++) + { + ends[i] = captured_params.at("op_0.arg" + std::to_string(i + 2 + nbegins)).i; + } + const int naxes = captured_params.at("op_0.arg" + std::to_string(nbegins + nends + 2)).i; + std::vector axes(naxes); + for (int i = 0; i < naxes; i++) + { + axes[i] = captured_params.at("op_0.arg" + std::to_string(i + 3 + nbegins + nends)).i; + } + + std::vector strides; + if (captured_params.find("op_0.arg" + std::to_string(nbegins + nends + naxes + 3)) != captured_params.end()) + { + const int nstrides = captured_params.at("op_0.arg" + std::to_string(nbegins + nends + naxes + 3)).i; + strides.resize(nstrides); + for (int i = 0; i < nstrides; i++) + { + strides[i] = captured_params.at("op_0.arg" + std::to_string(i + 4 + nbegins + nends + naxes)).i; + } + } + else + { + strides.resize(naxes, 1); + } + + if (axes.size() == 1) + { + op->params["dim"] = axes[0]; + op->params["start"] = begins[0]; + op->params["end"] = ends[0]; + op->params["step"] = strides[0]; + } + else + { + op->params["dims"] = axes; + op->params["starts"] = begins; + op->params["ends"] = ends; + op->params["steps"] = strides; + op->params["selects"] = std::vector(axes.size(), INT_MAX); + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice_tnn, 70) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/Tensor_to.cpp b/tools/pnnx/src/pass_level2/Tensor_to.cpp index 81da30bdd68..2780a529693 100644 --- a/tools/pnnx/src/pass_level2/Tensor_to.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_to.cpp @@ -162,4 +162,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_onnx, 60) +class Tensor_to_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Cast op_0 1 1 input out arg0=%to +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.to"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int to = captured_params.at("to").i; + + op->params["non_blocking"] = false; + op->params["copy"] = false; + op->params["memory_format"] = "torch.preserve_format"; + + if (to == 0) op->params["dtype"] = "torch.float"; + if (to == 1) op->params["dtype"] = "torch.half"; + if (to == 2) op->params["dtype"] = "torch.int8"; + if (to == 3) op->params["dtype"] = "torch.int"; + if (to == 4) op->params["dtype"] = "torch.bfloat16"; + if (to == 5) op->params["dtype"] = "torch.long"; + if (to == 6) op->params["dtype"] = "torch.uint32"; + if (to == 8) op->params["dtype"] = "torch.uint8"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/nn_LSTM.cpp b/tools/pnnx/src/pass_level2/nn_LSTM.cpp index 56ac7ab981c..6b1134f039b 100644 --- a/tools/pnnx/src/pass_level2/nn_LSTM.cpp +++ b/tools/pnnx/src/pass_level2/nn_LSTM.cpp @@ -629,4 +629,284 @@ pnnx.Output output 1 0 out2 REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_onnx_B5, 140) +class nn_LSTM_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +8 9 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 initial_h +pnnx.Input input_2 0 1 initial_c +pnnx.Attribute W 0 1 W @data +pnnx.Attribute R 0 1 R @data +pnnx.Attribute B 0 1 B @data +tnn.LSTMONNX lstm 6 3 input W R B initial_h initial_c out outh outc %*=%* +pnnx.Output output 3 0 out outh outc +)PNNXIR"; + } + + const char* type_str() const + { + return "nn.LSTM"; + } + + const char* name_str() const + { + return "lstm"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + // arg0=0, arg1=512, arg2=2 + + // captured_params.at("lstm.arg0"); // clip_threshold + + const int hidden_size = captured_params.at("lstm.arg1").i; + const int direction = captured_params.at("lstm.arg2").i; + + const int num_directions = direction == 2 ? 2 : 1; + + const auto& W = captured_attrs.at("W.data"); // 2,2048,512 + const auto& R = captured_attrs.at("R.data"); // 2,2048,512 + const auto& B = captured_attrs.at("B.data"); // 2,4096 + + if (W.shape.size() != 3 || W.shape[0] != num_directions || W.shape[1] != 4 * hidden_size) + return false; + + if (R.shape.size() != 3 || R.shape[0] != num_directions || R.shape[1] != 4 * hidden_size || R.shape[2] != hidden_size) + return false; + + if (B.shape.size() != 2 || B.shape[0] != num_directions || B.shape[1] != 8 * hidden_size) + return false; + + return true; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const int hidden_size = captured_params.at("lstm.arg1").i; + const int direction = captured_params.at("lstm.arg2").i; + + const int num_directions = direction == 2 ? 2 : 1; + + const auto& W = captured_attrs.at("W.data"); + const auto& R = captured_attrs.at("R.data"); + const auto& B = captured_attrs.at("B.data"); + + const int input_size = W.shape[2]; + + op->params["input_size"] = input_size; + op->params["hidden_size"] = hidden_size; + op->params["num_layers"] = 1; + op->params["bias"] = false; + op->params["batch_first"] = false; + op->params["bidirectional"] = direction == 2 ? true : false; + op->params["proj_size"] = 0; + + // split W R and reorder IOFG to IFGO + auto W_data = W.get_float32_data(); + auto R_data = R.get_float32_data(); + + std::vector W2(4 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* iptr = (const float*)W_data.data(); + const float* optr = (const float*)W_data.data() + weight_data_size_g; + const float* fptr = (const float*)W_data.data() + weight_data_size_g * 2; + const float* gptr = (const float*)W_data.data() + weight_data_size_g * 3; + + float* w_iptr = (float*)W2.data(); + float* w_fptr = (float*)W2.data() + weight_data_size_g; + float* w_gptr = (float*)W2.data() + weight_data_size_g * 2; + float* w_optr = (float*)W2.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + std::vector R2(4 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* iptr = (const float*)R_data.data(); + const float* optr = (const float*)R_data.data() + weight_data_size_g; + const float* fptr = (const float*)R_data.data() + weight_data_size_g * 2; + const float* gptr = (const float*)R_data.data() + weight_data_size_g * 3; + + float* w_iptr = (float*)R2.data(); + float* w_fptr = (float*)R2.data() + weight_data_size_g; + float* w_gptr = (float*)R2.data() + weight_data_size_g * 2; + float* w_optr = (float*)R2.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + if (direction == 2) + { + op->attrs["weight_ih_l0"] = Attribute({4 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); + + std::vector W2R(4 * hidden_size * input_size); + { + const int weight_data_size_g = hidden_size * input_size; + + const float* iptr = (const float*)W_data.data() + weight_data_size_g * 4; + const float* optr = (const float*)W_data.data() + weight_data_size_g * 5; + const float* fptr = (const float*)W_data.data() + weight_data_size_g * 6; + const float* gptr = (const float*)W_data.data() + weight_data_size_g * 7; + + float* w_iptr = (float*)W2R.data(); + float* w_fptr = (float*)W2R.data() + weight_data_size_g; + float* w_gptr = (float*)W2R.data() + weight_data_size_g * 2; + float* w_optr = (float*)W2R.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + std::vector R2R(4 * hidden_size * hidden_size); + { + const int weight_data_size_g = hidden_size * hidden_size; + + const float* iptr = (const float*)R_data.data() + weight_data_size_g * 4; + const float* optr = (const float*)R_data.data() + weight_data_size_g * 5; + const float* fptr = (const float*)R_data.data() + weight_data_size_g * 6; + const float* gptr = (const float*)R_data.data() + weight_data_size_g * 7; + + float* w_iptr = (float*)R2R.data(); + float* w_fptr = (float*)R2R.data() + weight_data_size_g; + float* w_gptr = (float*)R2R.data() + weight_data_size_g * 2; + float* w_optr = (float*)R2R.data() + weight_data_size_g * 3; + + memcpy(w_iptr, iptr, weight_data_size_g * sizeof(float)); + memcpy(w_fptr, fptr, weight_data_size_g * sizeof(float)); + memcpy(w_gptr, gptr, weight_data_size_g * sizeof(float)); + memcpy(w_optr, optr, weight_data_size_g * sizeof(float)); + } + + op->attrs["weight_ih_l0_reverse"] = Attribute({4 * hidden_size, input_size}, W2R); + op->attrs["weight_hh_l0_reverse"] = Attribute({4 * hidden_size, hidden_size}, R2R); + } + else + { + op->attrs["weight_ih_l0"] = Attribute({4 * hidden_size, input_size}, W2); + op->attrs["weight_hh_l0"] = Attribute({4 * hidden_size, hidden_size}, R2); + } + + bool has_bias = false; + for (auto b : B.get_float32_data()) + { + if (b != 0.f) + { + has_bias = true; + break; + } + } + + op->params["bias"] = has_bias; + + if (has_bias) + { + // split B and reorder IOFG to IFGO + auto B_data = B.get_float32_data(); + + std::vector B2(4 * hidden_size); + std::vector B3(4 * hidden_size); + { + const float* iptr = (const float*)B_data.data(); + const float* optr = (const float*)B_data.data() + hidden_size; + const float* fptr = (const float*)B_data.data() + hidden_size * 2; + const float* gptr = (const float*)B_data.data() + hidden_size * 3; + + float* w_iptr = (float*)B2.data(); + float* w_fptr = (float*)B2.data() + hidden_size; + float* w_gptr = (float*)B2.data() + hidden_size * 2; + float* w_optr = (float*)B2.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + { + const float* iptr = (const float*)B_data.data() + hidden_size * 4; + const float* optr = (const float*)B_data.data() + hidden_size * 5; + const float* fptr = (const float*)B_data.data() + hidden_size * 6; + const float* gptr = (const float*)B_data.data() + hidden_size * 7; + + float* w_iptr = (float*)B3.data(); + float* w_fptr = (float*)B3.data() + hidden_size; + float* w_gptr = (float*)B3.data() + hidden_size * 2; + float* w_optr = (float*)B3.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + + if (direction == 2) + { + op->attrs["bias_ih_l0"] = Attribute({4 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({4 * hidden_size}, B3); + + std::vector B2R(4 * hidden_size); + std::vector B3R(4 * hidden_size); + { + const float* iptr = (const float*)B_data.data() + hidden_size * 8; + const float* optr = (const float*)B_data.data() + hidden_size * 9; + const float* fptr = (const float*)B_data.data() + hidden_size * 10; + const float* gptr = (const float*)B_data.data() + hidden_size * 11; + + float* w_iptr = (float*)B2R.data(); + float* w_fptr = (float*)B2R.data() + hidden_size; + float* w_gptr = (float*)B2R.data() + hidden_size * 2; + float* w_optr = (float*)B2R.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + { + const float* iptr = (const float*)B_data.data() + hidden_size * 12; + const float* optr = (const float*)B_data.data() + hidden_size * 13; + const float* fptr = (const float*)B_data.data() + hidden_size * 14; + const float* gptr = (const float*)B_data.data() + hidden_size * 15; + + float* w_iptr = (float*)B3R.data(); + float* w_fptr = (float*)B3R.data() + hidden_size; + float* w_gptr = (float*)B3R.data() + hidden_size * 2; + float* w_optr = (float*)B3R.data() + hidden_size * 3; + + memcpy(w_iptr, iptr, hidden_size * sizeof(float)); + memcpy(w_fptr, fptr, hidden_size * sizeof(float)); + memcpy(w_gptr, gptr, hidden_size * sizeof(float)); + memcpy(w_optr, optr, hidden_size * sizeof(float)); + } + + op->attrs["bias_ih_l0_reverse"] = Attribute({4 * hidden_size}, B2R); + op->attrs["bias_hh_l0_reverse"] = Attribute({4 * hidden_size}, B3R); + } + else + { + op->attrs["bias_ih_l0"] = Attribute({4 * hidden_size}, B2); + op->attrs["bias_hh_l0"] = Attribute({4 * hidden_size}, B3); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(nn_LSTM_tnn, 140) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_clamp.cpp b/tools/pnnx/src/pass_level2/torch_clamp.cpp index ffe241e6895..db7e94521eb 100644 --- a/tools/pnnx/src/pass_level2/torch_clamp.cpp +++ b/tools/pnnx/src/pass_level2/torch_clamp.cpp @@ -114,4 +114,25 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_onnx_2, 40) +class torch_clamp_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Clip op_0 1 1 input out arg0=%min arg1=%max +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.clamp"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clamp_tnn, 40) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_full.cpp b/tools/pnnx/src/pass_level2/torch_full.cpp index ffea79b38de..1e77c60a930 100644 --- a/tools/pnnx/src/pass_level2/torch_full.cpp +++ b/tools/pnnx/src/pass_level2/torch_full.cpp @@ -93,4 +93,32 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_onnx, 21) +class torch_full_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 size +pnnx.Attribute value 0 1 value @data=(1)f32 +tnn.ConstantOfShape op_0 2 1 size value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.full"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0]; + op->params["dtype"] = "torch.float"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_tnn, 21) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_full_like.cpp b/tools/pnnx/src/pass_level2/torch_full_like.cpp index 3b016611763..20e2cec3004 100644 --- a/tools/pnnx/src/pass_level2/torch_full_like.cpp +++ b/tools/pnnx/src/pass_level2/torch_full_like.cpp @@ -67,4 +67,33 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like, 20) +class torch_full_like_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +tnn.Shape op_0 1 1 input shape +pnnx.Attribute value 0 1 value @data=(1)f32 +tnn.ConstantOfShape op_2 2 1 shape value out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.full_like"; + } + + void write(Operator* op, const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0]; + op->params["dtype"] = "torch.float"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like_tnn, 20) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_index_select.cpp b/tools/pnnx/src/pass_level2/torch_index_select.cpp index 084554af8a4..2bfc10558d4 100644 --- a/tools/pnnx/src/pass_level2/torch_index_select.cpp +++ b/tools/pnnx/src/pass_level2/torch_index_select.cpp @@ -13,6 +13,7 @@ // specific language governing permissions and limitations under the License. #include "pass_level2.h" +#include namespace pnnx { @@ -39,4 +40,61 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_index_select, 70) +class torch_index_select_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + // clang-format off + // *INDENT-OFF* + + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 index @data=(?)i32 +tnn.Gather op_1 2 1 input index out arg0=%dim arg1=0 arg2=1 +pnnx.Output output 1 0 out +)PNNXIR"; + + // *INDENT-ON* + // clang-format on + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute index 0 1 index +torch.index_select select 2 1 input index out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + const Attribute& index_data = captured_attrs.at("op_0.data"); + + // i32 to i64 + Operator* op_index = ops.at("index"); + const int* p = (const int*)index_data.data.data(); + const int n = index_data.data.size() / 4; + std::vector indices(n); + for (int i = 0; i < n; i++) + { + indices[i] = p[i]; + } + op_index->attrs["data"].type = 5; // i64 + op_index->attrs["data"].shape = {n}; + op_index->attrs["data"].data.resize(n * 8); + memcpy((void*)op_index->attrs["data"].data.data(), (const void*)indices.data(), n * 8); + + Operator* op_gather = ops.at("select"); + op_gather->params["dim"] = captured_params.at("dim"); + op_gather->inputnames = {"input", "index"}; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_index_select_tnn, 70) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_matmul.cpp b/tools/pnnx/src/pass_level2/torch_matmul.cpp index a12f2c84b07..62127a777a2 100644 --- a/tools/pnnx/src/pass_level2/torch_matmul.cpp +++ b/tools/pnnx/src/pass_level2/torch_matmul.cpp @@ -60,4 +60,39 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_matmul_onnx, 90) +class torch_matmul_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 other +tnn.MatMul op_0 2 1 input other out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.matmul"; + } + + void write(Operator* op, const std::map& captured_params) const + { + if (captured_params.find("op_0.arg0") != captured_params.end()) + { + const int weight_position = captured_params.at("op_0.arg0").i; + if (weight_position == 0) + { + // swap input and weight + std::swap(op->inputs[0], op->inputs[1]); + } + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_matmul_tnn, 90) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_max.cpp b/tools/pnnx/src/pass_level2/torch_max.cpp index 1338d58b88c..78dd72601b9 100644 --- a/tools/pnnx/src/pass_level2/torch_max.cpp +++ b/tools/pnnx/src/pass_level2/torch_max.cpp @@ -182,4 +182,47 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_onnx_1, 50) +class torch_max_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMax op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.max"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + if (dim.size() == 1) + { + op->params["dim"] = dim[0]; + } + else + { + fprintf(stderr, "fallback to reduce max all\n"); + } + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_max_tnn, 50) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_mean.cpp b/tools/pnnx/src/pass_level2/torch_mean.cpp index 6a579feeb2d..45fc54023b8 100644 --- a/tools/pnnx/src/pass_level2/torch_mean.cpp +++ b/tools/pnnx/src/pass_level2/torch_mean.cpp @@ -148,4 +148,40 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_onnx, 50) +class torch_mean_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMean op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.mean"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_mean_tnn, 50) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_min.cpp b/tools/pnnx/src/pass_level2/torch_min.cpp index 07248927d2e..201604aff26 100644 --- a/tools/pnnx/src/pass_level2/torch_min.cpp +++ b/tools/pnnx/src/pass_level2/torch_min.cpp @@ -182,4 +182,47 @@ pnnx.Output output 2 0 out indices REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_onnx_1, 50) +class torch_min_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceMin op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.min"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + if (dim.size() == 1) + { + op->params["dim"] = dim[0]; + } + else + { + fprintf(stderr, "fallback to reduce min all\n"); + } + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_min_tnn, 50) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_norm.cpp b/tools/pnnx/src/pass_level2/torch_norm.cpp index 57b54cb1328..c1980d1c6b5 100644 --- a/tools/pnnx/src/pass_level2/torch_norm.cpp +++ b/tools/pnnx/src/pass_level2/torch_norm.cpp @@ -144,4 +144,41 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_norm_fro, 90) REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_norm_fro_dims, 90) +class torch_norm_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceL2 op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.norm"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + op->params["p"] = 2; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_norm_tnn, 90) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_squeeze.cpp b/tools/pnnx/src/pass_level2/torch_squeeze.cpp index e79a2849144..0fdabe82144 100644 --- a/tools/pnnx/src/pass_level2/torch_squeeze.cpp +++ b/tools/pnnx/src/pass_level2/torch_squeeze.cpp @@ -137,4 +137,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_onnx_1, 60) +class torch_squeeze_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Squeeze op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.squeeze"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int dims_count = captured_params.at("op_0.arg0").i; + if (dims_count == 1) + { + op->params["dim"] = captured_params.at("op_0.arg1").i; + } + else + { + std::vector dims(dims_count); + for (int i = 0; i < dims_count; i++) + { + dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["dim"] = dims; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_sum.cpp b/tools/pnnx/src/pass_level2/torch_sum.cpp index 280cdff98aa..66ea74f3c7e 100644 --- a/tools/pnnx/src/pass_level2/torch_sum.cpp +++ b/tools/pnnx/src/pass_level2/torch_sum.cpp @@ -111,4 +111,40 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_onnx, 50) +class torch_sum_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.ReduceSum op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.sum"; + } + + void write(Operator* op, const std::map& captured_params) const + { + std::vector dim; + for (int i = 1;; i++) + { + if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end()) + break; + + dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i); + } + + op->params["dim"] = dim; + op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_tnn, 50) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp index ea377b071aa..1ff74798a51 100644 --- a/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp +++ b/tools/pnnx/src/pass_level2/torch_unsqueeze.cpp @@ -104,4 +104,43 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze_onnx_1, 60) +class torch_unsqueeze_tnn : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Unsqueeze op_0 1 1 input out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.unsqueeze"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const int dims_count = captured_params.at("op_0.arg0").i; + if (dims_count == 1) + { + op->params["dim"] = captured_params.at("op_0.arg1").i; + } + else + { + std::vector dims(dims_count); + for (int i = 0; i < dims_count; i++) + { + dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i; + } + op->params["dim"] = dims; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unsqueeze_tnn, 60) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index 4b9f3f83929..dcb3d1f1365 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -135,6 +135,7 @@ static bool operand_maybe_tensor(const Operand* operand) || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" + || op->type == "aten::erf" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" @@ -648,6 +649,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s || op->type == "aten::ceil" || op->type == "aten::cos" || op->type == "aten::cosh" + || op->type == "aten::erf" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" @@ -888,6 +890,7 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan || op->type == "aten::cos" || op->type == "aten::cosh" || op->type == "aten::div" + || op->type == "aten::erf" || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::floor_divide" diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index c7d5d5d0226..441cfca1c36 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -175,6 +175,7 @@ static std::string eval_expression(const Operator* op) || t == "ceil" || t == "cos" || t == "cosh" + || t == "erf" || t == "exp" || t == "floor" || t == "log" @@ -262,6 +263,11 @@ static std::string eval_expression(const Operator* op) float r = cosh(af); exprstack.push(std::to_string(r)); } + if (t == "erf") + { + float r = erf(af); + exprstack.push(std::to_string(r)); + } if (t == "exp") { float r = exp(af); diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index 2fdc6d77d62..63efb4f8e47 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -123,6 +123,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx || t == "atan" || t == "ceil" || t == "cos" + || t == "erf" || t == "exp" || t == "floor" || t == "log" @@ -154,6 +155,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx if (t == "ceil") op_unary->params["0"] = 3; if (t == "cos") op_unary->params["0"] = 10; if (t == "exp") op_unary->params["0"] = 7; + if (t == "erf") fprintf(stderr, "UnaryOp erf not supported yet\n"); // TODO if (t == "floor") op_unary->params["0"] = 2; if (t == "log") op_unary->params["0"] = 8; if (t == "log10") op_unary->params["0"] = 17; diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp new file mode 100644 index 00000000000..b2688852efa --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp @@ -0,0 +1,135 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_shape_list_construct.h" + +#include + +namespace pnnx { + +namespace tnn2pnnx { + +void fuse_shape_list_construct(Graph& graph) +{ + // TODO unpool tnn.Unsqueeze + + // a0 = pnnx.Attribute @data=(1)i32 + // a1 = tnn.Unsqueeze(..., arg0=1, arg1=0) + // y = tnn.Concat(a0, a1, ..., arg0=0) + // tnn.Reshape(x, y, args=...) / tnn.ConstantOfShape(y) + + // prim::ListConstruct (a0, a1, ...) + // tnn.Reshape(x, y) / tnn.ConstantOfShape(y) + + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "tnn.Concat") + continue; + + if (op->outputs[0]->consumers.size() != 1) + continue; + + Operator* op2 = op->outputs[0]->consumers[0]; + if (op2->type == "tnn.Reshape") + { + if (op2->inputs.size() != 2) + continue; + + if (op2->inputs[1] != op->outputs[0]) + continue; + } + else if (op2->type == "tnn.ConstantOfShape") + { + if (op2->inputs[0] != op->outputs[0]) + continue; + } + else if (op2->type == "tnn.Expand") + { + if (op2->inputs[1] != op->outputs[0]) + continue; + } + else + { + continue; + } + + matched = true; + + fprintf(stderr, "match concat + reshape/constantofshape/expand\n"); + + op->type = "prim::ListConstruct"; + + // drop tnn.Unsqueeze between aten::size and prim::ListConstruct + + const size_t count = op->inputs.size(); + for (size_t j = 0; j < count; j++) + { + Operand* r = op->inputs[j]; + + if (r->producer->type != "tnn.Unsqueeze") + continue; + + Operator* op_uqz = r->producer; + + Operand* r0 = op_uqz->inputs[0]; + + if (r0->producer->type != "aten::size") + continue; + + // drop tnn.Unsqueeze + + r0->remove_consumer(op_uqz); + r->remove_consumer(op); + + op->inputs[j] = r0; + r0->consumers.push_back(op); + + if (r->consumers.empty()) + { + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), r)); + delete r; + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_uqz)); + delete op_uqz; + } + } + + if (op2->type == "tnn.Reshape") + { + // drop tnn.Reshape args + op2->params.clear(); + } + if (op2->type == "tnn.Expand") + { + // drop tnn.Expand args + op2->params.clear(); + } + + break; + } + + if (!matched) + break; + } +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h new file mode 100644 index 00000000000..55d72e6305c --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void fuse_shape_list_construct(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_size.cpp b/tools/pnnx/src/pass_tnn/fuse_shape_size.cpp new file mode 100644 index 00000000000..062c8bb2b65 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_size.cpp @@ -0,0 +1,71 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "fuse_shape_size.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +class fuse_shape_size_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +tnn.Shape op_0 1 1 input a +pnnx.Attribute op_1 0 1 index @data=(1)i32 +tnn.Gather op_2 2 1 a index out arg0=0 arg1=0 arg2=1 +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +prim::Constant index 0 1 index +aten::size size 2 1 input index out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + const Attribute& index_data = captured_attrs.at("op_1.data"); + const int index = ((const int*)index_data.data.data())[0]; + + Operator* op_index = ops.at("index"); + op_index->params["value"] = index; + } +}; + +void fuse_shape_size(Graph& graph) +{ + // TODO unpool tnn.Shape + + fuse_shape_size_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/fuse_shape_size.h b/tools/pnnx/src/pass_tnn/fuse_shape_size.h new file mode 100644 index 00000000000..2058af6438e --- /dev/null +++ b/tools/pnnx/src/pass_tnn/fuse_shape_size.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void fuse_shape_size(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_concat.cpp b/tools/pnnx/src/pass_tnn/lower_concat.cpp new file mode 100644 index 00000000000..5a9cb1d662e --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_concat.cpp @@ -0,0 +1,63 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "lower_concat.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_concat(Graph& graph) +{ + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "tnn.Concat") + continue; + + const int dim = op->params["arg0"].i; + + op->type = "aten::cat"; + op->params.clear(); + op->params["dim"] = dim; + + // insert listconstruct for inputs + Operator* op0 = graph.new_operator_before("prim::ListConstruct", op->name + "_lc", op); + Operand* r = graph.new_operand(op->name + "_lc"); + + r->producer = op0; + r->consumers.push_back(op); + + op0->outputs.push_back(r); + + for (size_t j = 0; j < op->inputs.size(); j++) + { + Operand* x = op->inputs[j]; + + x->remove_consumer(op); + x->consumers.push_back(op0); + op0->inputs.push_back(x); + } + + op->inputs.clear(); + op->inputs.push_back(r); + } +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_concat.h b/tools/pnnx/src/pass_tnn/lower_concat.h new file mode 100644 index 00000000000..84fd76a542b --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_concat.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_concat(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp b/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp new file mode 100644 index 00000000000..a4d76c557ce --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_convolution_activation.cpp @@ -0,0 +1,301 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "lower_convolution_activation.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +class lower_convolution_activation_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution conv2d 3 1 input weight bias a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.arg13") == captured_params.end()) + return false; + + this->activation = captured_params.at("op_0.arg13").i; + return activation != 0; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + for (int i = 0; i < 13; i++) + { + std::string argN = std::string("arg") + std::to_string(i); + ops.at("conv2d")->params[argN] = captured_params.at("op_0." + argN); + } + + ops.at("conv2d")->params["arg13"] = 0; + } + +protected: + mutable int activation; +}; + +class lower_convolution_activation_pass_1 : public lower_convolution_activation_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution conv2d 2 1 input weight a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } +}; + +class lower_convolution1d_activation_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D op_0 3 1 input weight bias out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +pnnx.Input input_2 0 1 bias +tnn.Convolution1D conv1d 3 1 input weight bias a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } + + bool match(const std::map& captured_params) const + { + if (captured_params.find("op_0.arg9") == captured_params.end()) + return false; + + this->activation = captured_params.at("op_0.arg9").i; + return activation != 0; + } + + void write(const std::map& ops, const std::map& captured_params) const + { + for (int i = 0; i < 9; i++) + { + std::string argN = std::string("arg") + std::to_string(i); + ops.at("conv1d")->params[argN] = captured_params.at("op_0." + argN); + } + + ops.at("conv1d")->params["arg9"] = 0; + } + +protected: + mutable int activation; +}; + +class lower_convolution1d_activation_pass_1 : public lower_convolution1d_activation_pass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D op_0 2 1 input weight out %*=%* +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + if (this->activation == 1) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +aten::relu relu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else if (this->activation == 2) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +aten::relu6 relu6 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + else // if (this->activation == 256) + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 weight +tnn.Convolution1D conv1d 2 1 input weight a +aten::silu silu 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + } +}; + +void lower_convolution_activation(Graph& graph) +{ + lower_convolution_activation_pass a; + lower_convolution_activation_pass_1 a1; + lower_convolution1d_activation_pass b; + lower_convolution1d_activation_pass_1 b1; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); + pnnx_graph_rewrite(graph, &a1, opindex); + pnnx_graph_rewrite(graph, &b, opindex); + pnnx_graph_rewrite(graph, &b1, opindex); +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_convolution_activation.h b/tools/pnnx/src/pass_tnn/lower_convolution_activation.h new file mode 100644 index 00000000000..91d85f75500 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_convolution_activation.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_convolution_activation(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_power.cpp b/tools/pnnx/src/pass_tnn/lower_power.cpp new file mode 100644 index 00000000000..a35edaa5ffb --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_power.cpp @@ -0,0 +1,62 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "lower_power.h" + +#include "pass_level2.h" + +namespace pnnx { + +namespace tnn2pnnx { + +class lower_power_pass : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +tnn.Power op_0 1 1 input out arg0=%exponent arg1=%alpha arg2=%beta +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +8 7 +pnnx.Input input 0 1 input +prim::Constant alpha 0 1 alpha value=%alpha +prim::Constant beta 0 1 beta value=%beta +prim::Constant exponent 0 1 exponent value=%exponent +aten::mul scale 2 1 input alpha a +aten::add shift 2 1 a beta b +aten::pow pow 2 1 b exponent out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +void lower_power(Graph& graph) +{ + lower_power_pass a; + int opindex = 0; + + pnnx_graph_rewrite(graph, &a, opindex); +} + +} // namespace tnn2pnnx + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_tnn/lower_power.h b/tools/pnnx/src/pass_tnn/lower_power.h new file mode 100644 index 00000000000..35bbedf6393 --- /dev/null +++ b/tools/pnnx/src/pass_tnn/lower_power.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +namespace tnn2pnnx { + +void lower_power(Graph& graph); + +} // namespace tnn2pnnx + +} // namespace pnnx