Skip to content

常量折叠方案评审

zhangyuqin1998 edited this page Dec 13, 2023 · 3 revisions

需求介绍

常量折叠是比较基础的 pass 之一,paddle 在 pir/transmorm 模块中实现常量折叠 pass,它会遍历 program 中的 op,若判断出 op 能够进行折叠,则可以从传入的 scope 中获取常量,并进行计算,将计算后的结果存入 scope,并插入 get_prameter 替代折叠的 op。

竞品分析

OneFlow

oneflow 通过 mlir 的 fold 的技术路线实现的常量折叠。 oneflow 实现了 UnaryFold 和 BinaryFold 两种基本的折叠方式。上层 op 根据自己的参数数目直接复用 UnaryFold 和 BinaryFold,并且 op 在 fold 里硬编码自己的 kernel。 这种做法的好处就是设计简单,缺点是扩展性差,需要为 op 硬编码 kernel。

OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
                       const std::function<MaybeTensor(const TensorPtr&)>& f) {
  ::oneflow::LazyMode::Guard guard{false};
  if (!operands.front()) { return {}; }  // Important!

  const auto attr_dict = operands.front().cast<mlir::DictionaryAttr>();
  auto attrs = NamedAttrList(attr_dict);
  const auto tensor = support::DenseElementsAttrToTensor(
      attr_dict.get("value"), attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr()),
      attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr()));
  const auto result = f(tensor).GetPtrOrThrow();
  attrs.set("value", support::TensorToDenseElementsAttr(result, ctx));
  attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), GenNewVariableOpName(ctx));
  attrs.set(OpTrait::TensorSource<void>::getDataTypeAttrName(),
            attr_dict.get(OpTrait::TensorSource<void>::getDataTypeAttrName()));

  return attrs.getDictionary(ctx);
}

OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
                        const std::function<MaybeTensor(const TensorPtr&, const TensorPtr&)>& f) {
  ::oneflow::LazyMode::Guard guard{false};
  if (!(operands.front() && operands.back())) { return {}; }  // Important!
  auto lhs_attr_dict = operands.front().cast<mlir::DictionaryAttr>();
  auto rhs_attr_dict = operands.back().cast<mlir::DictionaryAttr>();
  if (!DictionaryAttrsHaveSameDataType({lhs_attr_dict, rhs_attr_dict})) {
    llvm::errs()
        << "Input tensors should have same data type in binary operation of constant folding."
        << "\n";
    return nullptr;
  }
}

//////////////////////////////////////////


OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return UnaryFold(getContext(), operands, [this](const auto& tensor) {
    std::vector<int32_t> perm_;
    for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast<IntegerAttr>().getSInt()); }
    return functional::Transpose(tensor, perm_);
  });
}

OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return UnaryFold(getContext(), operands, [this](const auto& tensor) {
    std::vector<int64_t> shape_vec;
    for (auto& x : getShape().getValue()) {
      shape_vec.emplace_back(x.cast<mlir::IntegerAttr>().getValue().getSExtValue());
    }
    return functional::Reshape(
        tensor, ::oneflow::Shape(::oneflow::DimVector(shape_vec.begin(), shape_vec.end())));
  });
}

OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return UnaryFold(getContext(), operands, [this](const auto& tensor) -> MaybeTensor {
    if (getHasIntOperand()) { return functional::ScalarAdd(tensor, getIntOperand(), 1, false); }
    if (getHasFloatOperand()) {
      return functional::ScalarAdd(tensor, getFloatOperand().convertToDouble(), 1, false);
    }
    emitError("Scalar op must has a int operand or a float operand.");
    return TensorPtr();
  });
}

OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return UnaryFold(getContext(), operands, functional::Sqrt);
}

OpFoldResult BroadcastMulOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return BinaryFold(getContext(), operands, functional::Mul);
}

OpFoldResult BroadcastDivOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return BinaryFold(getContext(), operands, functional::Div);
}

OpFoldResult BroadcastSubOp::fold(FoldAdaptor adaptor) {
  auto operands = adaptor.getOperands();
  return BinaryFold(getContext(), operands, [](const auto& lhs, const auto& rhs) -> MaybeTensor {
    return functional::Sub(lhs, rhs, /*alpha=*/1.0, false);
  });
}

TVM

tvm 构造了一个 dlcontext,并设置设备为 cpu,创建 llvm target。最后构造了一个 Interpreter,执行 Interpreter::Eval 完成计算。引入执行器的好处在于直接复用执行器,编码简单,扩展性好,但是需要引入对执行器的依赖。并且执行器逻辑复杂,会引入一些不必要的计算。

Expr FoldConstant(const Expr& expr, const IRModule& mod) {
	using tvm::transform::PassContext;
	DLContext ctx;
	ctx.device_type = kDLCPU;
	ctx.device_id = 0;
	Target target = Target::Create("llvm");
	// use a fresh build context
	// in case we are already in a build context.
	With<PassContext> fresh_build_ctx(PassContext::Create());

 return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
}

方案设计

基于执行器进行折叠

总体介绍

paddle 的执行器已经较为成熟。在常量折叠 pass 中,可以构造临时 program,然后直接引用执行器执行临时 program,并取得计算成果。

方案流程

  1. 根据待折叠 op 构建临时 program
  2. PdOpLowerToKernel,将原 program 替换成带有 kernel 信息的 program
  3. 创建 InterpreterCore
  4. 执行 InterpreterCore 并获取 output tensor
  5. 使用 get_parameter_op 替换原 op 代码如:
void Rewrite(pir::Operation* op,
               pir::PatternRewriter& rewriter) const override {  // NOLINT
    pir::Program new_program(ir_context());
    auto output_var_name = BuildProgramFromOperation(op, &new_program);
    
    // execute program
    exe_config_->skip_gc_vars.insert(output_var_name);
    auto kernel_program =
        paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
    paddle::framework::InterpreterCore core(
        place_, {}, kernel_program->block(), scope_, *exe_config_);
    
    core.Run({});
    
    // TODO(liuyuanle): support multiple output
    auto get_parameter_op = rewriter.Build<pir::GetParameterOp>(
        output_var_name, op->result(0).type());
    get_parameter_op->set_attribute(
        kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));
    
    rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
    rewriter.EraseOp(op);
}

优缺点分析

好处:

  • 执行器比较成熟,不会踩坑
  • 复用执行器代码,易于维护
  • 编码简单

坏处:

  • 引用执行器,在逻辑上存在循环依赖的风险
  • 我们只需要执行一个 op,执行器过于重量级,会带来不必要的计算开销(如,只执行一个 op的话,就没有必要 pd_lower_to_kernel)

基于kernel进行折叠

总体介绍

paddle 将 op 的元信息编码在 yaml 文件中,并通过 codegen 来生成 op 的定义。其中,codegen 生成的 OpInfoTuple 结构体将 op 的元信息打包,提供给使用者:

OpInfoTuple {op_name}::GetOpInfo() {{
  std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
  std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
  std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
  paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}});
  return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}

所以,我们可以不必引入执行器,而是通过 OpInfoTuple 拿到 op 的所有元信息,进而拿到 kernel 函数指针,进而直接进行计算。

方案流程

Rewrite 的主要流程:

  1. 设置 output tensor name,并在 scope 中创建对应的 Variable
  2. 获取 op_info_parser(持有 OpInfoTuple),并获取meta信息
  3. 根据上述信息,构建 kernel key
  4. 在 KernelFactory 中根据 kernel key 找到 kernel_fn(这个 kernel_fn 是使用宏来注册 kernel 时注册到 KernelFactory 中的,它是 kernel 的 warpper)
  5. 获取 InferMetaInterface,并构建 InferMetaContext
  6. 执行 InferMetaInterface::infer_meta(InferMetaContext),为 output tensor 设置 meta 信息
  7. 构建 KernelContext
  8. 执行 kernel_fn(&kernel_context),并获取 output tensor
  9. 使用 get_parameter_op 替换原 op 代码如:
void Rewrite(pir::Operation* op,
               pir::PatternRewriter& rewriter) const override {  // NOLINT
    pir::IrContext* ctx = pir::IrContext::Instance();
    ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
    ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>();
    
    // 0 设置 output name, build value
    std::string op_name = op_item->name();

    std::string output_name = "@constant_folding_pass@_" + std::to_string(suffix_++);
    paddle::framework::Variable* var = exec_info_->GetScope()->Var(output_name);
    phi::DenseTensor * ts = var->GetMutable<phi::DenseTensor>();
    exec_info_->Add(op_item->result(0), output_name);

    // 1 获取 op_info_parser
    paddle::dialect::OpYamlInfoParser op_info_parser(op_item->
        dyn_cast<paddle::dialect::OpYamlInfoInterface>().
        GetOpInfo());
    // 2 获取 kernel name
    auto kernel_fn_str = op_info_parser.OpRuntimeInfo().kernel_func;
    auto& data_type_info = op_info_parser.OpRuntimeInfo().kernel_key_dtype;

    // 3 获取 data type
    phi::DataType kernel_data_type = phi::DataType::FLOAT32;//phi::DataType::UNDEFINED;

    // 4 设置其它 kernel 信息
    phi::Backend kernel_backend = phi::Backend::CPU;
    phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;

    // 5 获取 kernel fn
    phi::KernelKey kernel_key = phi::KernelKey(kernel_backend, kernel_layout, kernel_data_type);
    auto kernel_fn = phi::KernelFactory::Instance().SelectKernelOrThrowError(
       kernel_fn_str, kernel_key).kernel;
    
    // 6 build ctx
    pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
    auto* infer_meta_interface =
        op_info.GetInterfaceImpl<paddle::dialect::InferMetaInterface>();

    phi::InferMetaContext infer_meta_context;
    phi::KernelContext kernel_context;
    {
          // EmplaceBackInputs
          // EmplaceBackAttributes
          // EmplaceBackOutputs
          // ... 代码先省略了
    }

    // 7. 设置device ctx
    kernel_context.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
        phi::TransToPhiPlace(kernel_key.backend())));

    // 8 执行
    infer_meta_interface->infer_meta_(&infer_meta_context); // 填充 tensor 的 meta 信息
    kernel_fn(&kernel_context);

    // 9 获取结果
    auto out_tensor = exec_info_->GetScope()->
            FindVar(output_name)->Get<phi::DenseTensor>();

    // 10 替换结果
    std::unique_ptr<pir::Parameter> parameter =
          std::make_unique<pir::Parameter>(
              reinterpret_cast<void*>(out_tensor.data()),
              out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
              op_item->result(0).type());
    std::cout << "4" << std::endl;
    op_item->GetParentProgram()->SetParameter(output_name, std::move(parameter));
    auto get_parameter_op =
          rewriter.Build<pir::GetParameterOp>(output_name, op_item->result(0).type());
    rewriter.ReplaceAllUsesWith(op_item->result(0), get_parameter_op->result(0));
    rewriter.EraseOp(op_item);
}

优缺点分析

好处:

  • 轻量级,没有冗余计算
  • 不必引入执行器,消除循环依赖风险

坏处:

  • 不能直接复用执行器,所以对于一些特殊算子(如 combine)需要进行特殊处理
  • 因为代码逻辑未经过大量测试,所以有踩坑的风险

其他问题

常量折叠 pass 需要传入 scope,如:

std::unique_ptr<Pass> CreateConstantFoldingPass(
    phi::Place& place, paddle::framework::Scope* scope) {
  return std::make_unique<ConstantFoldingPass>(place, scope);
}

使用这种方式时,由于与其他pass接口不统一,故无法从python直接根据pass名字来创建pass

pm = pir.PassManager()
pm.add_pass('dead_code_elimination_pass')
pm.run(new_program)