From f0f80cc90788a69799ff61a0f51c0280611b1aef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=AD=E4=B8=AA=E9=AA=A8=E5=A4=B4?= <46243324+zrr1999@users.noreply.github.com> Date: Tue, 16 Aug 2022 11:16:55 +0800 Subject: [PATCH] add CINN cast rfc docs (#205) --- .../CINN/APIs/20220811_api_design_for_cast.md | 241 ++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 rfcs/CINN/APIs/20220811_api_design_for_cast.md diff --git a/rfcs/CINN/APIs/20220811_api_design_for_cast.md b/rfcs/CINN/APIs/20220811_api_design_for_cast.md new file mode 100644 index 000000000..568169075 --- /dev/null +++ b/rfcs/CINN/APIs/20220811_api_design_for_cast.md @@ -0,0 +1,241 @@ +# CINN cast 设计文档 + +| API名称 | cast | +| ---------------------------------------------------------- | ------------------------------------------------ | +| 提交作者 | 六个骨头 | +| 提交时间 | 2022-08-11 | +| 版本号 | V1.0 | +| 依赖CINN版本 | develop | +| 文件名 | 20220729_api_design_for_cast.md
| + +# 一、概述 + +## 1、相关背景 + +`cast` 是众多神经网络编译器中基础的算子。转化输入的数据类型,例如将Int(32)变为Float(32)。 +为了提升 CINN API 丰富度,需要扩充 API `cast`。 + +## 2、名词解释 + +- 张量/Tensor:指高维数组。 +- cast:转化输入的数据类型。 +- dtype:数据类型。 + +## 3、功能目标 + +实现cast功能,将输入转换为指定数据类型。例如,对于张量 $A$ = [1, 2, 3], +cast( $A$, dtype = Float(32) 结果为 $[1.0f, 2.0f, 3.0f]$。 + +## 4、意义 + +为神经网络编译器 CINN 增加基础算子`cast`。 + +# 二、CINN现状 + +对CINN框架目前可以调用底层ir实现,即ir::Cast::Make,但没有相应的hlir实现,因此有必要实现 `cast` hlir API。 + +# 三、业内方案调研 + +- [TVM](https://github.com/apache/tvm/blob/main/src/relay/transforms/canonicalize_cast.cc): +对张量中每个元素使用更底层的ir进行类型转化。 + + ```cpp + class CastCanonicalizer : public ExprMutator { + public: + CastCanonicalizer() : cast_op_(Op::Get("cast")) {} + + Expr VisitExpr_(const CallNode* call) { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + + if (const OpNode* opnode = call->op.as()) { + auto pattern = fpattern[GetRef(opnode)]; + if (pattern <= kBroadcast) { + Array call_args = call->args; + bool unchanged = true; + for (size_t i = 0; i < call_args.size(); ++i) { + Expr arg = call_args[i]; + Expr new_arg = GetNewCallArg(arg); + if (!arg.same_as(new_arg)) { + call_args.Set(i, new_arg); + unchanged = false; + } + } + if (unchanged) { + return GetRef(call); + } + return Call(call->op, call_args, call->attrs, call->type_args); + } + } + + Expr new_expr = ExprMutator::VisitExpr_(call); + return new_expr; + } + + private: + std::unordered_map ref_counter_; + // cast op is frequently checked for equivalence. Therefore, we cache it to + // reduce lookup overhead. + const Op& cast_op_; + + Expr GetNewCallArg(const Expr& e) { + // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor + Expr new_expr = this->VisitExpr(e); + + if (const CallNode* call = e.as()) { + if (call->op == cast_op_) { + auto attrs = call->attrs.as(); + const auto* from_type = call->args[0]->type_as(); + ICHECK(from_type); + + if (from_type->dtype.bits() < attrs->dtype.bits()) { + if (++ref_counter_[call] > 1) { + const CallNode* new_call = new_expr.as(); + ICHECK(new_call); + ICHECK(new_call->op == cast_op_); + return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args); + } + } + } + } + return new_expr; + } +}; + +Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); } + +namespace transform { + +Pass CanonicalizeCast() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeCast(f)); + }; + return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast); + +} // namespace transform + + ``` + +- [XLA](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/cast.cpp):与TVM类似。 + +```cpp +namespace torch_xla { +namespace { + +xla::Shape NodeOutputShape(const torch::lazy::Value& input, + xla::PrimitiveType type) { + xla::Shape shape = GetXlaShape(input); + shape.set_element_type(type); + return shape; +} + +} // namespace + +Cast::Cast(const torch::lazy::Value& input, xla::PrimitiveType type) + : XlaNode(xla_cast, {input}, NodeOutputShape(input, type), + /*num_outputs=*/1, torch::lazy::MHash(static_cast(type))), + type_(type) {} + +Cast::Cast(const torch::lazy::Value& input, at::ScalarType dtype, + c10::optional stype) + : XlaNode(xla_cast, {input}, + NodeOutputShape(input, + MakeXlaPrimitiveType(dtype, /*device=*/nullptr)), + /*num_outputs=*/1, + torch::lazy::MHash(101, static_cast(dtype), + torch::lazy::OptionalOr(stype, -1))), + type_(MakeXlaPrimitiveType(dtype, /*device=*/nullptr)), + dtype_(dtype), + stype_(stype) {} + +torch::lazy::NodePtr Cast::Clone(torch::lazy::OpList operands) const { + return dtype_ ? torch::lazy::MakeNode(operands.at(0), *dtype_, stype_) + : torch::lazy::MakeNode(operands.at(0), type_); +} + +XlaOpVector Cast::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); + xla::PrimitiveType raw_from = + stype_ ? TensorTypeToRawXlaType(*stype_) : input_shape.element_type(); + xla::PrimitiveType raw_to = dtype_ ? TensorTypeToRawXlaType(*dtype_) : type_; + xla::XlaOp output = + ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to, + /*device=*/nullptr); + return ReturnOp(output, loctx); +} + +std::string Cast::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() + << ", type=" << xla::primitive_util::LowercasePrimitiveTypeName(type_); + if (dtype_) { + ss << ", dtype=" << *dtype_; + } + if (stype_) { + ss << ", stype=" << *stype_; + } + return ss.str(); +} + +} // namespace torch_xla +``` + +# 四、对比分析 + +TVM 与 XLA 实现方案类似。 + +# 五、设计思路与实现方案 + +## 命名与参数设计 + +- A:输入张量 +- dtype:指定数据类型 +- name:输出名称 + +## 底层OP设计 + +1. 在 `cinn/hlir/op/contrib/cast.h` 里声明`cast`算子。 +2. 在 `cinn/hlir/op/contrib/cast.cc` 里实现`cast`算子和 `strategy`。 + +## API实现方案 + +例如,对于张量 A = [1, 1, 1], +cast( A, dtype = Float(32)) 结果为[1.0f, 1.0f, 1.0f], +cast( A, dtype = bool) 结果为[True, True, True]。 + +1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Cast`。 +2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Cast`。 +3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `cast` 接口,并绑定到`BaseBuilder::Cast`。 +4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。 + +通过使用 Builder 类的方法调用 Cast。 + +```python +builder = NetBuilder("test_basic") +a = builder.create_input(Int(32), (8, 24, 124), "A1") +b = builder.cast(a, Float(32)) # 输出类型变为Float(32) +``` + +# 六、测试和验收的考量 + +1. 在`cinn/hlir/op/contrib/cast_test.cc`和`cinn/hlir/op/contrib/cast_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。 +2. 提交 API 使用方法到相应的文档中。 + +# 七、可行性分析和排期规划 + +- 可行性分析:非常可行 +- 排期规划:预计8月25日前完成 + +# 八、影响面 + +对其他模块无影响。 + +# 附件及参考资料 + +[TVM文档](https://github.com/apache/tvm/blob/main/src/relay/transforms/canonicalize_cast.cc) +[XLA文档](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/cast.cpp) +[CINN文档](https://paddlepaddle.github.io/CINN/)