This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Better Exception Handling for Operators (#9681)
* Add support for threaded engine * Add support for threaded engine * Remove on_start_callback for else * Add support for global_ex_ptr * Rethrow in waitall only once * run tests for gpu * Add comments for exception_ptr * Fix lint * Push exc_handling tests * Add comments for OnStart * Fixes for exc handling * Catch std::exception for all other exceptions * Rollback std::move use * Fix style * Fix onstart * Fix debug_info * Throw exception only once in an execution graph * make test naming consistent * Fix symbolic test * Remove unused code
- Loading branch information
1 parent
f57073e
commit 7b24137
Showing
8 changed files
with
231 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -175,6 +175,8 @@ class ThreadedVar final | |
static std::atomic<std::size_t> counter; | ||
~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } | ||
#endif // ENGINE_DEBUG | ||
/*! \brief exception_ptr associated with the ThreadedVar */ | ||
std::shared_ptr<std::exception_ptr> var_exception; | ||
|
||
private: | ||
// TODO(hotpxl) change this to spinlock for faster runtime | ||
|
@@ -236,6 +238,10 @@ struct ThreadedOpr final : public Opr, | |
* that can be deleted right after the operation completed. | ||
*/ | ||
bool temporary{false}; | ||
/*! | ||
* \brief Whether this is a WaitForVar operation | ||
*/ | ||
bool wait{false}; | ||
/*! | ||
* \brief Cast a Opr pointer to ThreadedOpr pointer | ||
* \param ptr pointer from base. | ||
|
@@ -246,6 +252,8 @@ struct ThreadedOpr final : public Opr, | |
} | ||
// define possible debug information | ||
DEFINE_ENGINE_DEBUG_INFO(ThreadedOpr); | ||
/*! \brief exception_ptr associated with the ThreadedOpr */ | ||
std::shared_ptr<std::exception_ptr> opr_exception; | ||
}; // struct ThreadedOpr | ||
|
||
/*! | ||
|
@@ -265,15 +273,17 @@ class ThreadedEngine : public Engine { | |
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
const char* opr_name = nullptr) override; | ||
const char* opr_name = nullptr, | ||
bool wait = false) override; | ||
void DeleteOperator(OprHandle op) override; | ||
void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override; | ||
void PushAsync(AsyncFn exec_fun, Context exec_ctx, | ||
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
FnProperty prop = FnProperty::kNormal, | ||
int priority = 0, | ||
const char* opr_name = nullptr) override; | ||
const char* opr_name = nullptr, | ||
bool wait = false) override; | ||
void PushSync(SyncFn exec_fn, Context exec_ctx, | ||
std::vector<VarHandle> const& const_vars, | ||
std::vector<VarHandle> const& mutable_vars, | ||
|
@@ -321,50 +331,63 @@ class ThreadedEngine : public Engine { | |
* \param run_ctx runtime context used to execute the function. | ||
* \param opr_block the opr_block to be executed and deleted. | ||
*/ | ||
void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) { | ||
void ExecuteOprBlock(RunContext run_ctx, OprBlock* opr_block) { | ||
ThreadedOpr* threaded_opr = opr_block->opr; | ||
#if MXNET_USE_PROFILER | ||
if (opr_block->profiling && threaded_opr->opr_name) { | ||
const Context& ctx = opr_block->ctx; | ||
opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id); | ||
opr_block->opr_stat = | ||
Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id); | ||
uint64_t id = std::hash<std::thread::id>()(std::this_thread::get_id()); | ||
opr_block->opr_stat->thread_id = id; | ||
strncpy(opr_block->opr_stat->opr_name, | ||
threaded_opr->opr_name, | ||
sizeof(opr_block->opr_stat->opr_name) - 1); | ||
strncpy(opr_block->opr_stat->opr_name, threaded_opr->opr_name, | ||
sizeof(opr_block->opr_stat->opr_name) - 1); | ||
// record operator start timestamp | ||
SetOprStart(opr_block->opr_stat); | ||
} | ||
#endif | ||
CallbackOnComplete callback = this->CreateCallback( | ||
ThreadedEngine::OnCompleteStatic, opr_block); | ||
bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); | ||
CallbackOnComplete callback = | ||
this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); | ||
const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); | ||
if (debug_info) { | ||
LOG(INFO) << "ExecuteOprBlock " << opr_block | ||
<< "shutdown_phase=" << shutdown_phase_; | ||
} | ||
if (!shutdown_phase_) { | ||
try { | ||
OnStart(threaded_opr); | ||
if (debug_info) { | ||
LOG(INFO) << "ExecuteOprFn "; | ||
} | ||
threaded_opr->fn(run_ctx, callback); | ||
try { | ||
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception) || | ||
threaded_opr->wait) { | ||
threaded_opr->fn(run_ctx, callback); | ||
} else { | ||
callback(); | ||
} | ||
} catch (dmlc::Error& e) { | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
dtmoodie
Contributor
|
||
threaded_opr->opr_exception = | ||
std::make_shared<std::exception_ptr>(std::current_exception()); | ||
callback(); | ||
} | ||
if (debug_info) { | ||
LOG(INFO) << "Fin ExecuteOprFn "; | ||
} | ||
} catch(dmlc::Error &e) { | ||
} catch (std::exception& e) { | ||
std::string what = e.what(); | ||
if (what.find("driver shutting down") == std::string::npos && | ||
!shutdown_phase_) { | ||
LOG(FATAL) << e.what() << "\n" << | ||
"A fatal error occurred in asynchronous engine operation. " | ||
"If you do not know what caused this error, " | ||
"you can try set environment variable MXNET_ENGINE_TYPE " | ||
"to NaiveEngine and run with debugger (i.e. gdb). " | ||
"This will force all operations to be synchronous and " | ||
"backtrace will give you the series of calls that lead " | ||
"to this error. Remember to set MXNET_ENGINE_TYPE back to " | ||
"empty after debugging."; | ||
LOG(FATAL) | ||
<< e.what() << "\n" | ||
<< "A fatal error occurred in asynchronous engine operation. " | ||
"If you do not know what caused this error, " | ||
"you can try set environment variable MXNET_ENGINE_TYPE " | ||
"to NaiveEngine and run with debugger (i.e. gdb). " | ||
"This will force all operations to be synchronous and " | ||
"backtrace will give you the series of calls that lead " | ||
"to this error. Remember to set MXNET_ENGINE_TYPE back to " | ||
"empty after debugging."; | ||
} | ||
} | ||
} else { | ||
|
@@ -414,7 +437,34 @@ class ThreadedEngine : public Engine { | |
* On operation completion, this will trigger subsequent operations. | ||
*/ | ||
inline void OnComplete(ThreadedOpr* threaded_opr); | ||
// callback to the threaded engine | ||
/*! | ||
* \brief rethrow caught exception in WaitForVar | ||
* \param threaded_var the var that we are waiting to read | ||
*/ | ||
inline void ThrowException(ThreadedVar* threaded_var); | ||
/*! | ||
* \brief Mark exceptions before operation execution. | ||
* | ||
* Will mark the operator as a failure and associate exception_ptr | ||
* if any of the read dependencies have exception associated. | ||
*/ | ||
inline void OnStart(ThreadedOpr* threaded_opr) { | ||
for (auto&& i : threaded_opr->const_vars) { | ||
if (i->var_exception && *i->var_exception) { | ||
threaded_opr->opr_exception = i->var_exception; | ||
break; | ||
} | ||
} | ||
if (!(threaded_opr->opr_exception && *threaded_opr->opr_exception)) { | ||
for (auto&& i : threaded_opr->mutable_vars) { | ||
if (i->var_exception && *i->var_exception) { | ||
threaded_opr->opr_exception = i->var_exception; | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
|
||
static void OnCompleteStatic(Engine *engine, void *threaded_opr); | ||
/*! \brief append an operator to bulk */ | ||
inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Hello @piiswrong
I believe this may have caused a bit of an issue with my training, a CHECK_EQ statement in my code is throwing an exception during the forward pass which is caught here, it is then silently ignored. The backwards pass then clearly fails since the full forward pass was not complete. This CHECK_EQ should have been updated accordingly for the new code I was developing, but since it was silently handled I did not see anything printed.
Should the call to callback() handle any exceptions and print them? Is this supposed to allow python handling of callbacks? I'm not sure the purpose of this change since a rethrow was not occurring.