Skip to content

Commit

Permalink
Re-work GetAvailableProviders API (#14486)
Browse files Browse the repository at this point in the history
### Description
Re-work `OrtApi::GetAvailableProviders` in a way that the data is
returned in a single allocation.
Fix exception safety issues and fix `Release` function. 
Remove warning suppressions.
Fix exception safety issue in C++ API.
Fix exception safety issue in C# API.
Move EP name length enforcement to the implementation.

### Motivation and Context
The original motivation comes from
#14378.
However, the API is already implemented.

Cc: @prabhat00155
  • Loading branch information
yuslepukhin authored Feb 1, 2023
1 parent d9e675a commit 61e7636
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 53 deletions.
13 changes: 4 additions & 9 deletions csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,26 +130,21 @@ public string[] GetAvailableProviders()
int numProviders;

NativeApiStatus.VerifySuccess(NativeMethods.OrtGetAvailableProviders(out availableProvidersHandle, out numProviders));

var availableProviders = new string[numProviders];

try
{
for(int i=0; i<numProviders; ++i)
var availableProviders = new string[numProviders];
for (int i=0; i<numProviders; ++i)
{
availableProviders[i] = NativeOnnxValueHelper.StringFromNativeUtf8(Marshal.ReadIntPtr(availableProvidersHandle, IntPtr.Size * i));
}
return availableProviders;
}

finally
{
// Looks a bit weird that we might throw in finally(...)
// But the native method OrtReleaseAvailableProviders actually doesn't return a failure status
// This should never throw. The original C API should have never returned status in the first place.
// If it does, it is BUG and we would like to propagate that to the user in the form of an exception
NativeApiStatus.VerifySuccess(NativeMethods.OrtReleaseAvailableProviders(availableProvidersHandle, numProviders));
}

return availableProviders;
}


Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <stddef.h> // needed for size_t on some platforms

namespace onnxruntime {

constexpr const char* kNoOp = "NoOp";
Expand All @@ -23,6 +25,10 @@ constexpr const char* kNGraphDomain = "com.intel.ai";
constexpr const char* kMIGraphXDomain = "";
constexpr const char* kVitisAIDomain = "com.xilinx";

// This is moved from the OrtApis::GetAvailableProviders implementation
// where it is enforced
constexpr size_t kMaxExecutionProviderNameLen = 30;

constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
Expand Down
3 changes: 2 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,8 @@ struct OrtApi {
*/
ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, _Out_ int* provider_length);

/** \brief Release data from OrtApi::GetAvailableProviders
/** \brief Release data from OrtApi::GetAvailableProviders. This API will never fail
* so you can rely on it in a noexcept code.
*
* \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders.
* \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders
Expand Down
18 changes: 13 additions & 5 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1849,16 +1849,24 @@ inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_c
}

inline std::vector<std::string> GetAvailableProviders() {
int len;
char** providers;
int len;

auto release_fn = [&len](char** providers) {
// This should always return nullptr.
ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
};

ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
std::vector<std::string> available_providers(providers, providers + len);
ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
std::vector<std::string> available_providers;
available_providers.reserve(static_cast<size_t>(len));
for (int i = 0; i < len; ++i) {
available_providers.emplace_back(providers[i]);
}
return available_providers;
}

SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);

template <typename TOp, typename TKernel>
void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
ConstSessionOptions options) const {
Expand Down
15 changes: 12 additions & 3 deletions onnxruntime/core/providers/get_execution_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#include "core/providers/get_execution_providers.h"

#include "core/graph/constants.h"
#include "core/common/common.h"

#include <string_view>

namespace onnxruntime {

namespace {
struct ProviderInfo {
const char* name;
std::string_view name;
bool available;
};

Expand Down Expand Up @@ -155,13 +158,18 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] =
},
{kCpuExecutionProvider, true}, // kCpuExecutionProvider is always last
};

constexpr size_t kAllExecutionProvidersCount = sizeof(kProvidersInPriorityOrder) / sizeof(ProviderInfo);

} // namespace

const std::vector<std::string>& GetAllExecutionProviderNames() {
static const auto all_execution_providers = []() {
std::vector<std::string> result{};
result.reserve(kAllExecutionProvidersCount);
for (const auto& provider : kProvidersInPriorityOrder) {
result.push_back(provider.name);
ORT_ENFORCE(provider.name.size() <= kMaxExecutionProviderNameLen, "Make the EP:", provider.name , " name shorter");
result.push_back(std::string(provider.name));
}
return result;
}();
Expand All @@ -173,8 +181,9 @@ const std::vector<std::string>& GetAvailableExecutionProviderNames() {
static const auto available_execution_providers = []() {
std::vector<std::string> result{};
for (const auto& provider : kProvidersInPriorityOrder) {
ORT_ENFORCE(provider.name.size() <= kMaxExecutionProviderNameLen, "Make the EP:", provider.name, " name shorter");
if (provider.available) {
result.push_back(provider.name);
result.push_back(std::string(provider.name));
}
}
return result;
Expand Down
95 changes: 60 additions & 35 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2052,53 +2052,78 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_
return nullptr;
}

GSL_SUPPRESS(r .11)
ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
_In_ int* providers_length) {
API_IMPL_BEGIN
// TODO: there is no need to manually malloc/free these memory, it is insecure
// and inefficient. Instead, the implementation could scan the array twice,
// and use a single string object to hold all the names.
constexpr size_t MAX_LEN = 30;
const auto& available_providers = GetAvailableExecutionProviderNames();
const int available_count = narrow<int>(available_providers.size());
GSL_SUPPRESS(r .11)
char** const out = new char*[available_count];
if (out) {
for (int i = 0; i < available_count; i++) {
GSL_SUPPRESS(r .11)
out[i] = new char[MAX_LEN + 1];
namespace {

struct ProviderBuffer {
char** buffer_;
char* next_write_;

ProviderBuffer(char** buf, size_t p_count) {
buffer_ = buf;
next_write_ = DataStart(p_count);
}

char* DataStart(size_t p_count) { return reinterpret_cast<char*>(buffer_ + p_count); }
// Return next buffer ptr
void Append(const std::string& provider, size_t p_index) {
// Maximum provider name length is now enforced at GetAvailableExecutionProviderNames()
const size_t to_copy = provider.size();
#ifdef _MSC_VER
strncpy_s(out[i], MAX_LEN, available_providers[i].c_str(), MAX_LEN);
out[i][MAX_LEN] = '\0';
memcpy_s(next_write_, to_copy, provider.data(), to_copy);
#elif defined(__APPLE__)
strlcpy(out[i], available_providers[i].c_str(), MAX_LEN);
memcpy(next_write_, provider.data(), to_copy);
#else
strncpy(out[i], available_providers[i].c_str(), MAX_LEN);
out[i][MAX_LEN] = '\0';
memcpy(next_write_, provider.data(), to_copy);
#endif
}
next_write_[to_copy] = 0;
buffer_[p_index] = next_write_;
next_write_ += to_copy + 1;
}
*providers_length = available_count;
*out_ptr = out;
};
} // namespace

ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char*** out_ptr,
_In_ int* providers_length) {
API_IMPL_BEGIN
const auto& available_providers = GetAvailableExecutionProviderNames();
const size_t available_count = available_providers.size();

if (available_count == 0) {
out_ptr = nullptr;
*providers_length = 0;
return OrtApis::CreateStatus(ORT_FAIL, "Invalid build with no providers available");
}

size_t output_len = 0;
for (const auto& p : available_providers) {
output_len += p.size() + 1;
}

// We allocate and construct the buffer in char* to hold all the string pointers
// followed by the actual string data. We allocate in terms of char* to make it convinient and avoid casts.
const size_t ptrs_num = (sizeof(char*) * available_count + output_len + (sizeof(char*) - 1)) / sizeof(char*);
auto total_buffer = std::make_unique<char*[]>(ptrs_num);
ProviderBuffer provider_buffer(total_buffer.get(), available_count);

for (size_t p_index = 0; p_index < available_count; p_index++) {
provider_buffer.Append(available_providers[p_index], p_index);
}

*providers_length = narrow<int>(available_count);
*out_ptr = total_buffer.release();
API_IMPL_END
return nullptr;
}

// TODO: we don't really need the second parameter
// This is a cleanup API, it should never return any failure
// so any no-throw code can rely on it.
ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char** ptr,
_In_ int providers_length) {
_In_ int /* providers_length */) {
API_IMPL_BEGIN
if (ptr) {
for (int i = 0; i < providers_length; i++) {
GSL_SUPPRESS(r .11)
delete[] ptr[i];
}
GSL_SUPPRESS(r .11)
delete[] ptr;
}
// take possession of the memory and deallocate it
std::unique_ptr<char*[]> g(ptr);
API_IMPL_END
return NULL;
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::GetExecutionProviderApi,
Expand Down

0 comments on commit 61e7636

Please sign in to comment.