Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 10, 2025
1 parent 4a2e8a5 commit ed0bdd0
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2253,11 +2253,11 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
{
if (op->type == "Tensor.index_put" && it.first == "values")
{
fprintf(pyfp, "torch.tensor(%f)", param.f);
fprintf(pyfp, "torch.tensor(%g)", param.f);
}
else
{
fprintf(pyfp, "%f", param.f);
fprintf(pyfp, "%g", param.f);
}
}
if (param.type == 4)
Expand Down Expand Up @@ -2316,7 +2316,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
fprintf(pyfp, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(pyfp, "%f", param.af[i]);
fprintf(pyfp, "%g", param.af[i]);
if (i + 1 != param.af.size() || param.af.size() == 1)
fprintf(pyfp, ",");
}
Expand Down
47 changes: 47 additions & 0 deletions tools/pnnx/src/pass_level2/F_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,51 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_onnx, 130)

class F_batch_norm_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
pnnx.Attribute op_0 0 1 weight @data=(%num_features)f32
pnnx.Attribute op_1 0 1 bias @data=(%num_features)f32
tnn.BatchNormCxx op_2 3 1 input weight bias out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input 0 1 input
pnnx.Attribute mean 0 1 running_mean
pnnx.Attribute var 0 1 running_var
pnnx.Attribute weight 0 1 weight @data=%op_0.data
pnnx.Attribute bias 0 1 bias @data=%op_1.data
F.batch_norm bn 5 1 input running_mean running_var weight bias out
pnnx.Output output 1 0 out
)PNNXIR";
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params) const
{
const int num_features = captured_params.at("num_features").i;

Operator* op_mean = ops.at("mean");
op_mean->attrs["data"] = Attribute({num_features}, std::vector<float>(num_features, 0.f));

Operator* op_var = ops.at("var");
op_var->attrs["data"] = Attribute({num_features}, std::vector<float>(num_features, 1.f));

Operator* op_bn = ops.at("bn");
op_bn->params["eps"] = 0.f;
op_bn->inputnames = {"input", "running_mean", "running_var", "weight", "bias"};
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_batch_norm_tnn, 130)

} // namespace pnnx

0 comments on commit ed0bdd0

Please sign in to comment.