Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick]C++ support register pass via PassDesc #36302

Conversation

Avin0323
Copy link
Contributor

@Avin0323 Avin0323 commented Oct 9, 2021

PR types

New features

PR changes

Others

Describe

(cherry picked from PR #36095)

PR主要功能:支持C++开发注册GeneratePass,简化针对fusion等子图优化场景开发方式。

背景

#35602 #35708 提供Python侧开发子图替换类Pass的方式:

  • 利用Paddle Python API或者辅助类型定义子图program用来匹配/替换图;
  • Python侧注册Pass时,将注册函数最终转换为protobuf定义的PassDesc数据形式,供C++侧进行解析完成Pass实例注册。

本PR在C++侧提供类似Python开发注册GeneratePass的API。

方案设计

定义匹配/替换子图

Python侧支持使用两种方式用于开发Pass,类似定义program的方式来定义匹配/替换子图:

  1. 直接使用Paddle Python API;
  2. 使用辅助类型;

Python侧定义子图通过定义一个函数(或lambda)完成,函数参数为子图的输入数据,返回值为输出数据,一个示例如下:

// 使用Paddle Python API方式
def pattern(x, y1, y2):
        mul1 = paddle.matmul(x, y1)
        mul2 = paddle.matmul(x, y2)
        return mul1, mul2

// 使用辅助类型方式
def pattern(x, y1, y2):
        mul1 = ir.PassDesc.OP.matmul_v2(X=x, Y=y1)
        mul2 = ir.PassDesc.OP.matmul_v2(X=x, Y=y2)
        return mul1, mul2

Python侧可以直接生成子图ProgramDesc,而C++中需要实现相同的功能使用的API较为复杂,因此方案采用如下方式实现类似Python侧的子图定义功能

  • 使用VarHelperOpHelperSubgraphHelper三个辅助类型完成定义子图转换ProgramDesc功能;
  • 使用VAR_OP_SUBGRAPH_宏创建对应上述三个辅助类型对象,从而尽可能简化代码量及内部细节;
  • 使用lambda表达式定义子图,并捕获表达式需要绑定的子图对象。

C++上定义子图示例如下:

SUBGRAPH_(pattern) = [subgraph = &pattern](VAR_(x), VAR_(y), VAR_(z)) {
  auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out");
  auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out");
  return ewadd2;
};
SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) {
  return OP_(sum)({"X", {x, y, z}}).Out("Out");
};

注册GeneratePass

Python侧Pass注册使用装饰器RegisterPass完成,将注册函数传递到C++中,在获取Pass实例时调用获取protobuf序列化数据,注册代码示例如下:

@ir.RegisterPass
def generate_add_n():
    def pattern(x, y, z):
        return paddle.add(paddle.add(x, y), z)
    def replace(x, y, z):
        return paddle.add_n([x, y, z])
    return pattern, replace

C++中使用类似的注册方式,使用宏REGISTER_GENERATE_PASS完成Pass的注册,其参数表示该pass的类型,代码示例如下:

REGISTER_GENERATE_PASS(generate_multi_add_to_addn) {
  // pattern
  SUBGRAPH_(pattern) = [subgraph = &pattern](VAR_(x), VAR_(y), VAR_(z)) {
    auto ewadd1 = OP_(elementwise_add)({{"X", x}, {"Y", y}}).Out("Out");
    auto ewadd2 = OP_(elementwise_add)({{"X", ewadd1}, {"Y", z}}).Out("Out");
    return ewadd2;
  };
  // replace
  SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) {
    return OP_(sum)({"X", {x, y, z}}).Out("Out");
  };
  return {pattern, replace};

@paddle-bot-old
Copy link

paddle-bot-old bot commented Oct 9, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lanxianghit lanxianghit merged commit 21c65f6 into PaddlePaddle:release/2.2 Oct 11, 2021
@Avin0323 Avin0323 deleted the cherry-pick-cc-register-generate-pass branch October 15, 2021 03:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants