Skip to content

Commit

Permalink
[rust] Change inputs and outputs to session.run
Browse files Browse the repository at this point in the history
Changes the run function to take
impl RefMut<Box<dyn ConstructTensor>>
where ConstructTensor is implemented for items that
can construct a OrtInputTensor<T> for a particular T.
  • Loading branch information
boydjohnson committed Oct 3, 2022
1 parent 80144cc commit b9efaa2
Show file tree
Hide file tree
Showing 11 changed files with 567 additions and 359 deletions.
10 changes: 6 additions & 4 deletions rust/onnxruntime/examples/issue22.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! https://drive.google.com/file/d/1FmL-Wpm06V-8wgRqvV3Skey_X98Ue4D_/view?usp=sharing
use ndarray::Array2;
use onnxruntime::{environment::Environment, tensor::OrtOwnedTensor, GraphOptimizationLevel};
use onnxruntime::{environment::Environment, GraphOptimizationLevel};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;

Expand Down Expand Up @@ -34,7 +34,9 @@ fn main() {
let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();

let outputs: Vec<OrtOwnedTensor<f32, _>> =
session.run(vec![input_ids, attention_mask]).unwrap();
print!("outputs: {:#?}", outputs);
let inputs = vec![input_ids.into(), attention_mask.into()];

let outputs = session.run(inputs).unwrap();

print!("outputs: {:#?}", outputs[0].float_array().unwrap());
}
15 changes: 7 additions & 8 deletions rust/onnxruntime/examples/sample.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#![forbid(unsafe_code)]

use onnxruntime::{
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
LoggingLevel,
};
use onnxruntime::{environment::Environment, ndarray::Array, GraphOptimizationLevel, LoggingLevel};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;

Expand Down Expand Up @@ -62,13 +59,15 @@ fn run() -> Result<(), Error> {
let array = Array::linspace(0.0_f32, 1.0, n as usize)
.into_shape(input0_shape)
.unwrap();
let input_tensor_values = vec![array];
let input_tensor_values = vec![array.into()];

let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
let outputs = session.run(input_tensor_values)?;

assert_eq!(outputs[0].shape(), output0_shape.as_slice());
let output = outputs[0].float_array().unwrap();

assert_eq!(output.shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
}

Ok(())
Expand Down
6 changes: 3 additions & 3 deletions rust/onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ to download.
//!
//! ```no_run
//! # use std::error::Error;
//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::OrtOwnedTensor};
//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::construct::ConstructTensor};
//! # fn main() -> Result<(), Box<dyn Error>> {
//! # let environment = Environment::builder()
//! # .with_name("test")
Expand All @@ -103,8 +103,8 @@ to download.
//! # .with_model_from_file("squeezenet.onnx")?;
//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
//! // Multiple inputs and outputs are possible
//! let input_tensor = vec![array];
//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
//! let input_tensor = vec![array.into()];
//! let outputs = session.run(input_tensor)?;
//! # Ok(())
//! # }
//! ```
Expand Down
2 changes: 1 addition & 1 deletion rust/onnxruntime/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
};

#[derive(Debug)]
pub(crate) struct MemoryInfo {
pub struct MemoryInfo {
pub ptr: *mut sys::OrtMemoryInfo,
}

Expand Down
199 changes: 99 additions & 100 deletions rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{ffi::CString, fmt::Debug, path::Path};
use std::{convert::TryFrom, ffi::CString, fmt::Debug, path::Path};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand All @@ -10,7 +10,6 @@ use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "model-fetching")]
use std::env;

use ndarray::Array;
use tracing::{debug, error};

use onnxruntime_sys as sys;
Expand All @@ -25,11 +24,11 @@ use crate::{
g_ort,
memory::MemoryInfo,
tensor::{
ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor},
OrtTensor,
construct::ConstructTensor,
ort_output_tensor::{OrtOutput, OrtOwnedTensorExtractor},
OrtOutputTensor,
},
AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType,
TypeToTensorElementDataType,
};

#[cfg(feature = "model-fetching")]
Expand Down Expand Up @@ -375,28 +374,12 @@ impl Session {
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
/// used for the input data here.
pub fn run<'s, 't, 'm, TIn, TOut, D>(
&'s self,
input_arrays: Vec<Array<TIn, D>>,
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
TOut: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
{
self.validate_input_shapes(&input_arrays)?;

// Build arguments to Run()

let input_names_ptr: Vec<*const i8> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const i8)
.collect();
pub fn run<'input, 'output>(
&'output self,
mut input_arrays: impl AsMut<[Box<dyn ConstructTensor + 'input>]> + 'input,
) -> Result<Vec<OrtOutput<'output>>> {
let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
vec![std::ptr::null_mut(); self.outputs.len()];

let output_names_cstring: Vec<CString> = self
.outputs
Expand All @@ -409,59 +392,75 @@ impl Session {
.map(|n| n.as_ptr().cast::<i8>())
.collect();

let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
vec![std::ptr::null_mut(); self.outputs.len()];

// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
.into_iter()
.map(|input_array| {
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
})
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
let input_names_ptr: Vec<*const i8> = self
.inputs
.iter()
.map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue)
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const i8)
.collect();

let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
{
let memory_info = &self.memory_info;

let status = unsafe {
g_ort().Run.unwrap()(
self.session_ptr,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len(),
output_names_ptr.as_ptr(),
output_names_ptr.len(),
output_tensor_extractors_ptrs.as_mut_ptr(),
)
};
status_to_result(status).map_err(OrtError::Run)?;
let allocator = self.allocator_ptr;

let memory_info_ref = &self.memory_info;
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
output_tensor_extractors_ptrs
let arr = input_arrays.as_mut();

let input_tensors = arr
.into_iter()
.map(|ptr| {
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
std::ptr::null_mut();
let status = unsafe {
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
};
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();

let mut output_tensor_extractor =
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
output_tensor_extractor.tensor_ptr = ptr;
output_tensor_extractor.extract::<TOut>()
})
.map(|v| v.construct(memory_info, allocator))
.collect::<Result<Vec<_>>>()?;

let input_arrays_shapes: Vec<Vec<usize>> =
input_tensors.iter().map(|v| v.shape().to_vec()).collect();

self.validate_input_shapes(&input_arrays_shapes)?;

// Build arguments to Run()

let input_ort_values: Vec<*const sys::OrtValue> = input_tensors
.iter()
.map(|input_array_ort| input_array_ort.ptr() as *const sys::OrtValue)
.collect();

let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();

let status = unsafe {
g_ort().Run.unwrap()(
self.session_ptr,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len(),
output_names_ptr.as_ptr(),
output_names_ptr.len(),
output_tensor_extractors_ptrs.as_mut_ptr(),
)
};
status_to_result(status).map_err(OrtError::Run)?;
}

let outputs: Result<Vec<OrtOutputTensor>> = output_tensor_extractors_ptrs
.into_iter()
.map(|ptr| {
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
let status = unsafe {
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
};
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };

unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();

let mut output_tensor_extractor = OrtOwnedTensorExtractor::new(dims);
output_tensor_extractor.tensor_ptr = ptr;

output_tensor_extractor.extract()
})
.collect();

// Reconvert to CString so drop impl is called and memory is freed
let cstrings: Result<Vec<CString>> = input_names_ptr
.into_iter()
Expand All @@ -472,33 +471,29 @@ impl Session {
.collect();
cstrings?;

outputs
outputs?
.into_iter()
.map(|v| OrtOutput::try_from(v))
.collect()
}

fn validate_input_shapes<TIn, D>(&self, input_arrays: &[Array<TIn, D>]) -> Result<()>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
{
fn validate_input_shapes(&self, input_array_shapes: &[Vec<usize>]) -> Result<()> {
// ******************************************************************
// FIXME: Properly handle errors here
// Make sure all dimensions match (except dynamic ones)

// Verify length of inputs
if input_arrays.len() != self.inputs.len() {
if input_array_shapes.len() != self.inputs.len() {
error!(
"Non-matching number of inputs: {} (inference) vs {} (model)",
input_arrays.len(),
input_array_shapes.len(),
self.inputs.len()
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
inference_input: input_arrays
.iter()
.map(|input_array| input_array.shape().to_vec())
.collect(),
inference_input: input_array_shapes.to_vec(),
model_input: self
.inputs
.iter()
Expand All @@ -509,20 +504,20 @@ impl Session {
}

// Verify length of each individual inputs
let inputs_different_length = input_arrays
let inputs_different_length = input_array_shapes
.iter()
.zip(self.inputs.iter())
.any(|(l, r)| l.shape().len() != r.dimensions.len());
.any(|(l, r)| l.len() != r.dimensions.len());
if inputs_different_length {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
self.inputs, input_array_shapes
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays
inference_input: input_array_shapes
.iter()
.map(|input_array| input_array.shape().to_vec())
.map(|input_array| input_array.to_vec())
.collect(),
model_input: self
.inputs
Expand All @@ -534,24 +529,28 @@ impl Session {
}

// Verify shape of each individual inputs
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, // None means dynamic size; in that case shape always match
})
});
let inputs_different_shape =
input_array_shapes
.iter()
.zip(self.inputs.iter())
.any(|(l, r)| {
let l_shape = l;
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, // None means dynamic size; in that case shape always match
})
});
if inputs_different_shape {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
self.inputs, input_array_shapes
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays
inference_input: input_array_shapes
.iter()
.map(|input_array| input_array.shape().to_vec())
.map(|input_array| input_array.to_vec())
.collect(),
model_input: self
.inputs
Expand Down
9 changes: 5 additions & 4 deletions rust/onnxruntime/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
//! will be returned by the method which can be derefed into its internal
//! [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
pub mod construct;
pub mod ndarray_tensor;
pub mod ort_owned_tensor;
pub mod ort_tensor;
pub mod ort_input_tensor;
pub mod ort_output_tensor;

pub use ort_owned_tensor::OrtOwnedTensor;
pub use ort_tensor::OrtTensor;
pub use ort_output_tensor::OrtOutputTensor;
pub use ort_output_tensor::WithOutputTensor;
Loading

0 comments on commit b9efaa2

Please sign in to comment.