Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RandyShuai committed Jul 13, 2023
1 parent 6c67e44 commit bc607f2
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
5 changes: 2 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4335,9 +4335,8 @@ struct OrtApi {
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be
* an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers
* to them will be set into the `outputs` array.
* \param[out] outputs Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
* The array will be passed back to run_async_callback
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
Expand Down
17 changes: 14 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -781,9 +781,20 @@ struct SessionImpl : ConstSessionImpl<T> {

void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding

/** \brief Run the model in a separate thread.
* Callback will be invoked on run completion, with output values as arguments,
* on error, a status could be returned.
/** \brief Run the model asynchronously in a thread owned by intra op thread pool
*
* Wraps OrtApi::RunAsync
*
* \param[in] run_options
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input_values Array of ::OrtValue%s of the input values
* \param[in] input_count Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
* The array will be passed back to the callback
* \param[in] output_count Number of elements in the output_names and outputs array
* \param[in] callback Callback function on model run completion
* \param[in] user_data User data that pass back to the callback
*/
void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2364,8 +2364,6 @@ Status InferenceSession::Run(const OrtRunOptions* run_options,
}
}

ORT_ENFORCE(output_unique_ptrs.size() == output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
ORT_ENFORCE(output_unique_ptrs[i] != nullptr);
Expand All @@ -2383,10 +2381,9 @@ Status InferenceSession::RunAsync(const OrtRunOptions* run_options, const char*
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "intra op thread pool must have at least one thread for RunAsync");
}

InferenceSession* sess = this;
std::function<void()> run_fn = [=]() {
ORT_TRY {
auto status = sess->Run(run_options, input_names, inputs, input_len, output_name, output_names_len, outputs);
auto status = Run(run_options, input_names, inputs, input_len, output_name, output_names_len, outputs);
if (status.IsOK()) {
callback(user_data, outputs, output_names_len, {});
} else {
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,7 @@ ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const
outputs,
run_async_callback,
user_data);
if (status.IsOK()) {
return nullptr;
} else {
return ToOrtStatus(status);
}
return ToOrtStatus(status);
API_IMPL_END
}

Expand Down Expand Up @@ -2695,7 +2691,6 @@ static constexpr OrtApi ort_api_1_to_16 = {
&OrtApis::GetROCMProviderOptionsAsString,
&OrtApis::ReleaseROCMProviderOptions,
&OrtApis::CreateAndRegisterAllocatorV2,

&OrtApis::RunAsync,
};

Expand Down

0 comments on commit bc607f2

Please sign in to comment.