Skip to content

Commit

Permalink
Better Exception Handling for Operators (apache#9681)
Browse files Browse the repository at this point in the history
* Add support for threaded engine

* Add support for threaded engine

* Remove on_start_callback for else

* Add support for global_ex_ptr

* Rethrow in waitall only once

* run tests for gpu

* Add comments for exception_ptr

* Fix lint

* Push exc_handling tests

* Add comments for OnStart

* Fixes for exc handling

* Catch std::exception for all other exceptions

* Rollback std::move use

* Fix style

* Fix onstart

* Fix debug_info

* Throw exception only once in an execution graph

* make test naming consistent

* Fix symbolic test

* Remove unused code
  • Loading branch information
anirudh2290 authored and zheng-da committed Jun 28, 2018
1 parent f01bf9c commit 6da2eec
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 39 deletions.
8 changes: 6 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ class MXNET_API Engine {
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation
* \return The new operator allocated.
*/
virtual OprHandle NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Delete the given operator.
* \param op The operator to delete.
Expand Down Expand Up @@ -176,13 +178,15 @@ class MXNET_API Engine {
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation
*/
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Schedule the deletion of a variable.
*
Expand Down
6 changes: 4 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr,
bool wait = false) override {
NaiveOpr *opr = new NaiveOpr();
opr->fn = fn;
opr->const_vars = const_vars;
Expand Down Expand Up @@ -125,7 +126,8 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr,
bool wait = false) override {
CallbackOnComplete callback = CreateCallback(
NaiveEngine::OnComplete, nullptr);
this->req_completed_ = false;
Expand Down
39 changes: 29 additions & 10 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ ThreadedOpr* ThreadedEngine::NewOperator(
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
const char* opr_name) {
const char* opr_name,
bool wait) {
auto ret = ThreadedOpr::New();
ret->opr_name = opr_name;
ret->fn = std::move(fn);
ret->prop = prop;
ret->const_vars.resize(const_vars.size());
ret->mutable_vars.resize(mutable_vars.size());
ret->wait = wait;
std::transform(const_vars.begin(), const_vars.end(),
ret->const_vars.begin(), ThreadedVar::CastFromBase);
std::transform(mutable_vars.begin(), mutable_vars.end(),
Expand Down Expand Up @@ -305,9 +307,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
int priority,
const char* opr_name) {
const char* opr_name,
bool wait) {
BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name);
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
#if MXNET_USE_PROFILER
Profiler *profiler = Profiler::Get();
Expand Down Expand Up @@ -356,7 +359,10 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
void ThreadedEngine::WaitForVar(VarHandle var) {
BulkFlush();
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) return;
if (threaded_var->ready_to_read()) {
ThrowException(threaded_var);
return;
}
if (engine_info_) {
LOG(INFO) << "Wait for " << threaded_var;
debug_wait_var_ = threaded_var;
Expand All @@ -376,13 +382,15 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
}
on_complete();
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("WaitForVar"));
PROFILER_MESSAGE("WaitForVar"), true);
{
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this, &done]() {
return done.load() || kill_.load();
});
}

ThrowException(threaded_var);
}

void ThreadedEngine::WaitForAll() {
Expand All @@ -397,18 +405,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
bool is_temporary_opr = threaded_opr->temporary;
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
i->CompleteReadDependency([this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
i->CompleteReadDependency(
[this](OprBlock* opr) { this->PushToExecute(opr, false); });
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
LOG(INFO) << "Complete write dep for " << i;
}
const bool to_delete = i->CompleteWriteDependency(
[this, debug_info](OprBlock* opr) {
const bool to_delete =
i->CompleteWriteDependency([this, debug_info](OprBlock* opr) {
if (debug_info) {
LOG(INFO) << "PushToExecute " << opr;
debug_push_opr_ = opr;
Expand Down Expand Up @@ -443,6 +453,15 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
}

inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
if (threaded_var->var_exception && *threaded_var->var_exception) {
std::exception_ptr tmp = *threaded_var->var_exception;
*threaded_var->var_exception = nullptr;
std::rethrow_exception(tmp);
}
return;
}

void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *opr_block_) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
Expand Down
94 changes: 72 additions & 22 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::shared_ptr<std::exception_ptr> var_exception;

private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr,
* that can be deleted right after the operation completed.
*/
bool temporary{false};
/*!
* \brief Whether this is a WaitForVar operation
*/
bool wait{false};
/*!
* \brief Cast a Opr pointer to ThreadedOpr pointer
* \param ptr pointer from base.
Expand All @@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::shared_ptr<std::exception_ptr> opr_exception;
}; // struct ThreadedOpr

/*!
Expand All @@ -265,15 +273,17 @@ class ThreadedEngine : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr,
bool wait = false) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr,
bool wait = false) override;
void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
Expand Down Expand Up @@ -321,50 +331,63 @@ class ThreadedEngine : public Engine {
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
#if MXNET_USE_PROFILER
if (opr_block->profiling && threaded_opr->opr_name) {
const Context& ctx = opr_block->ctx;
opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
opr_block->opr_stat =
Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id());
opr_block->opr_stat->thread_id = id;
strncpy(opr_block->opr_stat->opr_name,
threaded_opr->opr_name,
sizeof(opr_block->opr_stat->opr_name) - 1);
strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name,
sizeof(opr_block->opr_stat->opr_name) - 1);
// record operator start timestamp
SetOprStart(opr_block->opr_stat);
}
#endif
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, opr_block);
bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
CallbackOnComplete callback =
this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
if (debug_info) {
LOG(INFO) << "ExecuteOprBlock " << opr_block
<< "shutdown_phase=" << shutdown_phase_;
}
if (!shutdown_phase_) {
try {
OnStart(threaded_opr);
if (debug_info) {
LOG(INFO) << "ExecuteOprFn ";
}
threaded_opr->fn(run_ctx, callback);
try {
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
}
} catch (dmlc::Error& e) {
threaded_opr->opr_exception =
std::make_shared<std::exception_ptr>(std::current_exception());
callback();
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch(dmlc::Error &e) {
} catch (std::exception& e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos &&
!shutdown_phase_) {
LOG(FATAL) << e.what() << "\n" <<
"A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
LOG(FATAL)
<< e.what() << "\n"
<< "A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
}
}
} else {
Expand Down Expand Up @@ -414,7 +437,34 @@ class ThreadedEngine : public Engine {
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
// callback to the threaded engine
/*!
* \brief rethrow caught exception in WaitForVar
* \param threaded_var the var that we are waiting to read
*/
inline void ThrowException(ThreadedVar* threaded_var);
/*!
* \brief Mark exceptions before operation execution.
*
* Will mark the operator as a failure and associate exception_ptr
* if any of the read dependencies have exception associated.
*/
inline void OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
}
}

static void OnCompleteStatic(Engine *engine, void *threaded_opr);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
Expand Down
4 changes: 2 additions & 2 deletions src/storage/cpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) throw std::bad_alloc();
if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
#else
int ret = posix_memalign(&ptr, alignment_, size);
if (ret != 0) throw std::bad_alloc();
if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
#endif
return ptr;
}
Expand Down
2 changes: 1 addition & 1 deletion src/storage/gpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
throw std::bad_alloc();
LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from test_random import *
from test_gluon import *
from test_loss import *
from test_exc_handling import *
#from test_rnn import *
from test_gluon_rnn import *
from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
Expand Down
Loading

0 comments on commit 6da2eec

Please sign in to comment.