Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Better Exception Handling for Operators (#9681)
Browse files Browse the repository at this point in the history
* Add support for threaded engine

* Add support for threaded engine

* Remove on_start_callback for else

* Add support for global_ex_ptr

* Rethrow in waitall only once

* run tests for gpu

* Add comments for exception_ptr

* Fix lint

* Push exc_handling tests

* Add comments for OnStart

* Fixes for exc handling

* Catch std::exception for all other exceptions

* Rollback std::move use

* Fix style

* Fix onstart

* Fix debug_info

* Throw exception only once in an execution graph

* make test naming consistent

* Fix symbolic test

* Remove unused code
  • Loading branch information
anirudh2290 authored and piiswrong committed Feb 13, 2018
1 parent f57073e commit 7b24137
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 39 deletions.
8 changes: 6 additions & 2 deletions include/mxnet/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ class MXNET_API Engine {
* \param mutable_vars The variables that current operation will mutate.
* \param prop Property of the function.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation
* \return The new operator allocated.
*/
virtual OprHandle NewOperator(AsyncFn fn,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Delete the given operator.
* \param op The operator to delete.
Expand Down Expand Up @@ -176,13 +178,15 @@ class MXNET_API Engine {
* \param prop Property of the function.
* \param priority Priority of the action, as hint to the engine.
* \param opr_name The operator name.
* \param wait Whether this is a WaitForVar operation
*/
virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) = 0;
const char* opr_name = nullptr,
bool wait = false) = 0;
/*!
* \brief Schedule the deletion of a variable.
*
Expand Down
6 changes: 4 additions & 2 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr,
bool wait = false) override {
NaiveOpr *opr = new NaiveOpr();
opr->fn = fn;
opr->const_vars = const_vars;
Expand Down Expand Up @@ -125,7 +126,8 @@ class NaiveEngine final : public Engine {
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override {
const char* opr_name = nullptr,
bool wait = false) override {
CallbackOnComplete callback = CreateCallback(
NaiveEngine::OnComplete, nullptr);
this->req_completed_ = false;
Expand Down
39 changes: 29 additions & 10 deletions src/engine/threaded_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ ThreadedOpr* ThreadedEngine::NewOperator(
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
const char* opr_name) {
const char* opr_name,
bool wait) {
auto ret = ThreadedOpr::New();
ret->opr_name = opr_name;
ret->fn = std::move(fn);
ret->prop = prop;
ret->const_vars.resize(const_vars.size());
ret->mutable_vars.resize(mutable_vars.size());
ret->wait = wait;
std::transform(const_vars.begin(), const_vars.end(),
ret->const_vars.begin(), ThreadedVar::CastFromBase);
std::transform(mutable_vars.begin(), mutable_vars.end(),
Expand Down Expand Up @@ -305,9 +307,10 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop,
int priority,
const char* opr_name) {
const char* opr_name,
bool wait) {
BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name);
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait);
opr->temporary = true;
#if MXNET_USE_PROFILER
Profiler *profiler = Profiler::Get();
Expand Down Expand Up @@ -356,7 +359,10 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
void ThreadedEngine::WaitForVar(VarHandle var) {
BulkFlush();
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) return;
if (threaded_var->ready_to_read()) {
ThrowException(threaded_var);
return;
}
if (engine_info_) {
LOG(INFO) << "Wait for " << threaded_var;
debug_wait_var_ = threaded_var;
Expand All @@ -376,13 +382,15 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
}
on_complete();
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("WaitForVar"));
PROFILER_MESSAGE("WaitForVar"), true);
{
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this, &done]() {
return done.load() || kill_.load();
});
}

ThrowException(threaded_var);
}

void ThreadedEngine::WaitForAll() {
Expand All @@ -397,18 +405,20 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
bool is_temporary_opr = threaded_opr->temporary;
// Mark complete for read variables
for (auto&& i : threaded_opr->const_vars) {
i->CompleteReadDependency([this](OprBlock* opr) {
this->PushToExecute(opr, false);
});
i->CompleteReadDependency(
[this](OprBlock* opr) { this->PushToExecute(opr, false); });
}
// Mark complete for write variables.
for (auto&& i : threaded_opr->mutable_vars) {
if (threaded_opr->opr_exception && *threaded_opr->opr_exception) {
i->var_exception = threaded_opr->opr_exception;
}
const bool debug_info = (engine_info_ && debug_wait_var_ == i);
if (debug_info) {
LOG(INFO) << "Complete write dep for " << i;
}
const bool to_delete = i->CompleteWriteDependency(
[this, debug_info](OprBlock* opr) {
const bool to_delete =
i->CompleteWriteDependency([this, debug_info](OprBlock* opr) {
if (debug_info) {
LOG(INFO) << "PushToExecute " << opr;
debug_push_opr_ = opr;
Expand Down Expand Up @@ -443,6 +453,15 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) {
}
}

inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) {
if (threaded_var->var_exception && *threaded_var->var_exception) {
std::exception_ptr tmp = *threaded_var->var_exception;
*threaded_var->var_exception = nullptr;
std::rethrow_exception(tmp);
}
return;
}

void ThreadedEngine::OnCompleteStatic(
Engine *engine, void *opr_block_) {
OprBlock *opr_block = static_cast<OprBlock*>(opr_block_);
Expand Down
94 changes: 72 additions & 22 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class ThreadedVar final
static std::atomic<std::size_t> counter;
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; }
#endif // ENGINE_DEBUG
/*! \brief exception_ptr associated with the ThreadedVar */
std::shared_ptr<std::exception_ptr> var_exception;

private:
// TODO(hotpxl) change this to spinlock for faster runtime
Expand Down Expand Up @@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr,
* that can be deleted right after the operation completed.
*/
bool temporary{false};
/*!
* \brief Whether this is a WaitForVar operation
*/
bool wait{false};
/*!
* \brief Cast a Opr pointer to ThreadedOpr pointer
* \param ptr pointer from base.
Expand All @@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr,
}
// define possible debug information
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr);
/*! \brief exception_ptr associated with the ThreadedOpr */
std::shared_ptr<std::exception_ptr> opr_exception;
}; // struct ThreadedOpr

/*!
Expand All @@ -265,15 +273,17 @@ class ThreadedEngine : public Engine {
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr,
bool wait = false) override;
void DeleteOperator(OprHandle op) override;
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override;
void PushAsync(AsyncFn exec_fun, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
const char* opr_name = nullptr,
bool wait = false) override;
void PushSync(SyncFn exec_fn, Context exec_ctx,
std::vector<VarHandle> const& const_vars,
std::vector<VarHandle> const& mutable_vars,
Expand Down Expand Up @@ -321,50 +331,63 @@ class ThreadedEngine : public Engine {
* \param run_ctx runtime context used to execute the function.
* \param opr_block the opr_block to be executed and deleted.
*/
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) {
void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) {
ThreadedOpr* threaded_opr = opr_block->opr;
#if MXNET_USE_PROFILER
if (opr_block->profiling && threaded_opr->opr_name) {
const Context& ctx = opr_block->ctx;
opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
opr_block->opr_stat =
Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id);
uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id());
opr_block->opr_stat->thread_id = id;
strncpy(opr_block->opr_stat->opr_name,
threaded_opr->opr_name,
sizeof(opr_block->opr_stat->opr_name) - 1);
strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name,
sizeof(opr_block->opr_stat->opr_name) - 1);
// record operator start timestamp
SetOprStart(opr_block->opr_stat);
}
#endif
CallbackOnComplete callback = this->CreateCallback(
ThreadedEngine::OnCompleteStatic, opr_block);
bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
CallbackOnComplete callback =
this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block);
const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block);
if (debug_info) {
LOG(INFO) << "ExecuteOprBlock " << opr_block
<< "shutdown_phase=" << shutdown_phase_;
}
if (!shutdown_phase_) {
try {
OnStart(threaded_opr);
if (debug_info) {
LOG(INFO) << "ExecuteOprFn ";
}
threaded_opr->fn(run_ctx, callback);
try {
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) ||
threaded_opr->wait) {
threaded_opr->fn(run_ctx, callback);
} else {
callback();
}
} catch (dmlc::Error& e) {

This comment has been minimized.

Copy link
@dtmoodie

dtmoodie Jun 22, 2018

Contributor

Hello @piiswrong
I believe this may have caused a bit of an issue with my training, a CHECK_EQ statement in my code is throwing an exception during the forward pass which is caught here, it is then silently ignored. The backwards pass then clearly fails since the full forward pass was not complete. This CHECK_EQ should have been updated accordingly for the new code I was developing, but since it was silently handled I did not see anything printed.
Should the call to callback() handle any exceptions and print them? Is this supposed to allow python handling of callbacks? I'm not sure the purpose of this change since a rethrow was not occurring.

This comment has been minimized.

Copy link
@dtmoodie

dtmoodie Jun 22, 2018

Contributor

Adding DMLC_LOG_BEFORE_THROW returns to the previous behavior, is it possible that this was set before and removed at some point? I'll start looking for any such change.

This comment has been minimized.

Copy link
@anirudh2290

anirudh2290 Jun 22, 2018

Author Member

@dtmoodie do you happen to use waitall in your python code ? waitall is only supposed to be used for benchmarking purposes. It doesnt rethrow. https://mxnet.incubator.apache.org/architecture/exception_handling.html

This comment has been minimized.

Copy link
@anirudh2290

anirudh2290 Jun 22, 2018

Author Member

please provide a reproducible example.

threaded_opr->opr_exception =
std::make_shared<std::exception_ptr>(std::current_exception());
callback();
}
if (debug_info) {
LOG(INFO) << "Fin ExecuteOprFn ";
}
} catch(dmlc::Error &e) {
} catch (std::exception& e) {
std::string what = e.what();
if (what.find("driver shutting down") == std::string::npos &&
!shutdown_phase_) {
LOG(FATAL) << e.what() << "\n" <<
"A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
LOG(FATAL)
<< e.what() << "\n"
<< "A fatal error occurred in asynchronous engine operation. "
"If you do not know what caused this error, "
"you can try set environment variable MXNET_ENGINE_TYPE "
"to NaiveEngine and run with debugger (i.e. gdb). "
"This will force all operations to be synchronous and "
"backtrace will give you the series of calls that lead "
"to this error. Remember to set MXNET_ENGINE_TYPE back to "
"empty after debugging.";
}
}
} else {
Expand Down Expand Up @@ -414,7 +437,34 @@ class ThreadedEngine : public Engine {
* On operation completion, this will trigger subsequent operations.
*/
inline void OnComplete(ThreadedOpr* threaded_opr);
// callback to the threaded engine
/*!
* \brief rethrow caught exception in WaitForVar
* \param threaded_var the var that we are waiting to read
*/
inline void ThrowException(ThreadedVar* threaded_var);
/*!
* \brief Mark exceptions before operation execution.
*
* Will mark the operator as a failure and associate exception_ptr
* if any of the read dependencies have exception associated.
*/
inline void OnStart(ThreadedOpr* threaded_opr) {
for (auto&& i : threaded_opr->const_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) {
for (auto&& i : threaded_opr->mutable_vars) {
if (i->var_exception && *i->var_exception) {
threaded_opr->opr_exception = i->var_exception;
break;
}
}
}
}

static void OnCompleteStatic(Engine *engine, void *threaded_opr);
/*! \brief append an operator to bulk */
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
Expand Down
4 changes: 2 additions & 2 deletions src/storage/cpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ inline void* CPUDeviceStorage::Alloc(size_t size) {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment_);
if (ptr == NULL) throw std::bad_alloc();
if (ptr == NULL) LOG(FATAL) << "Failed to allocate CPU Memory";
#else
int ret = posix_memalign(&ptr, alignment_, size);
if (ret != 0) throw std::bad_alloc();
if (ret != 0) LOG(FATAL) << "Failed to allocate CPU Memory";
#endif
return ptr;
}
Expand Down
2 changes: 1 addition & 1 deletion src/storage/gpu_device_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ inline void* GPUDeviceStorage::Alloc(size_t size) {
#endif // MXNET_USE_NCCL
cudaError_t e = cudaMalloc(&ret, size);
if (e != cudaSuccess && e != cudaErrorCudartUnloading)
throw std::bad_alloc();
LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
#else // MXNET_USE_CUDA
LOG(FATAL) << "Please compile with CUDA enabled";
#endif // MXNET_USE_CUDA
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from test_random import *
from test_gluon import *
from test_loss import *
from test_exc_handling import *
#from test_rnn import *
from test_gluon_rnn import *
from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sparse_nd_slice
Expand Down
Loading

0 comments on commit 7b24137

Please sign in to comment.