-
Notifications
You must be signed in to change notification settings - Fork 1
常量折叠方案评审
常量折叠是比较基础的 pass 之一,paddle 在 pir/transmorm 模块中实现常量折叠 pass,它会遍历 program 中的 op,若判断出 op 能够进行折叠,则可以从传入的 scope 中获取常量,并进行计算,将计算后的结果存入 scope,并插入 get_prameter 替代折叠的 op。
oneflow 通过 mlir 的 fold 的技术路线实现的常量折叠。 oneflow 实现了 UnaryFold 和 BinaryFold 两种基本的折叠方式。上层 op 根据自己的参数数目直接复用 UnaryFold 和 BinaryFold,并且 op 在 fold 里硬编码自己的 kernel。 这种做法的好处就是设计简单,缺点是扩展性差,需要为 op 硬编码 kernel。
OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
const std::function<MaybeTensor(const TensorPtr&)>& f) {
::oneflow::LazyMode::Guard guard{false};
if (!operands.front()) { return {}; } // Important!
const auto attr_dict = operands.front().cast<mlir::DictionaryAttr>();
auto attrs = NamedAttrList(attr_dict);
const auto tensor = support::DenseElementsAttrToTensor(
attr_dict.get("value"), attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceTagAttr()),
attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr()));
const auto result = f(tensor).GetPtrOrThrow();
attrs.set("value", support::TensorToDenseElementsAttr(result, ctx));
attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), GenNewVariableOpName(ctx));
attrs.set(OpTrait::TensorSource<void>::getDataTypeAttrName(),
attr_dict.get(OpTrait::TensorSource<void>::getDataTypeAttrName()));
return attrs.getDictionary(ctx);
}
OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
const std::function<MaybeTensor(const TensorPtr&, const TensorPtr&)>& f) {
::oneflow::LazyMode::Guard guard{false};
if (!(operands.front() && operands.back())) { return {}; } // Important!
auto lhs_attr_dict = operands.front().cast<mlir::DictionaryAttr>();
auto rhs_attr_dict = operands.back().cast<mlir::DictionaryAttr>();
if (!DictionaryAttrsHaveSameDataType({lhs_attr_dict, rhs_attr_dict})) {
llvm::errs()
<< "Input tensors should have same data type in binary operation of constant folding."
<< "\n";
return nullptr;
}
}
//////////////////////////////////////////
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, [this](const auto& tensor) {
std::vector<int32_t> perm_;
for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast<IntegerAttr>().getSInt()); }
return functional::Transpose(tensor, perm_);
});
}
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, [this](const auto& tensor) {
std::vector<int64_t> shape_vec;
for (auto& x : getShape().getValue()) {
shape_vec.emplace_back(x.cast<mlir::IntegerAttr>().getValue().getSExtValue());
}
return functional::Reshape(
tensor, ::oneflow::Shape(::oneflow::DimVector(shape_vec.begin(), shape_vec.end())));
});
}
OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, [this](const auto& tensor) -> MaybeTensor {
if (getHasIntOperand()) { return functional::ScalarAdd(tensor, getIntOperand(), 1, false); }
if (getHasFloatOperand()) {
return functional::ScalarAdd(tensor, getFloatOperand().convertToDouble(), 1, false);
}
emitError("Scalar op must has a int operand or a float operand.");
return TensorPtr();
});
}
OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, functional::Sqrt);
}
OpFoldResult BroadcastMulOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return BinaryFold(getContext(), operands, functional::Mul);
}
OpFoldResult BroadcastDivOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return BinaryFold(getContext(), operands, functional::Div);
}
OpFoldResult BroadcastSubOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return BinaryFold(getContext(), operands, [](const auto& lhs, const auto& rhs) -> MaybeTensor {
return functional::Sub(lhs, rhs, /*alpha=*/1.0, false);
});
}
tvm 构造了一个 dlcontext,并设置设备为 cpu,创建 llvm target。最后构造了一个 Interpreter,执行 Interpreter::Eval 完成计算。引入执行器的好处在于直接复用执行器,编码简单,扩展性好,但是需要引入对执行器的依赖。并且执行器逻辑复杂,会引入一些不必要的计算。
Expr FoldConstant(const Expr& expr, const IRModule& mod) {
using tvm::transform::PassContext;
DLContext ctx;
ctx.device_type = kDLCPU;
ctx.device_id = 0;
Target target = Target::Create("llvm");
// use a fresh build context
// in case we are already in a build context.
With<PassContext> fresh_build_ctx(PassContext::Create());
return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
}
paddle 的执行器已经较为成熟。在常量折叠 pass 中,可以构造临时 program,然后直接引用执行器执行临时 program,并取得计算成果。
- 根据待折叠 op 构建临时 program
- PdOpLowerToKernel,将原 program 替换成带有 kernel 信息的 program
- 创建 InterpreterCore
- 执行 InterpreterCore 并获取 output tensor
- 使用 get_parameter_op 替换原 op 代码如:
void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
pir::Program new_program(ir_context());
auto output_var_name = BuildProgramFromOperation(op, &new_program);
// execute program
exe_config_->skip_gc_vars.insert(output_var_name);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);
core.Run({});
// TODO(liuyuanle): support multiple output
auto get_parameter_op = rewriter.Build<pir::GetParameterOp>(
output_var_name, op->result(0).type());
get_parameter_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op);
}
好处:
- 执行器比较成熟,不会踩坑
- 复用执行器代码,易于维护
- 编码简单
坏处:
- 引用执行器,在逻辑上存在循环依赖的风险
- 我们只需要执行一个 op,执行器过于重量级,会带来不必要的计算开销(如,只执行一个 op的话,就没有必要 pd_lower_to_kernel)
paddle 将 op 的元信息编码在 yaml 文件中,并通过 codegen 来生成 op 的定义。其中,codegen 生成的 OpInfoTuple 结构体将 op 的元信息打包,提供给使用者:
OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
所以,我们可以不必引入执行器,而是通过 OpInfoTuple 拿到 op 的所有元信息,进而拿到 kernel 函数指针,进而直接进行计算。
Rewrite 的主要流程:
- 设置 output tensor name,并在 scope 中创建对应的 Variable
- 获取 op_info_parser(持有 OpInfoTuple),并获取meta信息
- 根据上述信息,构建 kernel key
- 在 KernelFactory 中根据 kernel key 找到 kernel_fn(这个 kernel_fn 是使用宏来注册 kernel 时注册到 KernelFactory 中的,它是 kernel 的 warpper)
- 获取 InferMetaInterface,并构建 InferMetaContext
- 执行 InferMetaInterface::infer_meta(InferMetaContext),为 output tensor 设置 meta 信息
- 构建 KernelContext
- 执行 kernel_fn(&kernel_context),并获取 output tensor
- 使用 get_parameter_op 替换原 op 代码如:
void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<paddle::dialect::KernelDialect>();
// 0 设置 output name, build value
std::string op_name = op_item->name();
std::string output_name = "@constant_folding_pass@_" + std::to_string(suffix_++);
paddle::framework::Variable* var = exec_info_->GetScope()->Var(output_name);
phi::DenseTensor * ts = var->GetMutable<phi::DenseTensor>();
exec_info_->Add(op_item->result(0), output_name);
// 1 获取 op_info_parser
paddle::dialect::OpYamlInfoParser op_info_parser(op_item->
dyn_cast<paddle::dialect::OpYamlInfoInterface>().
GetOpInfo());
// 2 获取 kernel name
auto kernel_fn_str = op_info_parser.OpRuntimeInfo().kernel_func;
auto& data_type_info = op_info_parser.OpRuntimeInfo().kernel_key_dtype;
// 3 获取 data type
phi::DataType kernel_data_type = phi::DataType::FLOAT32;//phi::DataType::UNDEFINED;
// 4 设置其它 kernel 信息
phi::Backend kernel_backend = phi::Backend::CPU;
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
// 5 获取 kernel fn
phi::KernelKey kernel_key = phi::KernelKey(kernel_backend, kernel_layout, kernel_data_type);
auto kernel_fn = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_fn_str, kernel_key).kernel;
// 6 build ctx
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
auto* infer_meta_interface =
op_info.GetInterfaceImpl<paddle::dialect::InferMetaInterface>();
phi::InferMetaContext infer_meta_context;
phi::KernelContext kernel_context;
{
// EmplaceBackInputs
// EmplaceBackAttributes
// EmplaceBackOutputs
// ... 代码先省略了
}
// 7. 设置device ctx
kernel_context.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
// 8 执行
infer_meta_interface->infer_meta_(&infer_meta_context); // 填充 tensor 的 meta 信息
kernel_fn(&kernel_context);
// 9 获取结果
auto out_tensor = exec_info_->GetScope()->
FindVar(output_name)->Get<phi::DenseTensor>();
// 10 替换结果
std::unique_ptr<pir::Parameter> parameter =
std::make_unique<pir::Parameter>(
reinterpret_cast<void*>(out_tensor.data()),
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
op_item->result(0).type());
std::cout << "4" << std::endl;
op_item->GetParentProgram()->SetParameter(output_name, std::move(parameter));
auto get_parameter_op =
rewriter.Build<pir::GetParameterOp>(output_name, op_item->result(0).type());
rewriter.ReplaceAllUsesWith(op_item->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op_item);
}
好处:
- 轻量级,没有冗余计算
- 不必引入执行器,消除循环依赖风险
坏处:
- 不能直接复用执行器,所以对于一些特殊算子(如 combine)需要进行特殊处理
- 因为代码逻辑未经过大量测试,所以有踩坑的风险
常量折叠 pass 需要传入 scope,如:
std::unique_ptr<Pass> CreateConstantFoldingPass(
phi::Place& place, paddle::framework::Scope* scope) {
return std::make_unique<ConstantFoldingPass>(place, scope);
}
使用这种方式时,由于与其他pass接口不统一,故无法从python直接根据pass名字来创建pass
pm = pir.PassManager()
pm.add_pass('dead_code_elimination_pass')
pm.run(new_program)