Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 11, 2025
1 parent 1d1c339 commit a360149
Show file tree
Hide file tree
Showing 11 changed files with 575 additions and 1 deletion.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ if(PNNX_TNN2PNNX)
set(pnnx_pass_tnn_SRCS
pass_tnn/fuse_shape_size.cpp
pass_tnn/fuse_shape_list_construct.cpp
pass_tnn/lower_concat.cpp
pass_tnn/lower_convolution_activation.cpp
pass_tnn/lower_power.cpp
)
Expand Down
5 changes: 5 additions & 0 deletions tools/pnnx/src/load_tnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "pass_tnn/fuse_shape_size.h"
#include "pass_tnn/fuse_shape_list_construct.h"
#include "pass_tnn/lower_concat.h"
#include "pass_tnn/lower_convolution_activation.h"
#include "pass_tnn/lower_power.h"

Expand Down Expand Up @@ -625,6 +626,8 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph)
if (op->type == "tnn.Sub") op->type = "aten::sub";
if (op->type == "tnn.Mul") op->type = "aten::mul";
if (op->type == "tnn.Div") op->type = "aten::div";

// misc
}

tnn2pnnx::fuse_shape_size(pnnx_graph);
Expand All @@ -634,6 +637,8 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph)

tnn2pnnx::lower_power(pnnx_graph);

tnn2pnnx::lower_concat(pnnx_graph);

return 0;
}

Expand Down
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,26 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_onnx, 60)

class Tensor_expand_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 shape
tnn.Expand op_0 2 1 input shape out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.expand";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_tnn, 61)

} // namespace pnnx
28 changes: 28 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_expand_as.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,32 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_as, 60)

class Tensor_expand_as_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 other
tnn.Shape op_0 1 1 other shape
tnn.Expand op_1 2 1 input shape out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.expand_as";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params.clear();
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_expand_as_tnn, 60)

} // namespace pnnx
74 changes: 74 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_slice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,78 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice_onnx_1, 70)

class Tensor_slice_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.StridedSliceV2 op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.slice";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int nbegins = captured_params.at("op_0.arg0").i;
std::vector<int> begins(nbegins);
for (int i = 0; i < nbegins; i++)
{
begins[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i;
}
const int nends = captured_params.at("op_0.arg" + std::to_string(nbegins + 1)).i;
std::vector<int> ends(nends);
for (int i = 0; i < nends; i++)
{
ends[i] = captured_params.at("op_0.arg" + std::to_string(i + 2 + nbegins)).i;
}
const int naxes = captured_params.at("op_0.arg" + std::to_string(nbegins + nends + 2)).i;
std::vector<int> axes(naxes);
for (int i = 0; i < naxes; i++)
{
axes[i] = captured_params.at("op_0.arg" + std::to_string(i + 3 + nbegins + nends)).i;
}

std::vector<int> strides;
if (captured_params.find("op_0.arg" + std::to_string(nbegins + nends + naxes + 3)) != captured_params.end())
{
const int nstrides = captured_params.at("op_0.arg" + std::to_string(nbegins + nends + naxes + 3)).i;
strides.resize(nstrides);
for (int i = 0; i < nstrides; i++)
{
strides[i] = captured_params.at("op_0.arg" + std::to_string(i + 4 + nbegins + nends + naxes)).i;
}
}
else
{
strides.resize(naxes, 1);
}

if (axes.size() == 1)
{
op->params["dim"] = axes[0];
op->params["start"] = begins[0];
op->params["end"] = ends[0];
op->params["step"] = strides[0];
}
else
{
op->params["dims"] = axes;
op->params["starts"] = begins;
op->params["ends"] = ends;
op->params["steps"] = strides;
op->params["selects"] = std::vector<int>(axes.size(), INT_MAX);
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_slice_tnn, 70)

} // namespace pnnx
Loading

0 comments on commit a360149

Please sign in to comment.