Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Tricky reduce fuse #1429

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions cinn/auto_schedule/cost_model/feature_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ VisitForMultiOperandsDtypePattern(Product, mul);
VisitCountMemberPattern(And, bool_op);
VisitCountMemberPattern(Or, bool_op);
VisitCountMemberPattern(Not, bool_op);
VisitCountMemberPattern(GetReference, mem_read);
VisitCountMemberPattern(Max, select_op);
VisitCountMemberPattern(Min, select_op);
VisitCountMemberPattern(IfThenElse, select_op);
Expand Down
4 changes: 4 additions & 0 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ std::string CodeGenC::GetTypeRepr(Type type) {
str += "*";
} else if (type.is_cpp_handle2()) {
str += "**";
} else if (type.is_cpp_reference()) {
str += "&";
}
return str;
}

void CodeGenC::Visit(const ir::IntImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::UIntImm *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::FloatImm *op) { IrPrinter::Visit(op); }
Expand Down Expand Up @@ -186,6 +189,7 @@ void CodeGenC::Visit(const ir::Not *op) {
IrPrinter::Print(op->v());
os() << ")";
}
void CodeGenC::Visit(const ir::GetReference *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Cast *op) { PrintCastExpr(op->type(), op->v()); }
void CodeGenC::Visit(const ir::For *op) {
Expr extent = op->extent;
Expand Down
5 changes: 5 additions & 0 deletions cinn/backends/codegen_cuda_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ std::vector<Expr> CodeGenCUDA_Dev::GenerateBufferAliasExprs(const ir::_LoweredFu

void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
// clear names valid within scope when enter a new function
std::set<Expr> device_count_exprs = op->PrepareDeviceCountExprs();
for (auto dce : device_count_exprs) {
os() << "__device__ int " << dce.As<ir::_Var_>()->name << " = 0\n";
}

vectorized_tensor_names_.clear();
os() << "__global__\n";

Expand Down
5 changes: 5 additions & 0 deletions cinn/backends/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Minus *op) {
return (op->type().is_int() || op->type().is_uint()) ? Neg(v) : FNeg(v);
}

llvm::Value *CodeGenLLVM::Visit(const ir::GetReference *op) {
LOG(FATAL) << "TODO: Unimplementd CodeGenLLVM::Visit(const ir::GetReference *op)";
return nullptr;
}

llvm::Value *CodeGenLLVM::Visit(const ir::Not *op) { return Not(Visit(&op->v())); }

llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) {
Expand Down
2 changes: 1 addition & 1 deletion cinn/common/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ __host__ __device__ inline bool(isfinite)(const bfloat16& a) { return !((isnan)(

__host__ __device__ inline bfloat16(abs)(const bfloat16& a) {
#if defined(CINN_CUDA_BF16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __habs(a.to_nv_bfloat16());
return bfloat16(__habs(a.to_nv_bfloat16()));
#else
return bfloat16(std::abs(static_cast<float>(a)));
#endif
Expand Down
18 changes: 18 additions & 0 deletions cinn/common/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ Type &Type::set_cpp_handle2(bool x) {
return *this;
}

Type &Type::set_cpp_reference(bool x) {
auto &v = (*reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_));

// unset the other handle-related bits.
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);

if (x)
v |= static_cast<uint8_t>(cpp_type_t::Reference);
else
v &= ~static_cast<uint8_t>(cpp_type_t::Reference);

return *this;
}

Type Type::VectorOf(int w) const {
CheckTypeValid();
return Type(type(), bits(), w, specific_type());
Expand Down Expand Up @@ -263,6 +278,9 @@ bool Type::is_cpp_handle() const {
bool Type::is_cpp_handle2() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) & static_cast<uint8_t>(cpp_type_t::HandleHandle);
}
bool Type::is_cpp_reference() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) & static_cast<uint8_t>(cpp_type_t::Reference);
}
bool Type::is_cpp_const() const {
return static_cast<uint8_t>(cpp_type_t::Const) & static_cast<uint8_t>(GetStorage().cpp_type_);
}
Expand Down
4 changes: 4 additions & 0 deletions cinn/common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ struct Type {
Const = 1, // const.
Handle = 1 << 1, // pointer type, such as `cinn_buffer_t*`.
HandleHandle = 1 << 2, // pointer of pointer, such as `cinn_buffer_t**`.
Reference = 1 << 4, // reference type, such as `cinn_buffer_t&`.
};

Type();
Expand Down Expand Up @@ -100,6 +101,9 @@ struct Type {
Type& set_cpp_handle2(bool x = true);
CINN_NODISCARD bool is_cpp_handle2() const;

Type& set_cpp_reference(bool x = true);
CINN_NODISCARD bool is_cpp_reference() const;

Type& set_cpp_const(bool is_const = true);
CINN_NODISCARD bool is_cpp_const() const;

Expand Down
15 changes: 14 additions & 1 deletion cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ std::vector<ir::LoweredFunc> OpLowerer::IRLowerOp(IRComputeFunction compute,
ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ true);
} else {
for (auto& sub_group : group->fused_sub_groups) {
VLOG(4) << "sub_group->group_id = " << sub_group->group_id;
auto exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ true);
VLOG(4) << "==== Exprs are ====";
for (auto& e : exprs) {
VLOG(4) << e;
}
VLOG(4) << "==== End of Exprs ====";
ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end());
}
}
Expand Down Expand Up @@ -1275,10 +1281,17 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch,
ir_sch.Split(loops[0], splits);
}
}

VLOG(3) << "Before loop fusion, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
VLOG(4) << " FUSION " << node->op()->name;
// do loop fuse.
LoopComputeAt(ir_sch, node, master ? master : nodes_in_order.front(), group, this->shape_dict_, tensor_map);
Node* fusion_master = master ? master : nodes_in_order.front();

if (CanFuseReduceByBlockSync(ir_sch, node, fusion_master, group, this->shape_dict_, tensor_map)) {
SyncGpuBlocks(ir_sch, node, fusion_master, group, this->shape_dict_, tensor_map);
} else {
LoopComputeAt(ir_sch, node, master ? master : nodes_in_order.front(), group, this->shape_dict_, tensor_map);
}
VLOG(3) << "After loop fusion, ir is: \n" << ir_sch.GetModule().GetExprs().at(0);
}

Expand Down
51 changes: 51 additions & 0 deletions cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "cinn/hlir/framework/op_lowering_util.h"

#include "cinn/hlir/pe/nn_util.h"
#include "cinn/utils/string.h"
#ifdef CINN_WITH_CUDA
#include "cinn/common/bfloat16.h"
#include "cinn/common/float16.h"
Expand Down Expand Up @@ -1096,6 +1097,7 @@ void MergeReduceToReduce(ir::IRSchedule& ir_sch,
const Node* master,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
VLOG(6) << "Calling MergeReduceToReduceLoop, node->id() = " << node->id();
auto node_data = GetNodeData(node);
auto master_data = GetNodeData(master);

Expand Down Expand Up @@ -1280,6 +1282,7 @@ void MergeReduceLoop(ir::IRSchedule& ir_sch,
const Node* master,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
VLOG(6) << "Calling MergeReduceLoop, node->id() = " << node->id();
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
if (op_pattern_dict[master->op()] == kReduction && node != master) {
MergeReduceToReduce(ir_sch, node, master, shape_dict, tensor_map);
Expand Down Expand Up @@ -1429,6 +1432,54 @@ void LoopComputeAt(ir::IRSchedule& ir_sch,
} while (--index >= 0);
}

bool CanFuseReduceByBlockSync(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
if (op_pattern_dict[node->op()] == framework::kReduction && op_pattern_dict[master->op()] == framework::kReduction &&
node != master) {
auto node_shape = shape_dict.at(node->inlinks_in_order()[0]->source()->id());
auto master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id());

VLOG(6) << "Checking CanFuseReduceByBlockSync";
VLOG(6) << "node->id() = " << node->id() << ", node_shape.size() = " << node_shape.size();
VLOG(6) << "master->id() = " << master->id() << ", master_shape.size() = " << master_shape.size();

static std::unordered_set<std::string> reduce_op_type = {
"reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"};
for (const std::string& op_type : reduce_op_type) {
// TODO: this may speed up not only reduce_xxx_split nodes, but we limit it to reduce_xxx_split nodes for accuracy
// safety
if (cinn::utils::Startswith(master->id(), op_type + "_split") &&
cinn::utils::Startswith(node->id(), op_type + "_split")) {
return true;
}
}
}
return false;
}

void SyncGpuBlocks(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map) {
VLOG(6) << "Calling SyncGpuBlocks";
if (!group->output_nodes.count(node)) {
auto block = ir_sch.GetBlock(GetNodeData(node)->id());
ir_sch.SetBuffer(block, "local", true);
}
auto node_data = GetNodeData(node);
auto master_data = GetNodeData(master);
auto node_block = ir_sch.GetBlock(node->id());
auto master_block = ir_sch.GetBlock(master_data->id());
ir_sch.SyncGpuBlocks(master_block, node_block);
}

std::unordered_map<std::string, NodeData*> GetNodeDataSet(const std::unordered_set<Node*>& nodes_set) {
std::unordered_map<std::string, NodeData*> node_data_set;
for (auto node : nodes_set) {
Expand Down
14 changes: 14 additions & 0 deletions cinn/hlir/framework/op_lowering_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ void LoopComputeAt(ir::IRSchedule& ir_sch,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

bool CanFuseReduceByBlockSync(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

void SyncGpuBlocks(ir::IRSchedule& ir_sch,
Node* node,
const Node* master,
const GroupPtr& group,
const absl::flat_hash_map<std::string, shape_t>& shape_dict,
const std::unordered_map<std::string, ir::Tensor>& tensor_map);

void SyncThreadWithShared(ir::IRSchedule& ir_sch,
const GroupPtr& group,
const std::unordered_set<Node*>& nodes_inline,
Expand Down
31 changes: 26 additions & 5 deletions cinn/hlir/pass/fusion_merge_pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,20 @@ CONDITION_FUNC(injective_horizontal_with_reduce) {
return elementwise_fuse_reduce(helper, first, second);
}

inline bool ReduceSplitCanFuse(const Node* producer, const Node* reducer) {
static std::unordered_set<std::string> reduce_op_type = {
"reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"};
VLOG(6) << "Checking ReduceSplitCanFuse";
VLOG(6) << "producer->id() = " << producer->id();
VLOG(6) << "reducer->id() = " << reducer->id();
for (const std::string& op_type : reduce_op_type) {
if (utils::Startswith(producer->id(), op_type + "_split") && utils::Startswith(reducer->id(), op_type + "_split")) {
return true;
}
}
return false;
}

CONDITION_FUNC(reduce_fuse_broadcast) {
// if same shape with horizontal relation
if (is_same_size(helper, first, second)) {
Expand All @@ -294,8 +308,8 @@ CONDITION_FUNC(reduce_fuse_broadcast) {
// Traversing all reducers in all producers requires two types of conditions to be met.
// The first type is the condition that the reducer itself needs to meet,
// and the second type is the condition that the relationship between each reducer and its consumers with type of
// Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as
// before reduce.
// Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after
// broadcast as before reduce.
for (auto& node_in_master : first->master_nodes) {
if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) {
continue;
Expand Down Expand Up @@ -388,9 +402,7 @@ CONDITION_FUNC(reduce_fuse_broadcast) {
}

CONDITION_FUNC(reduce_fuse_reduce) {
if (!limit_args(helper, first, second)) {
return false;
}
VLOG(6) << "In reduce_fuse_reduce";
Node* reducer_0 = nullptr;
for (auto& reducer : first->master_nodes) {
if (helper->GetOpKind(reducer) == OpPatternKind::kReduction) {
Expand All @@ -409,6 +421,15 @@ CONDITION_FUNC(reduce_fuse_reduce) {
}
CHECK(reducer_1) << "Can't find reduce op in group " << second->group_id;

if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) {
return ReduceSplitCanFuse(reducer_0, reducer_1);
}

// reduce relation is horizontal with reduce.
if (!limit_args(helper, first, second)) {
return false;
}

// check reduce has same input shape and output shape
auto reducer_0_input_shape = helper->shape_dict_.at(reducer_0->inlinks_in_order()[0]->source()->id());
auto reducer_0_output_shape = helper->shape_dict_.at(reducer_0->outlinks_in_order()[0]->sink()->id());
Expand Down
20 changes: 20 additions & 0 deletions cinn/hlir/pass/op_fusion_pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,34 @@ CONDITION_FUNC(without_last_dimension_in_reduce) {
return helper->WithoutLastDimInReduce(in_shape, reduce_axes);
}

inline bool ReduceSplitCanFuse(const Node* producer, const Node* reducer) {
static std::unordered_set<std::string> reduce_op_type = {
"reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_all", "reduce_any"};
VLOG(6) << "Checking ReduceSplitCanFuse";
VLOG(6) << "producer->id() = " << producer->id();
VLOG(6) << "reducer->id() = " << reducer->id();
for (const std::string& op_type : reduce_op_type) {
if (utils::Startswith(producer->id(), op_type + "_split") && utils::Startswith(reducer->id(), op_type + "_split")) {
return true;
}
}
return false;
}

CONDITION_FUNC(reduce_fuse_reduce) {
VLOG(6) << "In reduce_fuse_reduce";
Node* reducer = NULL;
for (auto* master : consumer->master_nodes) {
if (helper->GetOpKind(master) == framework::kReduction) {
reducer = master;
break;
}
}

if (ReduceSplitCanFuse(producer, reducer)) {
return true;
}

// check reduce has same input shape and output shape
auto producer_input_shape = helper->shape_dict_.at(producer->inlinks_in_order()[0]->source()->id());
auto producer_output_shape = helper->shape_dict_.at(producer->outlinks_in_order()[0]->sink()->id());
Expand Down
9 changes: 9 additions & 0 deletions cinn/ir/ir.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ void Not::Verify() const { CHECK_EQ(v().type(), type_of<bool>()); }

Type Not::type() const { return type_; }

Expr GetReference::Make(Expr v) {
auto node = make_shared<GetReference>(v);
return Expr(node);
}

void GetReference::Verify() const { CHECK(v().defined()); }

Type GetReference::type() const { return type_; }

Expr Let::Make(Expr symbol, Expr body) {
auto *n = make_shared<Let>();
CHECK(symbol.type().valid());
Expand Down
16 changes: 16 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,22 @@ struct Not : public UnaryOpNode<Not> {
static const IrNodeTy _node_type_ = IrNodeTy::Not;
};

/**
* Get reference, such as C++ &x
*
* TODO: transformers for this Node is not completed. Be careful to use it.
*/
struct GetReference : public UnaryOpNode<GetReference> {
explicit GetReference(Expr v) : UnaryOpNode<GetReference>(common::Int(32).set_cpp_reference(), v) {}

static Expr Make(Expr v);

Type type() const override;
void Verify() const override;

static const IrNodeTy _node_type_ = IrNodeTy::GetReference;
};

struct Let : public ExprNode<Let> {
Expr symbol;
Expr body;
Expand Down
1 change: 1 addition & 0 deletions cinn/ir/ir_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ScheduleBlockRealize;
#define NODETY_UNARY_OP_FOR_EACH(macro__) \
macro__(Minus) \
macro__(Not) \
macro__(GetReference) \

#define NODETY_OP_FOR_EACH(macro__) NODETY_BINARY_OP_FOR_EACH(macro__) NODETY_UNARY_OP_FOR_EACH(macro__)

Expand Down
4 changes: 4 additions & 0 deletions cinn/ir/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ void IrPrinter::Visit(const Minus *x) {
Print(x->v());
os_ << ")";
}
void IrPrinter::Visit(const GetReference *x) {
os_ << "&";
Print(x->v());
}
void IrPrinter::Visit(const For *x) {
if (x->is_parallel()) {
os() << "parallel for (";
Expand Down
Loading