Skip to content

Commit

Permalink
convert pnnx clone normal
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 23, 2021
1 parent 42e7160 commit e0124db
Show file tree
Hide file tree
Showing 25 changed files with 669 additions and 38 deletions.
12 changes: 3 additions & 9 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_cat.cpp
pass_level2/torch_chunk.cpp
pass_level2/torch_clamp.cpp
pass_level2/torch_clone.cpp
pass_level2/torch_dequantize.cpp
pass_level2/torch_flatten.cpp
pass_level2/torch_mean.cpp
pass_level2/torch_normal.cpp
pass_level2/torch_quantize_per_tensor.cpp
pass_level2/torch_sum.cpp
pass_level2/torch_split.cpp
Expand Down Expand Up @@ -252,7 +254,6 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_adaptive_max_pool1d.cpp
pass_ncnn/F_adaptive_max_pool2d.cpp
pass_ncnn/F_adaptive_max_pool3d.cpp
#pass_ncnn/F_alpha_dropout.cpp
pass_ncnn/F_avg_pool1d.cpp
pass_ncnn/F_avg_pool2d.cpp
pass_ncnn/F_avg_pool3d.cpp
Expand All @@ -261,12 +262,8 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_conv1d.cpp
pass_ncnn/F_conv2d.cpp
pass_ncnn/F_conv3d.cpp
#pass_ncnn/F_dropout.cpp
#pass_ncnn/F_dropout2d.cpp
#pass_ncnn/F_dropout3d.cpp
pass_ncnn/F_elu.cpp
pass_ncnn/F_embedding.cpp
#pass_ncnn/F_feature_alpha_dropout.cpp
pass_ncnn/F_gelu.cpp
pass_ncnn/F_group_norm.cpp
pass_ncnn/F_hardsigmoid.cpp
Expand Down Expand Up @@ -303,7 +300,6 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/nn_AdaptiveMaxPool1d.cpp
pass_ncnn/nn_AdaptiveMaxPool2d.cpp
pass_ncnn/nn_AdaptiveMaxPool3d.cpp
#pass_ncnn/nn_AlphaDropout.cpp
pass_ncnn/nn_AvgPool1d.cpp
pass_ncnn/nn_AvgPool2d.cpp
pass_ncnn/nn_AvgPool3d.cpp
Expand All @@ -318,9 +314,6 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/nn_Conv2d.cpp
pass_ncnn/nn_Conv3d.cpp
pass_ncnn/nn_ConvTranspose2d.cpp
#pass_ncnn/nn_Dropout.cpp
#pass_ncnn/nn_Dropout2d.cpp
#pass_ncnn/nn_Dropout3d.cpp
pass_ncnn/nn_ELU.cpp
pass_ncnn/nn_Embedding.cpp
pass_ncnn/nn_GELU.cpp
Expand Down Expand Up @@ -364,6 +357,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/Tensor_slice.cpp
pass_ncnn/Tensor_view.cpp
pass_ncnn/torch_clamp.cpp
pass_ncnn/torch_clone.cpp
pass_ncnn/torch_flatten.cpp
pass_ncnn/torch_mean.cpp
pass_ncnn/torch_permute.cpp
Expand Down
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level2/F_leaky_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,26 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_leaky_relu, 10)

class F_leaky_relu_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 negative_slope
aten::leaky_relu_ op_0 2 1 input negative_slope out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.leaky_relu";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_leaky_relu_1, 10)

} // namespace pnnx
51 changes: 51 additions & 0 deletions tools/pnnx/src/pass_level2/torch_clone.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_clone : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
prim::Constant op_0 0 1 memory_format value=%memory_format
aten::clone op_1 2 1 input memory_format out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.clone";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("memory_format").i == 0)
op->params["memory_format"] = "torch.contiguous_format";
if (captured_params.at("memory_format").i == 1)
op->params["memory_format"] = "torch.preserve_format";
if (captured_params.at("memory_format").i == 2)
op->params["memory_format"] = "torch.channels_last";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_clone, 20)

} // namespace pnnx
67 changes: 67 additions & 0 deletions tools/pnnx/src/pass_level2/torch_normal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_normal : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mean
pnnx.Input input_2 0 1 std
pnnx.Input input_3 0 1 generator
aten::normal op_0 4 1 input mean std generator out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.normal";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_normal, 20)

class torch_normal_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mean
pnnx.Input input_2 0 1 std
pnnx.Input input_3 0 1 generator
aten::normal_ op_0 4 1 input mean std generator out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.normal";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_normal_1, 20)

} // namespace pnnx
12 changes: 11 additions & 1 deletion tools/pnnx/src/pass_ncnn/convert_torch_chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,20 @@ void convert_torch_chunk(Graph& graph)
axis = input_rank + axis;
}

int chunks = op->params.at("chunks").i;

if (!op->inputs[0]->shape.empty())
{
int size = op->inputs[0]->shape[axis];
if (size % chunks != 0)
{
fprintf(stderr, "chunk with non-perfect divided size %d / %d is not supported\n", size, chunks);
}
}

if (axis > batch_index)
axis -= 1;

int chunks = op->params.at("chunks").i;
op->params["0"].type = 5;
op->params["0"].ai.resize(chunks, -233);

Expand Down
31 changes: 30 additions & 1 deletion tools/pnnx/src/pass_ncnn/nn_Hardtanh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "pass_ncnn.h"

#include <float.h>

namespace pnnx {

namespace ncnn {
Expand All @@ -26,7 +28,7 @@ class nn_Hardtanh : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Hardtanh op_0 1 1 input out min_val=%0 max_val=%1
nn.Hardtanh op_0 1 1 input out min_val=%min_val max_val=%max_val
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -40,6 +42,33 @@ pnnx.Output output 1 0 out
{
return "htanh";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
float min = -FLT_MAX;
float max = FLT_MAX;

if (captured_params.at("min_val").type == 2)
{
min = captured_params.at("min_val").i;
}
if (captured_params.at("min_val").type == 3)
{
min = captured_params.at("min_val").f;
}

if (captured_params.at("max_val").type == 2)
{
max = captured_params.at("max_val").i;
}
if (captured_params.at("max_val").type == 3)
{
max = captured_params.at("max_val").f;
}

op->params["0"] = min;
op->params["1"] = max;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Hardtanh, 20)
Expand Down
31 changes: 30 additions & 1 deletion tools/pnnx/src/pass_ncnn/torch_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "pass_ncnn.h"

#include <float.h>

namespace pnnx {

namespace ncnn {
Expand All @@ -26,7 +28,7 @@ class torch_clamp : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.clamp op_0 1 1 input out min=%0 max=%1
torch.clamp op_0 1 1 input out min=%min max=%max
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -40,6 +42,33 @@ pnnx.Output output 1 0 out
{
return "clamp";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
float min = -FLT_MAX;
float max = FLT_MAX;

if (captured_params.at("min").type == 2)
{
min = captured_params.at("min").i;
}
if (captured_params.at("min").type == 3)
{
min = captured_params.at("min").f;
}

if (captured_params.at("max").type == 2)
{
max = captured_params.at("max").i;
}
if (captured_params.at("max").type == 3)
{
max = captured_params.at("max").f;
}

op->params["0"] = min;
op->params["1"] = max;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_clamp, 20)
Expand Down
53 changes: 53 additions & 0 deletions tools/pnnx/src/pass_ncnn/torch_clone.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_clone : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.clone op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Noop";
}

const char* name_str() const
{
return "clone";
}

void write(Operator* /*op*/, const std::map<std::string, Parameter>& /*captured_params*/) const
{
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_clone, 20)

} // namespace ncnn

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ pnnx_add_test(torch_argmin)
pnnx_add_test(torch_cat)
pnnx_add_test(torch_chunk)
pnnx_add_test(torch_clamp)
pnnx_add_test(torch_clone)
pnnx_add_test(torch_flatten)
pnnx_add_test(torch_mean)
pnnx_add_test(torch_permute)
Expand Down
Loading

0 comments on commit e0124db

Please sign in to comment.