Skip to content

Commit

Permalink
Use DynOrtTensor for model output tensors
Browse files Browse the repository at this point in the history
Outputs aren't all the same type for a single model, so this allows extracting different types per tensor.
  • Loading branch information
marshallpierce committed Feb 26, 2021
1 parent 0d34d23 commit 555bec7
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 79 deletions.
9 changes: 7 additions & 2 deletions onnxruntime/examples/issue22.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ fn main() {
let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();

let outputs: Vec<OrtOwnedTensor<f32, _>> =
session.run(vec![input_ids, attention_mask]).unwrap();
let outputs: Vec<OrtOwnedTensor<f32, _>> = session
.run(vec![input_ids, attention_mask])
.unwrap()
.into_iter()
.map(|dyn_tensor| dyn_tensor.try_extract())
.collect::<Result<_, _>>()
.unwrap();
print!("outputs: {:#?}", outputs);
}
13 changes: 8 additions & 5 deletions onnxruntime/examples/sample.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![forbid(unsafe_code)]

use onnxruntime::{
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
LoggingLevel,
environment::Environment,
ndarray::Array,
tensor::{DynOrtTensor, OrtOwnedTensor},
GraphOptimizationLevel, LoggingLevel,
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
Expand Down Expand Up @@ -61,11 +63,12 @@ fn run() -> Result<(), Error> {
.unwrap();
let input_tensor_values = vec![array];

let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?;

assert_eq!(outputs[0].shape(), output0_shape.as_slice());
let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap();
assert_eq!(output.shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
}

Ok(())
Expand Down
30 changes: 18 additions & 12 deletions onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ to download.
//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
//! // Multiple inputs and outputs are possible
//! let input_tensor = vec![array];
//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
//! let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor)?
//! .into_iter()
//! .map(|dyn_tensor| dyn_tensor.try_extract())
//! .collect::<Result<_, _>>()?;
//! # Ok(())
//! # }
//! ```
Expand All @@ -115,7 +118,10 @@ to download.
//! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs)
//! example for more details.
use std::sync::{atomic::AtomicPtr, Arc, Mutex};
use std::{
ffi, ptr,
sync::{atomic::AtomicPtr, Arc, Mutex},
};

use lazy_static::lazy_static;

Expand All @@ -142,7 +148,7 @@ lazy_static! {
// } as *mut sys::OrtApi)));
static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = {
let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() };
assert_ne!(base, std::ptr::null());
assert_ne!(base, ptr::null());
let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi =
unsafe { (*base).GetApi.unwrap() };
let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) };
Expand All @@ -157,13 +163,13 @@ fn g_ort() -> sys::OrtApi {
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut;

assert_ne!(api_ptr_mut, std::ptr::null_mut());
assert_ne!(api_ptr_mut, ptr::null_mut());

unsafe { *api_ptr_mut }
}

fn char_p_to_string(raw: *const i8) -> Result<String> {
let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() };
let c_string = unsafe { ffi::CStr::from_ptr(raw as *mut i8).to_owned() };

match c_string.into_string() {
Ok(string) => Ok(string),
Expand All @@ -176,7 +182,7 @@ mod onnxruntime {
//! Module containing a custom logger, used to catch the runtime's own logging and send it
//! to Rust's tracing logging instead.
use std::ffi::CStr;
use std::{ffi, ffi::CStr, ptr};
use tracing::{debug, error, info, span, trace, warn, Level};

use onnxruntime_sys as sys;
Expand Down Expand Up @@ -212,7 +218,7 @@ mod onnxruntime {

/// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate.
pub(crate) extern "C" fn custom_logger(
_params: *mut std::ffi::c_void,
_params: *mut ffi::c_void,
severity: sys::OrtLoggingLevel,
category: *const i8,
logid: *const i8,
Expand All @@ -227,16 +233,16 @@ mod onnxruntime {
sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR,
};

assert_ne!(category, std::ptr::null());
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) };
assert_ne!(code_location, std::ptr::null());
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }
.to_str()
.unwrap_or("unknown");
assert_ne!(message, std::ptr::null());
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) };

assert_ne!(logid, std::ptr::null());
assert_ne!(logid, ptr::null());
let logid = unsafe { CStr::from_ptr(logid) };

// Parse the code location
Expand Down Expand Up @@ -376,7 +382,7 @@ mod test {

#[test]
fn test_char_p_to_string() {
let s = std::ffi::CString::new("foo").unwrap();
let s = ffi::CString::new("foo").unwrap();
let ptr = s.as_c_str().as_ptr();
assert_eq!("foo", char_p_to_string(ptr).unwrap());
}
Expand Down
42 changes: 16 additions & 26 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ use crate::{
error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result},
g_ort,
memory::MemoryInfo,
tensor::{
ort_owned_tensor::OrtOwnedTensor, OrtTensor, TensorDataToType, TensorElementDataType,
TypeToTensorElementDataType,
},
tensor::{DynOrtTensor, OrtTensor, TensorElementDataType, TypeToTensorElementDataType},
AllocatorType, GraphOptimizationLevel, MemType,
};

Expand Down Expand Up @@ -364,13 +361,12 @@ impl<'a> Session<'a> {
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
/// used for the input data here.
pub fn run<'s, 't, 'm, TIn, TOut, D>(
pub fn run<'s, 't, 'm, TIn, D>(
&'s mut self,
input_arrays: Vec<Array<TIn, D>>,
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
TOut: TensorDataToType,
D: ndarray::Dimension,
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
Expand Down Expand Up @@ -404,7 +400,7 @@ impl<'a> Session<'a> {
.map(|n| n.as_ptr() as *const i8)
.collect();

let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> =
vec![std::ptr::null_mut(); self.outputs.len()];

// The C API expects pointers for the arrays (pointers to C-arrays)
Expand All @@ -430,38 +426,32 @@ impl<'a> Session<'a> {
input_ort_values.len() as u64, // C API expects a u64, not isize
output_names_ptr.as_ptr(),
output_names_ptr.len() as u64, // C API expects a u64, not isize
output_tensor_extractors_ptrs.as_mut_ptr(),
output_tensor_ptrs.as_mut_ptr(),
)
};
status_to_result(status).map_err(OrtError::Run)?;

let memory_info_ref = &self.memory_info;
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
output_tensor_extractors_ptrs
let outputs: Result<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> =
output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| {
let dims = unsafe {
let (dims, data_type) = unsafe {
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
get_tensor_dimensions(tensor_info_ptr)
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
.and_then(|dims| {
extract_data_type(tensor_info_ptr)
.map(|data_type| (dims, data_type))
})
})
}?;

// Note: Both tensor and array will point to the same data, nothing is copied.
// As such, there is no need to free the pointer used to create the ArrayView.
assert_ne!(tensor_ptr, std::ptr::null_mut());

let mut is_tensor = 0;
unsafe { call_ort(|ort| ort.IsTensor.unwrap()(tensor_ptr, &mut is_tensor)) }
.map_err(OrtError::IsTensor)?;
assert_eq!(is_tensor, 1);

let array_view = TOut::extract_array(ndarray::IxDyn(&dims), tensor_ptr)?;

Ok(OrtOwnedTensor::new(
Ok(DynOrtTensor::new(
tensor_ptr,
array_view,
&memory_info_ref,
memory_info_ref,
ndarray::IxDyn(&dims),
data_type,
))
})
.collect();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub mod ndarray_tensor;
pub mod ort_owned_tensor;
pub mod ort_tensor;

pub use ort_owned_tensor::OrtOwnedTensor;
pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor};
pub use ort_tensor::OrtTensor;

use crate::{OrtError, Result};
Expand Down
Loading

0 comments on commit 555bec7

Please sign in to comment.