Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【PaddlePaddle Hackathon 69】add cast op #883

Merged
merged 7 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ Variable NetBuilder::ReduceAny(const Variable& x, const std::vector<int>& dim, b
return Reduce(x, ReduceKind::kAny, dim, keep_dim);
}

Variable NetBuilder::Cast(const Variable& operand, const std::string& dtype) {
Instruction instr("cast", {operand});
instr.SetAttr("dtype", dtype);
InferShape(instr);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable NetBuilder::Conv2d(const Variable& a,
const Variable& b,
const std::vector<int>& strides,
Expand Down
5 changes: 5 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class NetBuilder : public BaseBuilder {
*/
Variable ReduceAny(const Variable& x, const std::vector<int>& dim, bool keep_dim = false);

/**
* Cast Variable x to dtype.
*/
Variable Cast(const Variable& operand, const std::string& dtype);

/**
* The convolution2D layer calculates the output based on the input, filter
* and strides, paddings, dilations, groups parameters.
Expand Down
60 changes: 60 additions & 0 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,65 @@ TEST(net_build, program_execute_reverse) {
runtime_program->Execute();
}

void SetIntRandData(hlir::framework::Tensor tensor, Target target) {
auto* data = tensor->mutable_data<int>(target);
std::random_device seed;
std::default_random_engine engine(seed());
std::uniform_int_distribution<int> dist(1, 128);
size_t num_ele = tensor->shape().numel();
std::vector<int> random_data(num_ele);
for (size_t i = 0; i < num_ele; i++) {
random_data[i] = dist(engine); // All random data
}
std::copy(random_data.begin(), random_data.end(), data);
}

TEST(net_build, program_execute_cast) {
const int B = 4;
const int H = 7;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Int(32), {B, H}, "In");
Variable output = builder.Cast(input, "float");
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetIntRandData(input_tensor, target);
int* input_data = input_tensor->mutable_data<int>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], B);
EXPECT_EQ(output_shape[1], H);

float* output_data = output_tensor->mutable_data<float>(target);
VLOG(6) << "Visualize output_data";
for (int b = 0; b < B; ++b) {
for (int h = 0; h < H; ++h) {
std::string line;
int index = h + H * b;
float in_data = (float)input_data[index];
float out_data = output_data[index];
line += (std::to_string(out_data) + ", ");
EXPECT_EQ(in_data, out_data);
VLOG(6) << line;
}
}
}

} // namespace frontend
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/hlir/op/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(contrib)

core_gather_headers()

gather_srcs(cinnapi_src SRCS
Expand Down
7 changes: 7 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS
cast.cc
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
124 changes: 124 additions & 0 deletions cinn/hlir/op/contrib/cast.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/cast.h"

#include <gflags/gflags.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValue;
using common::CINNValuePack;

ir::Tensor Cast(const ir::Tensor &A, const Type &dtype, const std::string &name) {
auto res = Compute(
A->shape, [=](const std::vector<Expr> &indices) { return ir::Cast::Make(dtype, A(indices)); }, name);
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForCast(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute cast_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input arguments of Cast compute is empty! Please check.\n";
CINNValuePack a = args[0];
CHECK_GE(a.size(), 1U) << "at least 1 input tensors for Cast compute\n";
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
Expr A = a[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto tensor_A = A.as_tensor_ref();
auto stages = CreateStages({tensor_A});
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
ir::Tensor out = Cast(tensor_A, out_type[0], UniqName("Cast_out"));
std::vector<CINNValue> res;
stages->InsertLazily(out);
res.push_back(CINNValue(out));
CHECK(!out_type.empty()) << "Output type of Cast is empty! Please check.\n";
res.push_back(CINNValue(stages));
*ret = CINNValuePack{res};
});

framework::CINNSchedule cast_schedule([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of reshape schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(cast_compute, cast_schedule, "strategy.cast.x86", 1);
return strategy;
}

std::vector<std::vector<int>> InferShapeForCast(const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again.";
std::vector<std::vector<int>> res{inputs_shape[0]};
return res;
}

std::vector<Type> InferDtypeForCast(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_type.size(), 1U) << "The input's type size should be 1! Please check again.";
std::vector<Type> res;
if (attrs.find("dtype") != attrs.end()) {
auto dtype_str = absl::get<std::string>(attrs.at("dtype"));
res.push_back(common::Str2Type(dtype_str));
}
CHECK(!res.empty()) << "The cast should have an attr named 'dtype'.";
return res;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(cast_ops) {
CINN_REGISTER_OP(cast)
.describe("Cast.")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForCast)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForCast))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCast))
.set_support_level(4);

return true;
}
32 changes: 32 additions & 0 deletions cinn/hlir/op/contrib/cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"

namespace cinn {
namespace hlir {
namespace op {

ir::Tensor Cast(const ir::Tensor& A, const Type& dtype, const std::string& name);

} // namespace op
} // namespace hlir
} // namespace cinn
67 changes: 67 additions & 0 deletions cinn/hlir/op/contrib/cast_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/cast.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <string>
#include <vector>

#include "cinn/backends/codegen_c.h"
#include "cinn/backends/codegen_c_x86.h"
#include "cinn/backends/codegen_cuda_dev.h"
#include "cinn/common/context.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/stage.h"

namespace cinn {
namespace hlir {
namespace op {

TEST(GenerateCode_Cpu, Cast) {
common::Context::Global().ResetNameId();

common::Target target = common::DefaultHostTarget();

ir::Expr n(4);
ir::Expr h(28);

lang::Placeholder<int32_t> in("in", {n, h});
ir::Tensor res = Cast(in, Float(32), "test_Cast_out");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeCpu_Cast", stages, {res}, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CPU codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("Cast_Module", target);
for (auto& f : funcs) {
builder.AddFunction(f);
}

backends::CodeGenCX86 codegen(target, backends::CodeGenCX86::Feature::AVX512);
codegen.SetInlineBuiltinCodes(false);
std::string code = codegen.Compile(builder.Build(), backends::CodeGenC::OutputKind::CImpl);
VLOG(6) << "Cpu Codegen result:";
VLOG(6) << code << std::endl;
}

} // namespace op
} // namespace hlir
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ CINN_USE_REGISTER(broadcast_ops)
CINN_USE_REGISTER(broadcast_grad_ops)
CINN_USE_REGISTER(elementwise_ops)
CINN_USE_REGISTER(transform_ops)
CINN_USE_REGISTER(cast_ops)
CINN_USE_REGISTER(reduce_ops)