Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

【Hackathon No.74】Add repeat op #949

Merged
merged 9 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ Variable NetBuilder::Pool2d(const Variable& a,
.front();
}

Variable NetBuilder::Repeat(const Variable& x, int repeats, int axis) {
return CustomInstr("repeat", {x}, {{"repeats", repeats}, {"axis", axis}}).front();
}

std::vector<Variable> NetBuilder::BatchNorm(const Variable& a,
const Variable& scale,
const Variable& bias,
Expand Down
9 changes: 9 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,15 @@ class NetBuilder {
bool adaptive = false,
const std::string& padding_algorithm = "EXPLICIT");

/**
* @brief Repeat elements of an array `repeats` times along axis `axis`
* @param x An input N-D variable.
* @param repeats The times of repeat operation.
* @param axis The index of dimension to repeat.
* @return The repeat result variable.
*/
Variable Repeat(const Variable& x, int repeats, int axis);

// *******************************************
// Broadcast operator
/**
Expand Down
98 changes: 98 additions & 0 deletions cinn/frontend/net_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1253,5 +1253,103 @@ TEST(net_build, program_argmin_case2) {
}
}

TEST(net_build, program_execute_repeat_axis_0) {
const int M = 4;
const int N = 4;
const int repeats = 3;
const int axis = 0;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {M, N}, "In");
Variable output = builder.Repeat(input, repeats, axis);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();

const int new_M = M * repeats;
const int new_N = N;
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], new_M);
EXPECT_EQ(output_shape[1], new_N);

float* output_data = output_tensor->mutable_data<float>(target);
for (int m = 0; m < new_M; ++m) {
for (int n = 0; n < new_N; ++n) {
int in_index = n + N * static_cast<int>(std::floor((float)m / repeats));
int out_index = n + new_N * m;
float in_data = input_data[in_index];
float out_data = output_data[out_index];
EXPECT_EQ(in_data, out_data);
}
}
}

TEST(net_build, program_execute_repeat_axis_1) {
const int M = 4;
const int N = 4;
const int repeats = 3;
const int axis = 1;

NetBuilder builder("net_builder");
Placeholder input = builder.CreateInput(Float(32), {M, N}, "In");
Variable output = builder.Repeat(input, repeats, axis);
auto program = builder.Build();

Target target = common::DefaultHostTarget();

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

scope->Var<hlir::framework::Tensor>(std::string(input.id()));
scope->Var<hlir::framework::Tensor>(std::string(output->id));

auto input_tensor = scope->GetTensor(std::string(input.id()));
SetRandData<float>(input_tensor, target);
float* input_data = input_tensor->mutable_data<float>(target);

runtime_program->Execute();

auto output_tensor = scope->GetTensor(std::string(output->id));
const std::vector<int>& output_shape = output_tensor->shape().data();

const int new_M = M;
const int new_N = N * repeats;
EXPECT_EQ(output_tensor->type(), Float(32));
EXPECT_EQ(output_shape.size(), 2UL);
EXPECT_EQ(output_shape[0], new_M);
EXPECT_EQ(output_shape[1], new_N);

float* output_data = output_tensor->mutable_data<float>(target);
for (int m = 0; m < new_M; ++m) {
for (int n = 0; n < new_N; ++n) {
int in_index = N * m + static_cast<int>(std::floor((float)n / repeats));
int out_index = n + new_N * m;
float in_data = input_data[in_index];
float out_data = output_data[out_index];
EXPECT_EQ(in_data, out_data);
}
}
}

} // namespace frontend
} // namespace cinn
2 changes: 2 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ gather_srcs(cinnapi_src SRCS
argmin.cc
argmax.cc
squeeze.cc
repeat.cc
)

cc_test(test_cast SRCS cast_test.cc DEPS cinncore)
Expand All @@ -24,3 +25,4 @@ cc_test(test_argmin SRCS argmin_test.cc DEPS cinncore)
cc_test(test_argmax SRCS argmax_test.cc DEPS cinncore)
cc_test(test_arange SRCS arange_test.cc DEPS cinncore)
cc_test(test_flip SRCS flip_test.cc DEPS cinncore)
cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore)
231 changes: 231 additions & 0 deletions cinn/hlir/op/contrib/repeat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/op/contrib/repeat.h"

#include <gflags/gflags.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/hlir/pe/transform.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_schedule.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValuePack;

std::vector<ir::Tensor> Repeat(const ir::Tensor &tensor, int repeats, int axis, const std::string &output_name) {
int ndim = static_cast<int>(tensor->shape.size());
CHECK(-ndim - 1 <= axis && axis <= ndim) << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
<< ", but got axis = " << axis << ", and data.ndim = " << ndim;
CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
<< ", but got repeats = " << repeats;

if (axis < 0) {
// Calculate offset from last dimension
axis += ndim;
}
std::vector<Expr> new_shape;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
new_shape.push_back(tensor->shape[i]);
}
new_shape.push_back(repeats * tensor->shape[axis]);
for (size_t i = axis + 1; i < tensor->shape.size(); ++i) {
new_shape.push_back(tensor->shape[i]);
}

ir::Tensor res = lang::Compute(
{new_shape},
[=](const std::vector<ir::Expr> &indices) {
std::vector<Expr> idx;
for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
idx.push_back(indices[i]);
}
Expr index_div = ir::Cast::Make(Float(32), indices[axis]) / ir::Cast::Make(Float(32), Expr(repeats));
idx.push_back(ir::Cast::Make(Int(32), lang::Floor(index_div)));
for (size_t i = axis + 1; i < indices.size(); ++i) {
idx.push_back(indices[i]);
}
return tensor(idx);
},
common::UniqName(output_name));
return {res};
}

std::vector<std::vector<int>> InferShapeForRepeat(const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 1U) << "The input's shape size should be 1! Please check again.";

int repeats = 0;
int axis = 0;
std::vector<int> new_shape;
const std::vector<int> &tensor_shape = inputs_shape[0];
int ndim = static_cast<int>(tensor_shape.size());

if (attrs.find("repeats") != attrs.end()) {
repeats = absl::get<int>(attrs.at("repeats"));
}
if (attrs.find("axis") != attrs.end()) {
axis = absl::get<int>(attrs.at("axis"));
}

if (axis < 0) {
// Calculate offset from last dimension
axis += ndim;
}

for (size_t i = 0; i < static_cast<size_t>(axis); ++i) {
new_shape.push_back(tensor_shape[i]);
}
new_shape.push_back(repeats * tensor_shape[axis]);
for (size_t i = axis + 1; i < tensor_shape.size(); ++i) {
new_shape.push_back(tensor_shape[i]);
}

std::vector<std::vector<int>> res{new_shape};
return res;
}

std::vector<Type> InferDtypeForRepeat(const std::vector<Type> &inputs_type, const framework::AttrMapType &attrs) {
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
std::vector<Type> res{inputs_type[0]};
return res;
}

std::shared_ptr<framework::OpStrategy> StrategyForRepeat(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
int repeats = 0;
int axis = 0;
for (auto &iter : attrs.attr_store) {
if (iter.first == "repeats") {
repeats = absl::get<int>(iter.second);
} else if (iter.first == "axis") {
axis = absl::get<int>(iter.second);
}
}

CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
<< ", but got repeats = " << repeats;

framework::CINNCompute repeat_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input arguments of Cast compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 1U) << "at least 1 input tensors for Cast compute\n";
Expr A = pack_args[0];
CHECK(A.as_tensor());
CHECK(!output_shapes.empty());
auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = common::UniqName("T_Repeat_out");

if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 1U);
tensor_name = pack_args[0].operator std::string();
}

std::vector<ir::Tensor> out = Repeat(tensor_A, repeats, axis, tensor_name);
CHECK(out.size() == 1U) << "The size of Repeat's output should be 1";

std::vector<common::CINNValue> res;
auto stages = CreateStages({tensor_A});
for (auto &t : out) {
stages->InsertLazily(t);
res.push_back(common::CINNValue(t));
}

res.push_back(common::CINNValue(stages));
*ret = common::CINNValuePack{res};
});

framework::CINNSchedule repeat_schedule([=](lang::Args args, lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) {
CHECK(!args.empty()) << "The input argument of repeat schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
long prod_size = std::accumulate(output_shapes[0].begin(), output_shapes[0].end(), 1, std::multiplies<int>());
if (prod_size > 1) {
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
}
std::vector<common::CINNValue> res{common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty()) << "The input argument of repeat schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
}
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(repeat_compute, repeat_schedule, "strategy.repeat.x86", 1);

return strategy;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(repeat_ops) {
CINN_REGISTER_OP(repeat)
.describe("Repeat elements of an array `repeats` times along axis `axis`")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForRepeat)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRepeat))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRepeat))
.set_support_level(4);

return true;
}
Loading