Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 12, 2025
1 parent fb159d9 commit 117ea4e
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ static std::string expand_expression(const Operator* op)
|| t == "ceil"
|| t == "cos"
|| t == "cosh"
|| t == "erf"
|| t == "exp"
|| t == "floor"
|| t == "log"
Expand Down Expand Up @@ -1062,6 +1063,7 @@ static std::string expand_expression(const Operator* op)
if (t == "ceil") unaryop = "torch.ceil";
if (t == "cos") unaryop = "torch.cos";
if (t == "cosh") unaryop = "torch.cosh";
if (t == "erf") unaryop = "torch.erf";
if (t == "exp") unaryop = "torch.exp";
if (t == "floor") unaryop = "torch.floor";
if (t == "log") unaryop = "torch.log";
Expand Down
10 changes: 10 additions & 0 deletions tools/pnnx/src/load_tnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,13 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph)

fprintf(stderr, "model constant %s\n", name.c_str());

if (op_map.find(name) == op_map.end())
{
// FIXME
Attribute skip(bp);
continue;
}

Operator* op_constant = op_map.at(name);

op_constant->attrs["data"] = Attribute(bp);
Expand All @@ -616,10 +623,13 @@ int load_tnn(const std::string& tnnpath, Graph& pnnx_graph)
for (Operator* op : pnnx_graph.ops)
{
// unary
if (op->type == "tnn.Erf") op->type = "aten::erf";
if (op->type == "tnn.Log") op->type = "aten::log";
if (op->type == "tnn.ReLU") op->type = "aten::relu";
if (op->type == "tnn.ReLU6") op->type = "aten::relu6";
if (op->type == "tnn.Sigmoid") op->type = "aten::sigmoid";
if (op->type == "tnn.Sqrt") op->type = "aten::sqrt";
if (op->type == "tnn.Tanh") op->type = "aten::tanh";

// binary
if (op->type == "tnn.Add") op->type = "aten::add";
Expand Down
28 changes: 28 additions & 0 deletions tools/pnnx/src/pass_level2/F_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,32 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_onnx, 140)

class F_embedding_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 weight
tnn.Gather op_0 2 1 input weight out arg0=0 arg1=1 arg2=0
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.embedding";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
op->params["scale_grad_by_freq"] = false;
op->params["sparse"] = false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_embedding_tnn, 140)

} // namespace pnnx
39 changes: 39 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,43 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_onnx, 60)

class Tensor_to_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Cast op_0 1 1 input out arg0=%to
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.to";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int to = captured_params.at("to").i;

op->params["non_blocking"] = false;
op->params["copy"] = false;
op->params["memory_format"] = "torch.preserve_format";

if (to == 0) op->params["dtype"] = "torch.float";
if (to == 1) op->params["dtype"] = "torch.half";
if (to == 2) op->params["dtype"] = "torch.int8";
if (to == 3) op->params["dtype"] = "torch.int";
if (to == 4) op->params["dtype"] = "torch.bfloat16";
if (to == 5) op->params["dtype"] = "torch.long";
if (to == 6) op->params["dtype"] = "torch.uint32";
if (to == 8) op->params["dtype"] = "torch.uint8";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_tnn, 60)

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level2/torch_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pnnx.Output output 1 0 out
return "torch.full";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0];
op->params["dtype"] = "torch.float";
Expand Down
29 changes: 29 additions & 0 deletions tools/pnnx/src/pass_level2/torch_full_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,33 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like, 20)

class torch_full_like_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input 0 1 input
tnn.Shape op_0 1 1 input shape
pnnx.Attribute value 0 1 value @data=(1)f32
tnn.ConstantOfShape op_2 2 1 shape value out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.full_like";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["fill_value"] = ((const float*)captured_attrs.at("value.data").data.data())[0];
op->params["dtype"] = "torch.float";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_full_like_tnn, 20)

} // namespace pnnx
39 changes: 39 additions & 0 deletions tools/pnnx/src/pass_level2/torch_squeeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,43 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_onnx_1, 60)

class torch_squeeze_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.Squeeze op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.squeeze";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int dims_count = captured_params.at("op_0.arg0").i;
if (dims_count == 1)
{
op->params["dim"] = captured_params.at("op_0.arg1").i;
}
else
{
std::vector<int> dims(dims_count);
for (int i = 0; i < dims_count; i++)
{
dims[i] = captured_params.at("op_0.arg" + std::to_string(i + 1)).i;
}
op->params["dim"] = dims;
}
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_squeeze_tnn, 60)

} // namespace pnnx
3 changes: 3 additions & 0 deletions tools/pnnx/src/pass_level3/fuse_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ static bool operand_maybe_tensor(const Operand* operand)
|| op->type == "aten::ceil"
|| op->type == "aten::cos"
|| op->type == "aten::cosh"
|| op->type == "aten::erf"
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::log"
Expand Down Expand Up @@ -648,6 +649,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
|| op->type == "aten::ceil"
|| op->type == "aten::cos"
|| op->type == "aten::cosh"
|| op->type == "aten::erf"
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::log"
Expand Down Expand Up @@ -888,6 +890,7 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
|| op->type == "aten::cos"
|| op->type == "aten::cosh"
|| op->type == "aten::div"
|| op->type == "aten::erf"
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::floor_divide"
Expand Down
6 changes: 6 additions & 0 deletions tools/pnnx/src/pass_level5/eval_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ static std::string eval_expression(const Operator* op)
|| t == "ceil"
|| t == "cos"
|| t == "cosh"
|| t == "erf"
|| t == "exp"
|| t == "floor"
|| t == "log"
Expand Down Expand Up @@ -262,6 +263,11 @@ static std::string eval_expression(const Operator* op)
float r = cosh(af);
exprstack.push(std::to_string(r));
}
if (t == "erf")
{
float r = erf(af);
exprstack.push(std::to_string(r));
}
if (t == "exp")
{
float r = exp(af);
Expand Down
2 changes: 2 additions & 0 deletions tools/pnnx/src/pass_ncnn/expand_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
|| t == "atan"
|| t == "ceil"
|| t == "cos"
|| t == "erf"
|| t == "exp"
|| t == "floor"
|| t == "log"
Expand Down Expand Up @@ -154,6 +155,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
if (t == "ceil") op_unary->params["0"] = 3;
if (t == "cos") op_unary->params["0"] = 10;
if (t == "exp") op_unary->params["0"] = 7;
if (t == "erf") fprintf(stderr, "UnaryOp erf not supported yet\n"); // TODO
if (t == "floor") op_unary->params["0"] = 2;
if (t == "log") op_unary->params["0"] = 8;
if (t == "log10") op_unary->params["0"] = 17;
Expand Down

0 comments on commit 117ea4e

Please sign in to comment.