Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 10, 2025
1 parent a3c33b8 commit ca57ab1
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 0 deletions.
77 changes: 77 additions & 0 deletions tools/pnnx/src/pass_level2/F_adaptive_avg_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,81 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_onnx, 120)

class F_adaptive_avg_pool2d_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Pooling op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.adaptive_avg_pool2d";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int pool_type = captured_params.at("op_0.arg0").i;
if (pool_type != 1)
return false;

const int pad_h = captured_params.at("op_0.arg5").i;
const int pad_w = captured_params.at("op_0.arg6").i;
if (pad_h != 0 || pad_w != 0)
return false;

const int kernel_h = captured_params.at("op_0.arg1").i;
const int kernel_w = captured_params.at("op_0.arg2").i;
if (kernel_h == 0 && kernel_w == 0)
return true;

if (captured_params.find("op_0.arg11") != captured_params.end())
{
const int is_adaptive = captured_params.at("op_0.arg11").i;
if (is_adaptive == 1)
return true;
}

return false;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
// captured_params.at("op_0.arg0"); // pool_type
const int kernel_h = captured_params.at("op_0.arg1").i;
const int kernel_w = captured_params.at("op_0.arg2").i;
if (kernel_h == 0 && kernel_w == 0)
{
// global pooling
op->params["output_size"] = {1, 1};
}

// captured_params.at("op_0.arg3"); // stride_h
// captured_params.at("op_0.arg4"); // stride_w
// captured_params.at("op_0.arg5"); // pad_h
// captured_params.at("op_0.arg6"); // pad_w
// captured_params.at("op_0.arg7"); // kernel_index_h
// captured_params.at("op_0.arg8"); // kernel_index_w
// captured_params.at("op_0.arg9"); // pad_type
// captured_params.at("op_0.arg10"); // ceil_mode

if (captured_params.find("op_0.arg11") != captured_params.end())
{
const int is_adaptive = captured_params.at("op_0.arg11").i;
if (is_adaptive == 1)
{
op->params["output_size"] = {captured_params.at("op_0.arg12").i, captured_params.at("op_0.arg13").i};
}
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_avg_pool2d_tnn, 120)

} // namespace pnnx
77 changes: 77 additions & 0 deletions tools/pnnx/src/pass_level2/F_adaptive_max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,81 @@ pnnx.Output output 2 0 out indices

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d, 120)

class F_adaptive_max_pool2d_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Pooling op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.adaptive_max_pool2d";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int pool_type = captured_params.at("op_0.arg0").i;
if (pool_type != 0)
return false;

const int pad_h = captured_params.at("op_0.arg5").i;
const int pad_w = captured_params.at("op_0.arg6").i;
if (pad_h != 0 || pad_w != 0)
return false;

const int kernel_h = captured_params.at("op_0.arg1").i;
const int kernel_w = captured_params.at("op_0.arg2").i;
if (kernel_h == 0 && kernel_w == 0)
return true;

if (captured_params.find("op_0.arg11") != captured_params.end())
{
const int is_adaptive = captured_params.at("op_0.arg11").i;
if (is_adaptive == 1)
return true;
}

return false;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
// captured_params.at("op_0.arg0"); // pool_type
const int kernel_h = captured_params.at("op_0.arg1").i;
const int kernel_w = captured_params.at("op_0.arg2").i;
if (kernel_h == 0 && kernel_w == 0)
{
// global pooling
op->params["output_size"] = {1, 1};
}

// captured_params.at("op_0.arg3"); // stride_h
// captured_params.at("op_0.arg4"); // stride_w
// captured_params.at("op_0.arg5"); // pad_h
// captured_params.at("op_0.arg6"); // pad_w
// captured_params.at("op_0.arg7"); // kernel_index_h
// captured_params.at("op_0.arg8"); // kernel_index_w
// captured_params.at("op_0.arg9"); // pad_type
// captured_params.at("op_0.arg10"); // ceil_mode

if (captured_params.find("op_0.arg11") != captured_params.end())
{
const int is_adaptive = captured_params.at("op_0.arg11").i;
if (is_adaptive == 1)
{
op->params["output_size"] = {captured_params.at("op_0.arg12").i, captured_params.at("op_0.arg13").i};
}
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_adaptive_max_pool2d_tnn, 120)

} // namespace pnnx
71 changes: 71 additions & 0 deletions tools/pnnx/src/pass_level2/F_avg_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,75 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_onnx, 120)

class F_avg_pool2d_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Pooling op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.avg_pool2d";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int pool_type = captured_params.at("op_0.arg0").i;
if (pool_type != 1)
return false;

const int kernel_h = captured_params.at("op_0.arg1").i;
const int kernel_w = captured_params.at("op_0.arg2").i;
if (kernel_h == 0 && kernel_w == 0)
return false;

if (captured_params.find("op_0.arg11") != captured_params.end())
{
const int is_adaptive = captured_params.at("op_0.arg11").i;
if (is_adaptive != 0)
return false;
}

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
// captured_params.at("op_0.arg0"); // pool_type
op->params["kernel_size"] = {captured_params.at("op_0.arg1").i, captured_params.at("op_0.arg2").i};
op->params["stride"] = {captured_params.at("op_0.arg3").i, captured_params.at("op_0.arg4").i};
op->params["padding"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i};

const int kernel_index_h = captured_params.at("op_0.arg7").i;
const int kernel_index_w = captured_params.at("op_0.arg8").i;
if (kernel_index_h != -1 || kernel_index_w != -1)
{
fprintf(stderr, "unsupported F.avg_pool2d kernel_index %d %d\n", kernel_index_h, kernel_index_w);
}

const int pad_type = captured_params.at("op_0.arg9").i;
if (pad_type > 0)
{
fprintf(stderr, "unsupported F.avg_pool2d pad_type %d\n", pad_type);
}

op->params["ceil_mode"] = captured_params.at("op_0.arg10").i ? true : false;
// captured_params.at("op_0.arg11"); // is_adaptive
// captured_params.at("op_0.arg12"); // output_h
// captured_params.at("op_0.arg13"); // output_w

op->params["count_include_pad"] = false;
op->params["divisor_override"] = Parameter();
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_avg_pool2d_tnn, 120)

} // namespace pnnx
70 changes: 70 additions & 0 deletions tools/pnnx/src/pass_level2/F_max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,74 @@ pnnx.Output output 2 0 out indices

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_onnx_1, 120)

class F_max_pool2d_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Pooling op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.max_pool2d";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const int pool_type = captured_params.at("op_0.arg0").i;
if (pool_type != 0)
return false;

const int kernel_h = captured_params.at("op_0.arg1").i;
const int kernel_w = captured_params.at("op_0.arg2").i;
if (kernel_h == 0 && kernel_w == 0)
return false;

if (captured_params.find("op_0.arg11") != captured_params.end())
{
const int is_adaptive = captured_params.at("op_0.arg11").i;
if (is_adaptive != 0)
return false;
}

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
// captured_params.at("op_0.arg0"); // pool_type
op->params["kernel_size"] = {captured_params.at("op_0.arg1").i, captured_params.at("op_0.arg2").i};
op->params["stride"] = {captured_params.at("op_0.arg3").i, captured_params.at("op_0.arg4").i};
op->params["padding"] = {captured_params.at("op_0.arg5").i, captured_params.at("op_0.arg6").i};

const int kernel_index_h = captured_params.at("op_0.arg7").i;
const int kernel_index_w = captured_params.at("op_0.arg8").i;
if (kernel_index_h != -1 || kernel_index_w != -1)
{
fprintf(stderr, "unsupported F.avg_pool2d kernel_index %d %d\n", kernel_index_h, kernel_index_w);
}

const int pad_type = captured_params.at("op_0.arg9").i;
if (pad_type > 0)
{
fprintf(stderr, "unsupported F.avg_pool2d pad_type %d\n", pad_type);
}

op->params["ceil_mode"] = captured_params.at("op_0.arg10").i ? true : false;
// captured_params.at("op_0.arg11"); // is_adaptive
// captured_params.at("op_0.arg12"); // output_h
// captured_params.at("op_0.arg13"); // output_w

op->params["return_indices"] = false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_max_pool2d_tnn, 120)

} // namespace pnnx

0 comments on commit ca57ab1

Please sign in to comment.