Skip to content

Commit

Permalink
Use ep.context_file_path to get base path when creating session from …
Browse files Browse the repository at this point in the history
…memory
  • Loading branch information
javier-intel committed Feb 5, 2025
1 parent 48f060b commit fb9ea15
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
20 changes: 14 additions & 6 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ BackendManager::BackendManager(SessionContext& session_context,
ptr_stream_t model_stream;
std::unique_ptr<onnx::ModelProto> model_proto;
if (subgraph_context_.is_ep_ctx_graph) {
model_stream = ep_ctx_handle_.GetModelBlobStream(subgraph);
model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph);
} else {
model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger);
}
Expand Down Expand Up @@ -214,21 +214,29 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie
// If not embed_mode, dump the blob here and only pass on the path to the blob
std::string model_blob_str;
auto compiled_model = concrete_backend_->GetOVCompiledModel();
if (session_context_.so_context_embed_mode) {
// Internal blob
if (session_context_.so_context_embed_mode) { // Internal blob
std::ostringstream model_blob_stream;
compiled_model.export_model(model_blob_stream);
model_blob_str = std::move(model_blob_stream).str();
if (model_blob_str.empty()) {
ORT_THROW("Model blob stream is empty after exporting the compiled model.");
}
} else {
// External blob
} else { // External blob
// Build name by combining EpCtx model name (if available) and subgraph name. Model
// name is not available in when creating a session from memory
auto name = session_context_.so_context_file_path.stem().string();
if (!name.empty() && !graph_body_viewer.ModelPath().empty()) {
name = graph_body_viewer.ModelPath().stem().string();
}
if (!name.empty()) {
name += "_";
}
name += subgraph_context_.subgraph_name;

std::filesystem::path blob_filename = session_context_.so_context_file_path;
if (blob_filename.empty()) {
blob_filename = session_context_.onnx_model_path_name;
}
const auto name = graph_body_viewer.ModelPath().stem().string() + "_" + subgraph_context_.subgraph_name;
blob_filename = blob_filename.parent_path() / name;
blob_filename.replace_extension("blob");
std::ofstream blob_file(blob_filename,
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer,
return Status::OK();
}

std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const GraphViewer& graph_viewer) const {
std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const {
auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin();
auto node = graph_viewer.GetNode(first_index);
ORT_ENFORCE(node != nullptr);
Expand All @@ -115,7 +115,11 @@ std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const GraphViewer
if (embed_mode) {
result.reset((std::istream*)new std::istringstream(ep_cache_context));
} else {
const auto& blob_filepath = graph_viewer.ModelPath().parent_path() / ep_cache_context;
auto blob_filepath = so_context_file_path;
if (blob_filepath.empty() && !graph_viewer.ModelPath().empty()) {
blob_filepath = graph_viewer.ModelPath();
}
blob_filepath = blob_filepath.parent_path() / ep_cache_context;
ORT_ENFORCE(std::filesystem::exists(blob_filepath), "Blob file not found: ", blob_filepath.string());
result.reset((std::istream*)new std::ifstream(blob_filepath, std::ios_base::binary | std::ios_base::in));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EPCtxHandler {
const std::string& graph_name,
const bool embed_mode,
std::string&& model_blob_str) const;
std::unique_ptr<std::istream> GetModelBlobStream(const GraphViewer& graph_viewer) const;
std::unique_ptr<std::istream> GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const;
InlinedVector<const Node*> GetEPCtxNodes() const;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ void ParseConfigOptions(ProviderInfo& pi, const ConfigOptions& config_options) {
pi.so_context_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
}

void* ParseUint64(const ProviderOptions& provider_options, [[maybe_unused]] std::string option_name) {
if (provider_options.contains("context")) {
uint64_t number = std::strtoull(provider_options.at("context").data(), nullptr, 16);
void* ParseUint64(const ProviderOptions& provider_options, std::string option_name) {
if (provider_options.contains(option_name)) {
uint64_t number = std::strtoull(provider_options.at(option_name).data(), nullptr, 16);
return reinterpret_cast<void*>(number);
} else {
return nullptr;
Expand Down

0 comments on commit fb9ea15

Please sign in to comment.