Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunAsync C/CXX API #16613

Merged
merged 30 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,15 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha

typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);

/** \brief Callback function for RunAsync
*
* \param[in] user_data User specific data that passed back to the callback.
* \param[out] outputs On succeed, outputs host inference results.
* \param[out] num_outputs Number of outputs.
* \param[out] status On error, status will provide details.
*/
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
Expand Down Expand Up @@ -4316,6 +4325,27 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* \param[in] session
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input Array of ::OrtValue%s of the input values
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
* The array will be passed back to run_async_callback
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
};

/*
Expand Down
18 changes: 18 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,24 @@ struct SessionImpl : ConstSessionImpl<T> {

void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding

/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* Wraps OrtApi::RunAsync
*
* \param[in] run_options
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input_values Array of ::OrtValue%s of the input values
* \param[in] input_count Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
* The array will be passed back to the callback
* \param[in] output_count Number of elements in the output_names and outputs array
* \param[in] callback Callback function on model run completion
* \param[in] user_data User data that pass back to the callback
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);

/** \brief End profiling and return a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,16 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
ort_input_values, input_count, output_names, output_count,
ort_output_values, callback, user_data));
}

template <typename T>
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
char* out = nullptr;
Expand Down
111 changes: 111 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,117 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}

Status InferenceSession::Run(const RunOptions& run_options,
gsl::span<const char*> feed_names,
gsl::span<const OrtValue*> feeds,
gsl::span<const char*> fetch_names,
gsl::span<OrtValue*> fetches) {
size_t num_feeds = feed_names.size();
size_t num_fetches = fetch_names.size();
InlinedVector<std::string> feed_name_vec;
feed_name_vec.reserve(num_feeds);
InlinedVector<OrtValue> feed_vec;
feed_vec.reserve(num_feeds);

for (size_t i = 0; i != num_feeds; ++i) {
if (feed_names[i] == nullptr || feed_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
}

if (!feeds[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", feed_names[i]).c_str());
}

feed_name_vec.emplace_back(feed_names[i]);
feed_vec.emplace_back(*feeds[i]);
}

// Create output feed
InlinedVector<std::string> fetch_name_vec;
fetch_name_vec.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetch_names[i] == nullptr || fetch_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
}
fetch_name_vec.emplace_back(fetch_names[i]);
}

std::vector<OrtValue> fetch_vec;
fetch_vec.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] != nullptr) {
fetch_vec.emplace_back(*fetches[i]);
} else {
fetch_vec.emplace_back();
}
}

Status status;
status = Run(run_options, feed_name_vec, feed_vec, fetch_name_vec, &fetch_vec, nullptr);

if (!status.IsOK())
return status;

// We do it in two loops to make sure copy __ctors does not throw
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
InlinedVector<std::unique_ptr<OrtValue>> fetch_unique_ptrs;
fetch_unique_ptrs.reserve(num_fetches);
for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] == nullptr) {
fetch_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetch_vec[i]));
} else {
fetch_unique_ptrs.emplace_back();
}
}

for (size_t i = 0; i != num_fetches; ++i) {
if (fetches[i] == nullptr) {
ORT_ENFORCE(fetch_unique_ptrs[i] != nullptr);
fetches[i] = fetch_unique_ptrs[i].release();
}
}
return Status::OK();
}

common::Status InferenceSession::RunAsync(const RunOptions* run_options,
gsl::span<const char*> feed_names,
gsl::span<const OrtValue*> feeds,
gsl::span<const char*> fetch_names,
gsl::span<OrtValue*> fetches,
RunAsyncCallbackFn callback,
void* user_data) {
size_t num_fetches = fetch_names.size();
if (!thread_pool_.get() || concurrency::ThreadPool::DegreeOfParallelism(thread_pool_.get()) < 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
}

std::function<void()> run_fn = [=]() {
ORT_TRY {
Status status;
if (run_options) {
status = Run(*run_options, feed_names, feeds, fetch_names, fetches);
} else {
RunOptions default_run_options;
status = Run(default_run_options, feed_names, feeds, fetch_names, fetches);
}
if (status.IsOK()) {
callback(user_data, fetches.data(), num_fetches, {});
} else {
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& e) {
std::string what = "unknown exception";
ORT_HANDLE_EXCEPTION([&]() { what = e.what(); });
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, what.c_str())));
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
}
ORT_CATCH(...) {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "unknown exception")));
}
};
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
return Status::OK();
}

common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,20 @@ class InferenceSession {
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);

[[nodiscard]] common::Status Run(const RunOptions& run_options,
gsl::span<const char*> feed_names,
gsl::span<const OrtValue*> feeds,
gsl::span<const char*> fetch_names,
gsl::span<OrtValue*> fetches);

[[nodiscard]] common::Status RunAsync(const RunOptions* run_options,
gsl::span<const char*> feed_names,
gsl::span<const OrtValue*> feeds,
gsl::span<const char*> fetch_names,
gsl::span<OrtValue*> fetches,
RunAsyncCallbackFn callback,
void* user_data = nullptr);

/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.
Expand Down
104 changes: 40 additions & 64 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,81 +817,56 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}
gsl::span<const char*> input_names_span(const_cast<const char**>(input_names), input_len);
gsl::span<const OrtValue*> input_span(const_cast<const OrtValue**>(input), input_len);
gsl::span<const char*> output_name_span(const_cast<const char**>(output_names), input_len);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
gsl::span<OrtValue*> output_span(output, output_names_len);

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
if (run_options) {
status = session->Run(*run_options,
input_names_span,
input_span,
output_name_span,
output_span);
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
const RunOptions default_run_options;
status = session->Run(default_run_options,
input_names_span,
input_span,
output_name_span,
output_span);
}
return ToOrtStatus(status);
API_IMPL_END
}

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

assert(output_unique_ptrs.size() == output_names_len);
gsl::span<const char*> input_names_span(const_cast<const char**>(input_names), input_len);
gsl::span<const OrtValue*> input_span(const_cast<const OrtValue**>(input), input_len);
gsl::span<const char*> output_name_span(const_cast<const char**>(output_names), input_len);
gsl::span<OrtValue*> output_span(output, output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return nullptr;
return ToOrtStatus(session->RunAsync(run_options,
input_names_span,
input_span,
output_name_span,
output_span,
run_async_callback,
user_data));
API_IMPL_END
}

Expand Down Expand Up @@ -2735,6 +2710,7 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,
&OrtApis::RunAsync,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,11 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions

ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** outputs,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
} // namespace OrtApis
Loading