Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.69】为神经网络编译器 CINN 增加 cast 算子 #205

Merged
merged 1 commit into from
Aug 16, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions rfcs/CINN/APIs/20220811_api_design_for_cast.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# CINN cast 设计文档

| API名称 | cast |
| ---------------------------------------------------------- | ------------------------------------------------ |
| 提交作者<input type="checkbox" class="rowselector hidden"> | 六个骨头 |
| 提交时间<input type="checkbox" class="rowselector hidden"> | 2022-08-11 |
| 版本号 | V1.0 |
| 依赖CINN版本<input type="checkbox" class="rowselector hidden"> | develop |
| 文件名 | 20220729_api_design_for_cast.md<br> |

# 一、概述

## 1、相关背景

`cast` 是众多神经网络编译器中基础的算子。转化输入的数据类型,例如将Int(32)变为Float(32)。
为了提升 CINN API 丰富度,需要扩充 API `cast`。

## 2、名词解释

- 张量/Tensor:指高维数组。
- cast:转化输入的数据类型。
- dtype:数据类型。

## 3、功能目标

实现cast功能,将输入转换为指定数据类型。例如,对于张量 $A$ = [1, 2, 3],
cast( $A$, dtype = Float(32) 结果为 $[1.0f, 2.0f, 3.0f]$。

## 4、意义

为神经网络编译器 CINN 增加基础算子`cast`。

# 二、CINN现状

对CINN框架目前可以调用底层ir实现,即ir::Cast::Make,但没有相应的hlir实现,因此有必要实现 `cast` hlir API。

# 三、业内方案调研

- [TVM](https://github.com/apache/tvm/blob/main/src/relay/transforms/canonicalize_cast.cc):
对张量中每个元素使用更底层的ir进行类型转化。

```cpp
class CastCanonicalizer : public ExprMutator {
public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}

Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");

if (const OpNode* opnode = call->op.as<OpNode>()) {
auto pattern = fpattern[GetRef<Op>(opnode)];
if (pattern <= kBroadcast) {
Array<Expr> call_args = call->args;
bool unchanged = true;
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = GetNewCallArg(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
if (unchanged) {
return GetRef<Expr>(call);
}
return Call(call->op, call_args, call->attrs, call->type_args);
}
}

Expr new_expr = ExprMutator::VisitExpr_(call);
return new_expr;
}

private:
std::unordered_map<const Object*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;

Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
Expr new_expr = this->VisitExpr(e);

if (const CallNode* call = e.as<CallNode>()) {
if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
ICHECK(from_type);

if (from_type->dtype.bits() < attrs->dtype.bits()) {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
ICHECK(new_call);
ICHECK(new_call->op == cast_op_);
return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args);
}
}
}
}
return new_expr;
}
};

Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); }

namespace transform {

Pass CanonicalizeCast() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast);

} // namespace transform

```

- [XLA](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/cast.cpp):与TVM类似。

```cpp
namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(const torch::lazy::Value& input,
xla::PrimitiveType type) {
xla::Shape shape = GetXlaShape(input);
shape.set_element_type(type);
return shape;
}

} // namespace

Cast::Cast(const torch::lazy::Value& input, xla::PrimitiveType type)
: XlaNode(xla_cast, {input}, NodeOutputShape(input, type),
/*num_outputs=*/1, torch::lazy::MHash(static_cast<int>(type))),
type_(type) {}

Cast::Cast(const torch::lazy::Value& input, at::ScalarType dtype,
c10::optional<at::ScalarType> stype)
: XlaNode(xla_cast, {input},
NodeOutputShape(input,
MakeXlaPrimitiveType(dtype, /*device=*/nullptr)),
/*num_outputs=*/1,
torch::lazy::MHash(101, static_cast<int>(dtype),
torch::lazy::OptionalOr<int>(stype, -1))),
type_(MakeXlaPrimitiveType(dtype, /*device=*/nullptr)),
dtype_(dtype),
stype_(stype) {}

torch::lazy::NodePtr Cast::Clone(torch::lazy::OpList operands) const {
return dtype_ ? torch::lazy::MakeNode<Cast>(operands.at(0), *dtype_, stype_)
: torch::lazy::MakeNode<Cast>(operands.at(0), type_);
}

XlaOpVector Cast::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::PrimitiveType raw_from =
stype_ ? TensorTypeToRawXlaType(*stype_) : input_shape.element_type();
xla::PrimitiveType raw_to = dtype_ ? TensorTypeToRawXlaType(*dtype_) : type_;
xla::XlaOp output =
ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to,
/*device=*/nullptr);
return ReturnOp(output, loctx);
}

std::string Cast::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString()
<< ", type=" << xla::primitive_util::LowercasePrimitiveTypeName(type_);
if (dtype_) {
ss << ", dtype=" << *dtype_;
}
if (stype_) {
ss << ", stype=" << *stype_;
}
return ss.str();
}

} // namespace torch_xla
```

# 四、对比分析

TVM 与 XLA 实现方案类似。

# 五、设计思路与实现方案

## 命名与参数设计

- A:输入张量
- dtype:指定数据类型
- name:输出名称

## 底层OP设计

1. 在 `cinn/hlir/op/contrib/cast.h` 里声明`cast`算子。
2. 在 `cinn/hlir/op/contrib/cast.cc` 里实现`cast`算子和 `strategy`。

## API实现方案

例如,对于张量 A = [1, 1, 1],
cast( A, dtype = Float(32)) 结果为[1.0f, 1.0f, 1.0f],
cast( A, dtype = bool) 结果为[True, True, True]。

1. 在 `cinn/frontend/net_build.h` 里声明 `BaseBuilder::Cast`。
2. 在 `cinn/frontend/net_build.cc` 里实现 `BaseBuilder::Cast`。
3. 在 `cinn/pybind/frontend` 对 Python 类 `BaseBuilder` 添加 `cast` 接口,并绑定到`BaseBuilder::Cast`。
4. 上层 `load_paddle_model` 调用提交到 `cinn/frontend/paddle_model_to_program.h` 和 `.cc` 文件下。

通过使用 Builder 类的方法调用 Cast。

```python
builder = NetBuilder("test_basic")
a = builder.create_input(Int(32), (8, 24, 124), "A1")
b = builder.cast(a, Float(32)) # 输出类型变为Float(32)
```

# 六、测试和验收的考量

1. 在`cinn/hlir/op/contrib/cast_test.cc`和`cinn/hlir/op/contrib/cast_test.cc`中添加对底层OP进行测试的代码,在`cinn/frontend/net_builder_test.cc`中添加对前端的测试。
2. 提交 API 使用方法到相应的文档中。

# 七、可行性分析和排期规划

- 可行性分析:非常可行
- 排期规划:预计8月25日前完成

# 八、影响面

对其他模块无影响。

# 附件及参考资料

[TVM文档](https://github.com/apache/tvm/blob/main/src/relay/transforms/canonicalize_cast.cc)
[XLA文档](https://github.com/pytorch/xla/blob/3d24d955b6121289a3c8bb86eda541fca7a0d69f/torch_xla/csrc/ops/cast.cpp)
[CINN文档](https://paddlepaddle.github.io/CINN/)