Skip to content

Commit

Permalink
RunAsync C/CXX API (#16613)
Browse files Browse the repository at this point in the history
Implement RunAsync API - the session will run in a thread of intra-op
thread pool.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Jul 16, 2023
1 parent 2cf31a2 commit e1ca8ee
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 66 deletions.
30 changes: 30 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,15 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha

typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);

/** \brief Callback function for RunAsync
*
* \param[in] user_data User specific data that passed back to the callback
* \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr
* \param[out] num_outputs Number of outputs, on error, the value will be zero
* \param[out] status On error, status will provide details
*/
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);

/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
Expand Down Expand Up @@ -4316,6 +4325,27 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* \param[in] session
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input Array of ::OrtValue%s of the input values
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
* The array will be passed back to run_async_callback
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
};

/*
Expand Down
18 changes: 18 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,24 @@ struct SessionImpl : ConstSessionImpl<T> {

void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding

/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* Wraps OrtApi::RunAsync
*
* \param[in] run_options
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input_values Array of ::OrtValue%s of the input values
* \param[in] input_count Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
* The array will be passed back to the callback
* \param[in] output_count Number of elements in the output_names and outputs array
* \param[in] callback Callback function on model run completion
* \param[in] user_data User data that pass back to the callback
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);

/** \brief End profiling and return a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,16 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
ort_input_values, input_count, output_names, output_count,
ort_output_values, callback, user_data));
}

template <typename T>
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
char* out = nullptr;
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,116 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}

Status InferenceSession::Run(const RunOptions& run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches) {
size_t num_feeds = feed_names.size();
size_t num_fetches = fetch_names.size();
InlinedVector<std::string> feed_name_vec;
feed_name_vec.reserve(num_feeds);
InlinedVector<OrtValue> feed_vec;
feed_vec.reserve(num_feeds);

for (size_t i = 0; i != num_feeds; ++i) {
if (feed_names[i] == nullptr || feed_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
}

if (!feeds[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", feed_names[i]).c_str());
}

feed_name_vec.emplace_back(feed_names[i]);
feed_vec.emplace_back(*feeds[i]);
}

// Create output feed
InlinedVector<std::string> fetch_name_vec;
fetch_name_vec.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetch_names[i] == nullptr || fetch_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
}
fetch_name_vec.emplace_back(fetch_names[i]);
}

std::vector<OrtValue> fetch_vec;
fetch_vec.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] != nullptr) {
fetch_vec.emplace_back(*fetches[i]);
} else {
fetch_vec.emplace_back();
}
}

Status status;
status = Run(run_options, feed_name_vec, feed_vec, fetch_name_vec, &fetch_vec, nullptr);

if (!status.IsOK())
return status;

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> fetch_unique_ptrs;
fetch_unique_ptrs.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] == nullptr) {
fetch_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetch_vec[i]));
} else {
fetch_unique_ptrs.emplace_back();
}
}

for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] == nullptr) {
ORT_ENFORCE(fetch_unique_ptrs[i] != nullptr);
fetches[i] = fetch_unique_ptrs[i].release();
}
}
return Status::OK();
}

common::Status InferenceSession::RunAsync(const RunOptions* run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches,
RunAsyncCallbackFn callback,
void* user_data) {
size_t num_fetches = fetch_names.size();
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
}
std::function<void()> run_fn = [=]() {
ORT_TRY {
Status status;
if (run_options) {
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
} else {
RunOptions default_run_options;
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
}
if (status.IsOK()) {
callback(user_data, fetches.data(), num_fetches, ToOrtStatus(status));
} else {
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([=]() {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what())));
});
}
ORT_CATCH(...) {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
}
}; // run_fn
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
return Status::OK();
}

common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ class InferenceSession {
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);

[[nodiscard]] common::Status Run(const RunOptions& run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches);

[[nodiscard]] common::Status RunAsync(const RunOptions* run_options,
gsl::span<const char* const> feed_names,
gsl::span<const OrtValue* const> feeds,
gsl::span<const char* const> fetch_names,
gsl::span<OrtValue*> fetches,
RunAsyncCallbackFn callback,
void* user_data = nullptr);

/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.
Expand Down
104 changes: 40 additions & 64 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}
gsl::span<const char* const> input_names_span(input_names, input_len);
gsl::span<const OrtValue* const> input_span(input, input_len);
gsl::span<const char* const> output_name_span(output_names, output_names_len);
gsl::span<OrtValue*> output_span(output, output_names_len);

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
if (run_options) {
status = session->Run(*run_options,
input_names_span,
input_span,
output_name_span,
output_span);
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
const RunOptions default_run_options;
status = session->Run(default_run_options,
input_names_span,
input_span,
output_name_span,
output_span);
}
return ToOrtStatus(status);
API_IMPL_END
}

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

assert(output_unique_ptrs.size() == output_names_len);
gsl::span<const char* const> input_names_span(input_names, input_len);
gsl::span<const OrtValue* const> input_span(input, input_len);
gsl::span<const char* const> output_name_span(output_names, output_names_len);
gsl::span<OrtValue*> output_span(output, output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return nullptr;
return ToOrtStatus(session->RunAsync(run_options,
input_names_span,
input_span,
output_name_span,
output_span,
run_async_callback,
user_data));
API_IMPL_END
}

Expand Down Expand Up @@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,
&OrtApis::RunAsync,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions

ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** outputs,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
} // namespace OrtApis
Loading

0 comments on commit e1ca8ee

Please sign in to comment.