Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tnn2pnnx #5898

Merged
merged 32 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/pnnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ else()
set(onnxruntime_FOUND FALSE)
endif()

option(PNNX_TNN2PNNX "build tnn2pnnx" ON)

add_subdirectory(src)

enable_testing()
Expand Down
28 changes: 28 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,29 @@ else()
message(STATUS "Building without onnx2pnnx")
endif()

if(PNNX_TNN2PNNX)
set(pnnx_pass_tnn_SRCS
pass_tnn/fuse_shape_size.cpp
pass_tnn/fuse_shape_list_construct.cpp
pass_tnn/lower_concat.cpp
pass_tnn/lower_convolution_activation.cpp
pass_tnn/lower_power.cpp
)

set(tnn2pnnx_SRCS
${pnnx_pass_tnn_SRCS}
load_tnn.cpp
)

add_library(tnn2pnnx OBJECT ${tnn2pnnx_SRCS})
target_compile_definitions(tnn2pnnx PRIVATE BUILD_TNN2PNNX)
target_compile_options(tnn2pnnx PUBLIC "${TORCH_CXX_FLAGS}")

message(STATUS "Building with tnn2pnnx")
else()
message(STATUS "Building without tnn2pnnx")
endif()

if(NOT MSVC)
add_definitions(-Wall -Wextra)
endif()
Expand Down Expand Up @@ -765,6 +788,11 @@ if(onnxruntime_FOUND)
target_link_libraries(pnnx PRIVATE onnx2pnnx)
endif()

if(PNNX_TNN2PNNX)
set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TNN2PNNX)
target_link_libraries(pnnx PRIVATE tnn2pnnx)
endif()

if(PNNX_COVERAGE)
target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage)
target_link_libraries(pnnx PUBLIC -coverage -lgcov)
Expand Down
17 changes: 14 additions & 3 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ static std::string expand_expression(const Operator* op)
|| t == "ceil"
|| t == "cos"
|| t == "cosh"
|| t == "erf"
|| t == "exp"
|| t == "floor"
|| t == "log"
Expand Down Expand Up @@ -1062,6 +1063,7 @@ static std::string expand_expression(const Operator* op)
if (t == "ceil") unaryop = "torch.ceil";
if (t == "cos") unaryop = "torch.cos";
if (t == "cosh") unaryop = "torch.cosh";
if (t == "erf") unaryop = "torch.erf";
if (t == "exp") unaryop = "torch.exp";
if (t == "floor") unaryop = "torch.floor";
if (t == "log") unaryop = "torch.log";
Expand Down Expand Up @@ -2253,11 +2255,17 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
if (op->type == "Tensor.index_put" && it.first == "values")
{
fprintf(pyfp, "torch.tensor(%f)", param.f);
if (param.f == (int)param.f)
fprintf(pyfp, "torch.tensor(%.1f)", param.f);
else
fprintf(pyfp, "torch.tensor(%g)", param.f);
}
else
{
fprintf(pyfp, "%f", param.f);
if (param.f == (int)param.f)
fprintf(pyfp, "%.1f", param.f);
else
fprintf(pyfp, "%g", param.f);
}
}
if (param.type == 4)
Expand Down Expand Up @@ -2316,7 +2324,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
fprintf(pyfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(pyfp, "%f", param.af[i]);
if (param.af[i] == (int)param.af[i])
fprintf(pyfp, "%.1f", param.af[i]);
else
fprintf(pyfp, "%g", param.af[i]);
if (i + 1 != param.af.size() || param.af.size() == 1)
fprintf(pyfp, ",");
}
Expand Down
3 changes: 3 additions & 0 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ class Attribute
#if BUILD_ONNX2PNNX
Attribute(const onnx::TensorProto& t);
#endif
#if BUILD_TNN2PNNX
Attribute(FILE* bp);
#endif

Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);

Expand Down
Loading
Loading