Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunAsync C/CXX API #16613

Merged
merged 30 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,10 @@ typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_ha

typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api);

typedef void (*RunAsyncCallbackFn)(void*, OrtValue**, size_t, OrtStatusPtr);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

// void CallbackBridge(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);

/** \brief The C API
*
* All C API functions are defined inside this structure as pointers to functions.
Expand Down Expand Up @@ -4316,6 +4320,26 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

/** \brief Run the model asynchronously in an ::OrtSession
*
* Will return immediately. Model runs in a separate thread. Callback will be invoked on completion.
*
* \param[in] session
* \param[in] run_options If nullptr, will use a default ::OrtRunOptions
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] inputs Array of ::OrtValue%s of the input values
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* inputs, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
};

/*
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <unordered_map>
#include <utility>
#include <type_traits>
#include <functional>
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

#ifdef ORT_NO_EXCEPTIONS
#include <iostream>
Expand Down Expand Up @@ -748,6 +749,8 @@ struct ConstSessionImpl : Base<T> {
TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
};

using RunAsyncCallbackStdFn = std::function<void(std::vector<Value>&, Status)>;
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
struct SessionImpl : ConstSessionImpl<T> {
using B = ConstSessionImpl<T>;
Expand Down Expand Up @@ -781,6 +784,13 @@ struct SessionImpl : ConstSessionImpl<T> {

void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding

/** \brief Run the model in a separate thread.
* Callback will be invoked on run completion, with output values as arguments,
* on error, a status could be returned.
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackStdFn& callback);

/** \brief End profiling and return a copy of the profiling file name.
*
* \param allocator to allocate memory for the copy of the string returned
Expand Down
16 changes: 16 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,22 @@ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding&
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
}

inline void CallbackBridge(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status) {
RunAsyncCallbackStdFn* callback = reinterpret_cast<RunAsyncCallbackStdFn*>(user_data);
std::vector<Ort::Value> output_values;
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
for (size_t ith = 0; ith < num_outputs; ++ith) {
output_values.emplace_back(outputs[ith]);
}
(*callback)(output_values, Ort::Status{status});
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename T>
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, size_t output_count, RunAsyncCallbackStdFn& callback) {
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, CallbackBridge, &callback));
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename T>
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
char* out = nullptr;
Expand Down
106 changes: 106 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2299,6 +2299,112 @@ Status InferenceSession::Run(const RunOptions& run_options,
return retval;
}

Status InferenceSession::Run(const OrtRunOptions* run_options,
const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
OrtValue** output) {
InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_name_vec;
output_name_vec.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names[i] == nullptr || output_names[i][0] == '\0') {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output name cannot be empty");
}
output_name_vec.emplace_back(output_names[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = Run(op, feed_names, feeds, output_name_vec, &fetches, nullptr);
} else {
status = Run(*run_options, feed_names, feeds, output_name_vec, &fetches, nullptr);
}

if (!status.IsOK())
return status;

// We do it in two loops to make sure copy __ctors does not throw
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}

ORT_ENFORCE(output_unique_ptrs.size() == output_names_len);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
ORT_ENFORCE(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return Status::OK();
}

Status InferenceSession::RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_name, size_t output_names_len,
RunAsyncCallbackFn callback, void* user_data) {
InferenceSession* sess = this;
std::function<void()> run_fn = [=]() {
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
ORT_TRY {
using OrtValuePtr = OrtValue*;
std::unique_ptr<OrtValuePtr[]> outputs = std::make_unique<OrtValuePtr[]>(output_names_len);
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
memset(outputs.get(), 0x0, sizeof(OrtValuePtr) * output_names_len);
auto status = sess->Run(run_options, input_names, input, input_len, output_name, output_names_len, outputs.get());
if (status.IsOK()) {
callback(user_data, outputs.get(), output_names_len, {});
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
} else {
callback(user_data, {}, 0, ToOrtStatus(status));
}
}
ORT_CATCH(const std::exception& e) {
std::string what = "unknown";
ORT_HANDLE_EXCEPTION([&]() { what = e.what(); });
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, what.c_str())));
}
ORT_CATCH(...) {
callback(user_data, {}, 0, ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION)));
}
};
concurrency::ThreadPool::Schedule(thread_pool_.get(), run_fn);
return Status::OK();
}

common::Status InferenceSession::Run(const NameMLValMap& feeds, gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches) {
return Run(RunOptions(), feeds, output_names, p_fetches);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,16 @@ class InferenceSession {
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info = nullptr);

[[nodiscard]] common::Status Run(const OrtRunOptions* run_options, const char* const* input_names,
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
OrtValue** output);

[[nodiscard]] common::Status RunAsync(const OrtRunOptions* run_options, const char* const* input_names,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
const OrtValue* const* input, size_t input_len,
const char* const* output_names, size_t output_names_len,
RunAsyncCallbackFn callback, void* user_data = nullptr);

/**
* Run a pre-loaded and pre-intialized model.
* Multiple threads are allowed to run this function; hence its thread-safe.
Expand Down
87 changes: 20 additions & 67 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,81 +817,32 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
status = session->Run(op, feed_names, feeds, output_names, &fetches, nullptr);
auto status = session->Run(run_options, input_names, input, input_len, output_names, output_names_len, output);
if (status.IsOK()) {
return nullptr;
} else {
status = session->Run(*run_options, feed_names, feeds, output_names, &fetches, nullptr);
}

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}
API_IMPL_END
}

assert(output_unique_ptrs.size() == output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data) {
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
auto status = session->RunAsync(run_options, input_names, input, input_len, output_names, output_names_len, run_async_callback, user_data);
if (status.IsOK()) {
return nullptr;
} else {
return ToOrtStatus(status);
}
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
API_IMPL_END
}

Expand Down Expand Up @@ -2735,6 +2686,8 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,

&OrtApis::RunAsync,
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,10 @@ ORT_API(void, ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions

ORT_API_STATUS_IMPL(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg,
_In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API_STATUS_IMPL(RunAsync, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input, size_t input_len,
_In_reads_(output_names_len) const char* const* output_names, size_t output_names_len,
_In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data);
} // namespace OrtApis
46 changes: 44 additions & 2 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3140,8 +3140,8 @@ TEST(MultiKernelSingleSchemaTest, valid) {
Ort::Value::CreateTensor<float>(memory_info, x_value, 10, x_dim, 1),
};

Ort::RunOptions run_optoins;
auto output_tensors = session.Run(run_optoins, input_names, input_tensors, 1, output_names, 2);
Ort::RunOptions run_options;
auto output_tensors = session.Run(run_options, input_names, input_tensors, 1, output_names, 2);
ASSERT_TRUE(*output_tensors[1].GetTensorData<int32_t>() == 72);
}

Expand Down Expand Up @@ -3219,3 +3219,45 @@ TEST(MultiKernelSingleSchemaTest, DuplicateKernel) {
}

#endif

const static std::thread::id caller_tid = std::this_thread::get_id();
static std::atomic_bool atomic_wait{false};

TEST(CApiTest, RunAsync) {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(2);
Ort::Session session(*ort_env, MODEL_URI, session_options);

const char* input_names[] = {"X"};
float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t x_dim[] = {3, 2};

auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);

Ort::Value input_tensors[1] = {
Ort::Value::CreateTensor<float>(memory_info, x_value, 6, x_dim, 2),
};

const char* output_names[] = {"Y"};

Ort::RunOptions run_options;

Ort::detail::RunAsyncCallbackStdFn callback = [&](std::vector<Ort::Value>& outputs, Ort::Status status) {
auto callee_tid = std::this_thread::get_id();
EXPECT_NE(caller_tid, callee_tid);
EXPECT_TRUE(status.IsOK());
EXPECT_EQ(outputs.size(), 1UL);
EXPECT_EQ(outputs[0].At<float>({1, 0}), 9.f);
atomic_wait.store(true);
};

EXPECT_NO_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, 1, callback));

std::chrono::duration<double, std::milli> dur{100};
// timeout in about 10 secs
for (int i = 0; i < 100 && !atomic_wait.load(); ++i) {
std::this_thread::sleep_for(dur);
}

EXPECT_EQ(atomic_wait.load(), true);
}
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved