Skip to content

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 11, 2025
1 parent 0af0a57 commit 4dc0e2c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tools/pnnx/src/pass_level2/torch_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class torch_matmul_tnn : public GraphRewriterPass
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 other
tnn.MatMul op_0 2 1 input other out arg0=%weight_position
tnn.MatMul op_0 2 1 input other out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -81,16 +81,16 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const int weight_position = captured_params.at("weight_position").i;
if (weight_position == 0)
if (captured_params.find("op_0.arg0") != captured_params.end())
{
// swap input and weight
std::swap(op->inputs[0], op->inputs[1]);
const int weight_position = captured_params.at("op_0.arg0").i;
if (weight_position == 0)
{
// swap input and weight
std::swap(op->inputs[0], op->inputs[1]);
}
}
}

protected:
mutable int weight_position;
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_matmul_tnn, 90)
Expand Down
36 changes: 36 additions & 0 deletions tools/pnnx/src/pass_level2/torch_sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,40 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_onnx, 50)

class torch_sum_tnn : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
tnn.ReduceSum op_0 1 1 input out %*=%*
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.sum";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
std::vector<int> dim;
for (int i = 1;; i++)
{
if (captured_params.find("op_0.arg" + std::to_string(i)) == captured_params.end())
break;

dim.push_back(captured_params.at("op_0.arg" + std::to_string(i)).i);
}

op->params["dim"] = dim;
op->params["keepdim"] = captured_params.at("op_0.arg0").i ? true : false;
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_sum_tnn, 50)

} // namespace pnnx

0 comments on commit 4dc0e2c

Please sign in to comment.