Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Schedule error #1535

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cinn/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gather_srcs(cinnapi_src SRCS
ir_base.cc
ir_schedule.cc
ir_schedule_util.cc
ir_schedule_error.cc
ir_visitor.cc
ir_printer.cc
ir_mutator.cc
Expand Down
43 changes: 37 additions & 6 deletions cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ namespace ir {
class ScheduleImpl {
public:
ScheduleImpl() = default;
explicit ScheduleImpl(const ModuleExpr& module_expr, bool debug_flag = false)
: module_expr_(module_expr), debug_flag_(debug_flag) {}
explicit ScheduleImpl(const ModuleExpr& module_expr,
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank)
: module_expr_(module_expr), debug_flag_(debug_flag), err_msg_level_(err_msg_level) {}
explicit ScheduleImpl(ModuleExpr&& module_expr) : module_expr_(std::move(module_expr)) {}

//! Set the debug flag.
Expand Down Expand Up @@ -114,8 +116,32 @@ class ScheduleImpl {

ModuleExpr module_expr_;
bool debug_flag_{false};
ScheduleErrorMessageLevel err_msg_level_;
};

/** \brief A macro that guards the beginning of each implementation of schedule */
#define CINN_IR_SCHEDULE_BEGIN() try {
/**
* \brief A macro that pairs with `CINN_IR_SCHEDULE_BEGIN`, handling potential errors and error
* message printing
* \param primitive A string representing the kind of schedule primitive
* \param err_msg_level A ScheduleErrorMessageLevel enum, level of error message printing
*/
#define CINN_IR_SCHEDULE_END(primitive, err_msg_level) \
} \
catch (const IRScheduleErrorHandler& err_hanlder) { \
switch (err_msg_level) { \
case ScheduleErrorMessageLevel::kDetailed: \
throw std::runtime_error(err_hanlder.FormatErrorMessage(primitive)); \
case ScheduleErrorMessageLevel::kGenearl: \
throw std::runtime_error(err_hanlder.GeneralErrorMessage()); \
case ScheduleErrorMessageLevel::kBlank: \
throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \
default: \
throw std::runtime_error("IRScheduleError occurred! (No more error message)"); \
} \
}

std::vector<Expr> ScheduleImpl::Split(const Expr& loop, const std::vector<int>& factors) {
CHECK(loop.As<ir::For>()) << "Expr param of Split must be For node! Please check.";
auto* for_node = loop.As<ir::For>();
Expand All @@ -126,8 +152,10 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop, const std::vector<int>&
VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " << tot_extent << ") to ("
<< cinn::utils::Join(factors, ", ") << ") at loop:\n"
<< loop;

auto processed_factors = ValidateFactors(factors, tot_extent);
std::vector<int> processed_factors;
CINN_IR_SCHEDULE_BEGIN();
processed_factors = ValidateFactors(factors, tot_extent);
CINN_IR_SCHEDULE_END("split", this->err_msg_level_);
int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, std::multiplies<int>());
std::vector<Var> new_loop_vars;
Expr substitute_value(0);
Expand Down Expand Up @@ -1971,8 +1999,11 @@ Expr ScheduleImpl::SampleCategorical(utils::LinearRandomEngine::StateType* rand_

IRSchedule::IRSchedule() {}

IRSchedule::IRSchedule(const ModuleExpr& module_expr, utils::LinearRandomEngine::StateType rand_seed, bool debug_flag) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag);
IRSchedule::IRSchedule(const ModuleExpr& module_expr,
utils::LinearRandomEngine::StateType rand_seed,
bool debug_flag,
ScheduleErrorMessageLevel err_msg_level) {
impl_ = std::make_unique<ScheduleImpl>(module_expr, debug_flag, err_msg_level);
this->InitSeed(rand_seed);
}

Expand Down
4 changes: 3 additions & 1 deletion cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_mutator.h"
#include "cinn/ir/ir_schedule_error.h"
#include "cinn/ir/schedule_desc.h"
#include "cinn/ir/tensor.h"
#include "cinn/utils/random_engine.h"
Expand Down Expand Up @@ -67,7 +68,8 @@ class IRSchedule {
IRSchedule();
explicit IRSchedule(const ModuleExpr& modexpr,
utils::LinearRandomEngine::StateType rand_seed = -1,
bool debug_flag = false);
bool debug_flag = false,
ScheduleErrorMessageLevel err_msg_level = ScheduleErrorMessageLevel::kBlank);
IRSchedule(ir::ModuleExpr&& mod_expr, ScheduleDesc&& trace, utils::LinearRandomEngine::StateType rand_seed = -1);
IRSchedule(const IRSchedule& other);
IRSchedule& operator=(const IRSchedule& src);
Expand Down
30 changes: 30 additions & 0 deletions cinn/ir/ir_schedule_error.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/ir/ir_schedule_error.h"

namespace cinn {
namespace ir {

std::string IRScheduleErrorHandler::FormatErrorMessage(const std::string &primitive) const {
std::ostringstream os;
std::string err_msg = DetailedErrorMessage();

os << "[IRScheduleError] An error occurred in the scheduel primitive <" << primitive << ">. " << std::endl;
os << "Error info: " << err_msg;
return os.str();
}

} // namespace ir
} // namespace cinn
67 changes: 67 additions & 0 deletions cinn/ir/ir_schedule_error.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <vector>

namespace cinn {
namespace ir {

/**
* \brief Indicates the level of printing error message in the current Schedule
*/
enum class ScheduleErrorMessageLevel : int32_t {
/** \brief No error message*/
kBlank = 0,
/** \brief Print an error message in short mode*/
kGenearl = 1,
/** \brief Print an error message in detailed mode*/
kDetailed = 2,
};

/**
* This handler is dealing with the errors happen in in the current Scheduling.
*/
class IRScheduleErrorHandler : public std::runtime_error {
public:
IRScheduleErrorHandler() : std::runtime_error("") {}
/**
* \brief constructor
* \param s the error message
*/
explicit IRScheduleErrorHandler(const std::string &s) : std::runtime_error(s) {}

/**
* \brief Returns a detailed error message corresponding to the kDetailed error level.
*/
std::string FormatErrorMessage(const std::string &primitive) const;

/**
* \brief Returns a short error message corresponding to the kGeneral error level.
*/
virtual std::string GeneralErrorMessage() const = 0;

/**
* \brief Returns a detailed error message corresponding to the kDetailed error level.
*/
virtual std::string DetailedErrorMessage() const = 0;
};

} // namespace ir
} // namespace cinn
71 changes: 63 additions & 8 deletions cinn/ir/ir_schedule_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/ir/ir_schedule_error.h"
#include "cinn/ir/ir_visitor.h"
#include "cinn/lang/compute.h"
#include "cinn/optim/ir_copy.h"
Expand Down Expand Up @@ -196,27 +197,81 @@ void ReplaceExpr(Expr* source, const std::vector<Var>& replaced, const std::vect
}

std::vector<int> ValidateFactors(const std::vector<int>& factors, int total_extent) {
class NegativeFactorErrorHandler : public IRScheduleErrorHandler {
public:
explicit NegativeFactorErrorHandler(int64_t factor, size_t idx) : factor_(factor), idx_(idx) {}

std::string GeneralErrorMessage() const final {
return "[IRScheduleError]: The params in factors of Split should be positive. However, some "
"factor is zero or negative.";
}

std::string DetailedErrorMessage() const final {
std::ostringstream os;
os << "The params in factors of Split should be positive. However, the factor at position " << idx_ << " is "
<< factor_;
return os.str();
}

private:
int64_t factor_;
size_t idx_;
};

class InferFactorErrorHandler : public IRScheduleErrorHandler {
public:
std::string GeneralErrorMessage() const final {
return "[IRScheduleError]: The params in factors of Split should not be less than -1 or have more than one -1!";
}

std::string DetailedErrorMessage() const final {
std::ostringstream os;
os << "The params in factors of Split should not be less than -1 or have more than one -1!";
return os.str();
}
};

class FactorProductErrorHandler : public IRScheduleErrorHandler {
public:
std::string GeneralErrorMessage() const final {
return "[IRScheduleError]: In Split, the factors' product should be not larger than or equal to original loop's "
"extent!";
}

std::string DetailedErrorMessage() const final {
std::ostringstream os;
os << "In Split, the factors' product should be not larger than or equal to original loop's extent!";
return os.str();
}
};

CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check.";
bool has_minus_one = false;
int product = 1;
int idx = -1;
for (auto& i : factors) {
CHECK(i != 0) << "The params in factors of Split should not be 0! Please check.";
CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check.";
if (i == -1) {
CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check.";
idx++;
if (i == 0 || i < -1) {
throw NegativeFactorErrorHandler(i, idx);
} else if (i == -1) {
if (has_minus_one) {
throw InferFactorErrorHandler();
}
has_minus_one = true;
} else {
product *= i;
}
}
std::vector<int> validated_factors = factors;
if (!has_minus_one) {
CHECK_GE(product, total_extent)
<< "In Split, the factors' product should be equal to original loop's extent! Please check.";
if (product < total_extent) {
throw FactorProductErrorHandler();
}
return validated_factors;
} else {
CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= "
"original loop's extent! Please check.";
if (product > total_extent) {
throw FactorProductErrorHandler();
}
int minus_one_candidate = (int)ceil((double)total_extent / (double)product);
for (int i = 0; i < validated_factors.size(); ++i) {
if (validated_factors[i] == -1) {
Expand Down