Skip to content

Commit

Permalink
add CINN cast rfc docs (#205)
Browse files Browse the repository at this point in the history
zrr1999 authored Aug 16, 2022
1 parent 50abd7d commit f0f80cc
Showing 1 changed file with 241 additions and 0 deletions.
241 changes: 241 additions & 0 deletions rfcs/CINN/APIs/20220811_api_design_for_cast.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# CINN cast 设计文档

| API名称 | cast |
| ---------------------------------------------------------- | ------------------------------------------------ |
| 提交作者<input type="checkbox" class="rowselector hidden"> | 六个骨头 |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2022-08-11 |
| 版本号 | V1.0 |
| 依赖CINN版本<input type="checkbox" class="rowselector hidden"> | develop |
| 文件名 | 20220729_api_design_for_cast.md<br> |

# 一、概述

## 1、相关背景

`cast` 是众多神经网络编译器中基础的算子。转化输入的数据类型,例如将Int(32)变为Float(32)。
为了提升 CINN API 丰富度,需要扩充 API `cast`

## 2、名词解释

- 张量/Tensor:指高维数组。
- cast:转化输入的数据类型。
- dtype:数据类型。

## 3、功能目标

实现cast功能,将输入转换为指定数据类型。例如,对于张量 $A$ = [1, 2, 3]
cast( $A$, dtype = Float(32) 结果为 $[1.0f, 2.0f, 3.0f]$。

## 4、意义

为神经网络编译器 CINN 增加基础算子`cast`

# 二、CINN现状

对CINN框架目前可以调用底层ir实现,即ir::Cast::Make,但没有相应的hlir实现,因此有必要实现 `cast` hlir API。

# 三、业内方案调研

- [TVM](https://github.com/apache/tvm/blob/main/src/relay/transforms/canonicalize_cast.cc)
对张量中每个元素使用更底层的ir进行类型转化。

```cpp
class CastCanonicalizer : public ExprMutator {
public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}

Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");

if (const OpNode* opnode = call->op.as<OpNode>()) {
auto pattern = fpattern[GetRef<Op>(opnode)];
if (pattern <= kBroadcast) {
Array<Expr> call_args = call->args;
bool unchanged = true;
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = GetNewCallArg(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
if (unchanged) {
return GetRef<Expr>(call);
}
return Call(call->op, call_args, call->attrs, call->type_args);
}
}

Expr new_expr = ExprMutator::VisitExpr_(call);
return new_expr;
}

private:
std::unordered_map<const Object*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;

Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
Expr new_expr = this->VisitExpr(e);

if (const CallNode* call = e.as<CallNode>()) {
if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
ICHECK(from_type);

if (from_type->dtype.bits() < attrs->dtype.bits()) {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
ICHECK(new_call);
ICHECK(new_call->op == cast_op_);
return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args);
}
}
}
}
return new_expr;
}
};

Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); }

namespace transform {

Pass CanonicalizeCast() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast);

} // namespace transform

```
- [XLA](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/cast.cpp):与TVM类似。
```cpp
namespace torch_xla {
namespace {
xla::Shape NodeOutputShape(const torch::lazy::Value& input,
xla::PrimitiveType type) {
xla::Shape shape = GetXlaShape(input);
shape.set_element_type(type);
return shape;
}
} // namespace
Cast::Cast(const torch::lazy::Value& input, xla::PrimitiveType type)
: XlaNode(xla_cast, {input}, NodeOutputShape(input, type),
/*num_outputs=*/1, torch::lazy::MHash(static_cast<int>(type))),
type_(type) {}
Cast::Cast(const torch::lazy::Value& input, at::ScalarType dtype,
c10::optional<at::ScalarType> stype)
: XlaNode(xla_cast, {input},
NodeOutputShape(input,
MakeXlaPrimitiveType(dtype, /*device=*/nullptr)),
/*num_outputs=*/1,
torch::lazy::MHash(101, static_cast<int>(dtype),
torch::lazy::OptionalOr<int>(stype, -1))),
type_(MakeXlaPrimitiveType(dtype, /*device=*/nullptr)),
dtype_(dtype),
stype_(stype) {}
torch::lazy::NodePtr Cast::Clone(torch::lazy::OpList operands) const {
return dtype_ ? torch::lazy::MakeNode<Cast>(operands.at(0), *dtype_, stype_)
: torch::lazy::MakeNode<Cast>(operands.at(0), type_);
}
XlaOpVector Cast::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::PrimitiveType raw_from =
stype_ ? TensorTypeToRawXlaType(*stype_) : input_shape.element_type();
xla::PrimitiveType raw_to = dtype_ ? TensorTypeToRawXlaType(*dtype_) : type_;
xla::XlaOp output =
ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to,
/*device=*/nullptr);
return ReturnOp(output, loctx);
}
std::string Cast::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString()
<< ", type=" << xla::primitive_util::LowercasePrimitiveTypeName(type_);
if (dtype_) {
ss << ", dtype=" << *dtype_;
}
if (stype_) {
ss << ", stype=" << *stype_;
}
return ss.str();
}
} // namespace torch_xla
```

# 四、对比分析

TVM 与 XLA 实现方案类似。

# 五、设计思路与实现方案

## 命名与参数设计

- A:输入张量
- dtype:指定数据类型
- name:输出名称

## 底层OP设计

1.`cinn/hlir/op/contrib/cast.h` 里声明`cast`算子。
2.`cinn/hlir/op/contrib/cast.cc` 里实现`cast`算子和 `strategy`

## API实现方案

例如,对于张量 A = [1, 1, 1]
cast( A, dtype = Float(32)) 结果为[1.0f, 1.0f, 1.0f]
cast( A, dtype = bool) 结果为[True, True, True]

1.`cinn/frontend/net_build.h` 里声明 `BaseBuilder::Cast`
2.`cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Cast`
3.`cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `cast` 接口,并绑定到`BaseBuilder::Cast`
4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h``.cc` 文件下。

通过使用 Builder 类的方法调用 Cast。

```python
builder = NetBuilder("test_basic")
a = builder.create_input(Int(32), (8, 24, 124), "A1")
b = builder.cast(a, Float(32)) # 输出类型变为Float(32)
```

# 六、测试和验收的考量

1.`cinn/hlir/op/contrib/cast_test.cc``cinn/hlir/op/contrib/cast_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
2. 提交 API 使用方法到相应的文档中。

# 七、可行性分析和排期规划

- 可行性分析:非常可行
- 排期规划:预计8月25日前完成

# 八、影响面

对其他模块无影响。

# 附件及参考资料

[TVM文档](https://github.com/apache/tvm/blob/main/src/relay/transforms/canonicalize_cast.cc)
[XLA文档](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/cast.cpp)
[CINN文档](https://paddlepaddle.github.io/CINN/)

0 comments on commit f0f80cc

Please sign in to comment.