Skip to content

Commit

Permalink
tnn2pnnx (#5898)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Feb 12, 2025
1 parent 2389090 commit 5ea7681
Show file tree
Hide file tree
Showing 50 changed files with 3,195 additions and 3 deletions.
2 changes: 2 additions & 0 deletions tools/pnnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ else()
set(onnxruntime_FOUND FALSE)
endif()

option(PNNX_TNN2PNNX "build tnn2pnnx" ON)

add_subdirectory(src)

enable_testing()
Expand Down
28 changes: 28 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,29 @@ else()
message(STATUS "Building without onnx2pnnx")
endif()

if(PNNX_TNN2PNNX)
set(pnnx_pass_tnn_SRCS
pass_tnn/fuse_shape_size.cpp
pass_tnn/fuse_shape_list_construct.cpp
pass_tnn/lower_concat.cpp
pass_tnn/lower_convolution_activation.cpp
pass_tnn/lower_power.cpp
)

set(tnn2pnnx_SRCS
${pnnx_pass_tnn_SRCS}
load_tnn.cpp
)

add_library(tnn2pnnx OBJECT ${tnn2pnnx_SRCS})
target_compile_definitions(tnn2pnnx PRIVATE BUILD_TNN2PNNX)
target_compile_options(tnn2pnnx PUBLIC "${TORCH_CXX_FLAGS}")

message(STATUS "Building with tnn2pnnx")
else()
message(STATUS "Building without tnn2pnnx")
endif()

if(NOT MSVC)
add_definitions(-Wall -Wextra)
endif()
Expand Down Expand Up @@ -765,6 +788,11 @@ if(onnxruntime_FOUND)
target_link_libraries(pnnx PRIVATE onnx2pnnx)
endif()

if(PNNX_TNN2PNNX)
set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TNN2PNNX)
target_link_libraries(pnnx PRIVATE tnn2pnnx)
endif()

if(PNNX_COVERAGE)
target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage)
target_link_libraries(pnnx PUBLIC -coverage -lgcov)
Expand Down
17 changes: 14 additions & 3 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ static std::string expand_expression(const Operator* op)
|| t == "ceil"
|| t == "cos"
|| t == "cosh"
|| t == "erf"
|| t == "exp"
|| t == "floor"
|| t == "log"
Expand Down Expand Up @@ -1062,6 +1063,7 @@ static std::string expand_expression(const Operator* op)
if (t == "ceil") unaryop = "torch.ceil";
if (t == "cos") unaryop = "torch.cos";
if (t == "cosh") unaryop = "torch.cosh";
if (t == "erf") unaryop = "torch.erf";
if (t == "exp") unaryop = "torch.exp";
if (t == "floor") unaryop = "torch.floor";
if (t == "log") unaryop = "torch.log";
Expand Down Expand Up @@ -2253,11 +2255,17 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
if (op->type == "Tensor.index_put" && it.first == "values")
{
fprintf(pyfp, "torch.tensor(%f)", param.f);
if (param.f == (int)param.f)
fprintf(pyfp, "torch.tensor(%.1f)", param.f);
else
fprintf(pyfp, "torch.tensor(%g)", param.f);
}
else
{
fprintf(pyfp, "%f", param.f);
if (param.f == (int)param.f)
fprintf(pyfp, "%.1f", param.f);
else
fprintf(pyfp, "%g", param.f);
}
}
if (param.type == 4)
Expand Down Expand Up @@ -2316,7 +2324,10 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
fprintf(pyfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(pyfp, "%f", param.af[i]);
if (param.af[i] == (int)param.af[i])
fprintf(pyfp, "%.1f", param.af[i]);
else
fprintf(pyfp, "%g", param.af[i]);
if (i + 1 != param.af.size() || param.af.size() == 1)
fprintf(pyfp, ",");
}
Expand Down
3 changes: 3 additions & 0 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ class Attribute
#if BUILD_ONNX2PNNX
Attribute(const onnx::TensorProto& t);
#endif
#if BUILD_TNN2PNNX
Attribute(FILE* bp);
#endif

Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);

Expand Down
Loading

0 comments on commit 5ea7681

Please sign in to comment.