Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 11, 2025
1 parent 86d25a0 commit 51e99e8
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ endif()

if(PNNX_TNN2PNNX)
set(pnnx_pass_tnn_SRCS
pass_tnn/fuse_shape_size.cpp
pass_tnn/fuse_shape_list_construct.cpp
pass_tnn/lower_convolution_activation.cpp
pass_tnn/lower_power.cpp
)
Expand Down
5 changes: 5 additions & 0 deletions tools/pnnx/src/load_tnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <string.h>
#include <unordered_map>

#include "pass_tnn/fuse_shape_size.h"
#include "pass_tnn/fuse_shape_list_construct.h"
#include "pass_tnn/lower_convolution_activation.h"
#include "pass_tnn/lower_power.h"

Expand Down Expand Up @@ -625,6 +627,9 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph)
if (op->type == "tnn.Div") op->type = "aten::div";
}

tnn2pnnx::fuse_shape_size(pnnx_graph);
tnn2pnnx::fuse_shape_list_construct(pnnx_graph);

tnn2pnnx::lower_convolution_activation(pnnx_graph);

tnn2pnnx::lower_power(pnnx_graph);
Expand Down
66 changes: 66 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,70 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_onnx_2, 61)

class Tensor_reshape_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Reshape op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.reshape";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int axis = captured_params.at("op_0.arg0").i;
const int num_axes = captured_params.at("op_0.arg1").i;
const int shape_rank = captured_params.at("op_0.arg2").i;

std::vector<int> shape(shape_rank);
for (int i = 0; i < shape_rank; i++)
{
shape[i] = captured_params.at("op_0.arg" + std::to_string(i + 3)).i;
}

const int reshape_type = captured_params.at("op_0.arg" + std::to_string(shape_rank + 3)).i;

// HACK
if (shape == std::vector{0, -1, 0, 0})
{
shape = {-1};
}

op->params["shape"] = shape;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn, 60)

class Tensor_reshape_tnn_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
pnnx.Input shape 0 1 shape
tnn.Reshape op_0 2 1 input shape out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.reshape";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_reshape_tnn_1, 60)

} // namespace pnnx
126 changes: 126 additions & 0 deletions tools/pnnx/src/pass_tnn/fuse_shape_list_construct.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_shape_list_construct.h"

#include <algorithm>

namespace pnnx {

namespace tnn2pnnx {

void fuse_shape_list_construct(Graph& graph)
{
// TODO unpool tnn.Unsqueeze

// a0 = pnnx.Attribute @data=(1)i32
// a1 = tnn.Unsqueeze(..., arg0=1, arg1=0)
// y = tnn.Concat(a0, a1, ..., arg0=0)
// tnn.Reshape(x, y, args=...) / tnn.ConstantOfShape(y)

// prim::ListConstruct (a0, a1, ...)
// tnn.Reshape(x, y) / tnn.ConstantOfShape(y)

while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "tnn.Concat")
continue;

if (op->outputs[0]->consumers.size() != 1)
continue;

Operator* op2 = op->outputs[0]->consumers[0];
if (op2->type == "tnn.Reshape")
{
if (op2->inputs.size() != 2)
continue;

if (op2->inputs[1] != op->outputs[0])
continue;
}
else if (op2->type == "tnn.ConstantOfShape")
{
if (op2->inputs[0] != op->outputs[0])
continue;
}
else
{
continue;
}

matched = true;

fprintf(stderr, "match concat + reshape/constantofshape\n");

op->type = "prim::ListConstruct";

// drop tnn.Unsqueeze between aten::size and prim::ListConstruct

const size_t count = op->inputs.size();
for (size_t j = 0; j < count; j++)
{
Operand* r = op->inputs[j];

if (r->producer->type != "tnn.Unsqueeze")
continue;

Operator* op_uqz = r->producer;

Operand* r0 = op_uqz->inputs[0];

if (r0->producer->type != "aten::size")
continue;

// drop tnn.Unsqueeze

r0->remove_consumer(op_uqz);
r->remove_consumer(op);

op->inputs[j] = r0;
r0->consumers.push_back(op);

if (r->consumers.empty())
{
graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), r));
delete r;

graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_uqz));
delete op_uqz;
}
}

if (op2->type == "tnn.Reshape")
{
// drop tnn.Reshape args
op2->params.clear();
}

break;
}

if (!matched)
break;
}

}

} // namespace tnn2pnnx

} // namespace pnnx
25 changes: 25 additions & 0 deletions tools/pnnx/src/pass_tnn/fuse_shape_list_construct.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

namespace tnn2pnnx {

void fuse_shape_list_construct(Graph& graph);

} // namespace tnn2pnnx

} // namespace pnnx
71 changes: 71 additions & 0 deletions tools/pnnx/src/pass_tnn/fuse_shape_size.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_shape_size.h"

#include "pass_level2.h"

namespace pnnx {

namespace tnn2pnnx {

class fuse_shape_size_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
tnn.Shape op_0 1 1 input a
pnnx.Attribute op_1 0 1 index @data=(1)i32
tnn.Gather op_2 2 1 a index out arg0=0 arg1=0 arg2=1
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
prim::Constant index 0 1 index
aten::size size 2 1 input index out
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const Attribute& index_data = captured_attrs.at("op_1.data");
const int index = ((const int*)index_data.data.data())[0];

Operator* op_index = ops.at("index");
op_index->params["value"] = index;
}
};

void fuse_shape_size(Graph& graph)
{
// TODO unpool tnn.Shape

fuse_shape_size_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace tnn2pnnx

} // namespace pnnx
25 changes: 25 additions & 0 deletions tools/pnnx/src/pass_tnn/fuse_shape_size.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

namespace tnn2pnnx {

void fuse_shape_size(Graph& graph);

} // namespace tnn2pnnx

} // namespace pnnx

0 comments on commit 51e99e8

Please sign in to comment.