Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunAsync C/CXX API #16613

Merged
merged 30 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,9 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha

typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);

typedef void (*RunAsyncCallbackFn)(OrtValue**, size_t, OrtStatusPtr);
typedef void (*RunAsyncCallbackFn)(void*, OrtValue**, size_t, OrtStatusPtr);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

// void CallbackBridge(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);

/** \brief The C API
*
Expand Down Expand Up @@ -4331,12 +4333,13 @@ struct OrtApi {
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* inputs, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback);
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
};

/*
Expand Down
7 changes: 5 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <unordered_map>
#include <utility>
#include <type_traits>
#include <functional>
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

#ifdef ORT_NO_EXCEPTIONS
#include <iostream>
Expand Down Expand Up @@ -748,6 +749,8 @@ struct ConstSessionImpl : Base<T> {
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
};

using RunAsyncCallbackStdFn = std::function<void(std::vector<Value>&, Status)>;
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
struct SessionImpl : ConstSessionImpl<T> {
using B = ConstSessionImpl<T>;
Expand Down Expand Up @@ -783,10 +786,10 @@ struct SessionImpl : ConstSessionImpl<T> {

/** \brief Run the model in a separate thread.
* Callback will be invoked on run completion, with output values as arguments,
* on error, status could be used to see detail.
* on error, a status could be returned.
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackFn callback);
const char* const* output_names, size_t output_count, RunAsyncCallbackStdFn& callback);

/** \brief End profiling and return a copy of the profiling file name.
*
Expand Down
23 changes: 16 additions & 7 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -915,13 +915,6 @@ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, con
return output_values;
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackFn callback) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, callback));
}

template <typename T>
inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count) {
Expand All @@ -936,6 +929,22 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
}

inline void CallbackBridge(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status) {
RunAsyncCallbackStdFn* callback = reinterpret_cast<RunAsyncCallbackStdFn*>(user_data);
std::vector<Ort::Value> output_values;
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
for (size_t ith = 0; ith < num_outputs; ++ith) {
output_values.emplace_back(outputs[ith]);
}
(*callback)(output_values, Ort::Status{status});
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackStdFn& callback) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, CallbackBridge, &callback));
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename T>
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
char* out = nullptr;
Expand Down
11 changes: 5 additions & 6 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2304,7 +2304,6 @@ Status InferenceSession::Run(const OrtRunOptions* run_options,
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
OrtValue** output) {

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
Expand Down Expand Up @@ -2379,7 +2378,7 @@ Status InferenceSession::Run(const OrtRunOptions* run_options,
Status InferenceSession::RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_name, size_t output_names_len,
RunAsyncCallbackFn callback) {
RunAsyncCallbackFn callback, void* user_data) {
InferenceSession* sess = this;
std::function<void()> run_fn = [=]() {
ORT_TRY {
Expand All @@ -2388,18 +2387,18 @@ Status InferenceSession::RunAsync(const OrtRunOptions* run_options, const char*
memset(outputs.get(), 0x0, sizeof(OrtValuePtr) * output_names_len);
auto status = sess->Run(run_options, input_names, input, input_len, output_name, output_names_len, outputs.get());
if (status.IsOK()) {
callback(outputs.get(), output_names_len, {});
callback(user_data, outputs.get(), output_names_len, {});
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
} else {
callback({}, 0, ToOrtStatus(status));
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& e) {
std::string what = "unknown";
ORT_HANDLE_EXCEPTION([&]() { what = e.what(); });
callback({}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, what.c_str())));
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, what.c_str())));
}
ORT_CATCH(...) {
callback({}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION)));
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION)));
}
};
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class InferenceSession {
[[nodiscard]] common::Status RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
RunAsyncCallbackFn callback);
RunAsyncCallbackFn callback, void* user_data = nullptr);

/**
* Run a pre-loaded and pre-intialized model.
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback) {
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
auto status = session->RunAsync(run_options, input_names, input, input_len, output_names, output_names_len, run_async_callback);
auto status = session->RunAsync(run_options, input_names, input, input_len, output_names, output_names_len, run_async_callback, user_data);
if (status.IsOK()) {
return nullptr;
} else {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,5 +483,5 @@ ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOpt
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback);
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
} // namespace OrtApis
28 changes: 15 additions & 13 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3223,16 +3223,6 @@ TEST(MultiKernelSingleSchemaTest, DuplicateKernel) {
const static std::thread::id caller_tid = std::this_thread::get_id();
static std::atomic_bool atomic_wait{false};

void RunAsyncCallBack(OrtValue** outputs, size_t num_outputs, OrtStatusPtr status) {
auto callee_tid = std::this_thread::get_id();
EXPECT_EQ(status, nullptr);
EXPECT_EQ(num_outputs, 1UL);
EXPECT_NE(caller_tid, callee_tid);
Ort::Value output_value(outputs[0]);
EXPECT_EQ(output_value.At<float>({1, 0}), 9.f);
atomic_wait.store(true);
}

TEST(CApiTest, RunAsync) {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(2);
Expand All @@ -3251,10 +3241,22 @@ TEST(CApiTest, RunAsync) {
const char* output_names[] = {"Y"};

Ort::RunOptions run_options;
EXPECT_NO_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, 1, RunAsyncCallBack));

while (!atomic_wait.load()) {
std::this_thread::yield();
Ort::detail::RunAsyncCallbackStdFn callback = [&](std::vector<Ort::Value>& outputs, Ort::Status status) {
auto callee_tid = std::this_thread::get_id();
EXPECT_NE(caller_tid, callee_tid);
EXPECT_TRUE(status.IsOK());
EXPECT_EQ(outputs.size(), 1UL);
EXPECT_EQ(outputs[0].At<float>({1, 0}), 9.f);
atomic_wait.store(true);
};

EXPECT_NO_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, 1, callback));

std::chrono::duration<double, std::milli> dur{100};
// timeout in about 10 secs
for (int i = 0; i < 100 && !atomic_wait.load(); ++i) {
std::this_thread::sleep_for(dur);
}

EXPECT_EQ(atomic_wait.load(), true);
Expand Down