Skip to content

Commit

Permalink
[rust] Change inputs and outputs to session.run
Browse files Browse the repository at this point in the history
Adds WithOrtOwnedTensor<T> and WithOrtTensor<T> that
tie the drop of T to the underlying tensor. Changes
the run function to take Vec<impl ConstructTensor>
where ConstructTensor is implemented for items that
can construct a WithOrtTensor<T> for a particular T.
  • Loading branch information
boydjohnson committed Sep 29, 2022
1 parent 03333c6 commit 1c01875
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 350 deletions.
16 changes: 10 additions & 6 deletions rust/onnxruntime/examples/issue22.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
//! https://drive.google.com/file/d/1FmL-Wpm06V-8wgRqvV3Skey_X98Ue4D_/view?usp=sharing
use ndarray::Array2;
use onnxruntime::{environment::Environment, tensor::OrtOwnedTensor, GraphOptimizationLevel};
use onnxruntime::{
environment::Environment, tensor::construct::ConstructTensor, GraphOptimizationLevel,
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;

Expand All @@ -31,10 +33,12 @@ fn main() {
println!("{:#?}", session.inputs);
println!("{:#?}", session.outputs);

let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();
let mut input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let mut attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();

let outputs: Vec<OrtOwnedTensor<f32, _>> =
session.run(vec![input_ids, attention_mask]).unwrap();
print!("outputs: {:#?}", outputs);
let inputs: Vec<&mut dyn ConstructTensor> = vec![&mut input_ids, &mut attention_mask];

let outputs = session.run(inputs).unwrap();

print!("outputs: {:#?}", outputs[0].float_array().unwrap());
}
16 changes: 9 additions & 7 deletions rust/onnxruntime/examples/sample.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![forbid(unsafe_code)]

use onnxruntime::{
environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel,
LoggingLevel,
environment::Environment, ndarray::Array, tensor::construct::ConstructTensor,
GraphOptimizationLevel, LoggingLevel,
};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
Expand Down Expand Up @@ -59,16 +59,18 @@ fn run() -> Result<(), Error> {
.iter()
.map(|d| d.unwrap())
.product();
let array = Array::linspace(0.0_f32, 1.0, n as usize)
let mut array = Array::linspace(0.0_f32, 1.0, n as usize)
.into_shape(input0_shape)
.unwrap();
let input_tensor_values = vec![array];
let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array];

let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?;
let outputs = session.run(input_tensor_values)?;

assert_eq!(outputs[0].shape(), output0_shape.as_slice());
let output = outputs[0].float_array().unwrap();

assert_eq!(output.shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]);
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
}

Ok(())
Expand Down
8 changes: 4 additions & 4 deletions rust/onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ to download.
//!
//! ```no_run
//! # use std::error::Error;
//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::OrtOwnedTensor};
//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::construct::ConstructTensor};
//! # fn main() -> Result<(), Box<dyn Error>> {
//! # let environment = Environment::builder()
//! # .with_name("test")
Expand All @@ -101,10 +101,10 @@ to download.
//! # .with_graph_optimization_level(GraphOptimizationLevel::Basic)?
//! # .with_intra_op_num_threads(1)?
//! # .with_model_from_file("squeezenet.onnx")?;
//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
//! let mut array = ndarray::Array::linspace(0.0_f32, 1.0, 100);
//! // Multiple inputs and outputs are possible
//! let input_tensor = vec![array];
//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?;
//! let input_tensor: Vec<&mut dyn ConstructTensor> = vec![&mut array];
//! let outputs = session.run(input_tensor)?;
//! # Ok(())
//! # }
//! ```
Expand Down
2 changes: 1 addition & 1 deletion rust/onnxruntime/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
};

#[derive(Debug)]
pub(crate) struct MemoryInfo {
pub struct MemoryInfo {
pub ptr: *mut sys::OrtMemoryInfo,
}

Expand Down
145 changes: 70 additions & 75 deletions rust/onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Module containing session types
use std::{ffi::CString, fmt::Debug, path::Path};
use std::{convert::TryFrom, ffi::CString, fmt::Debug, path::Path};

#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
Expand All @@ -10,7 +10,6 @@ use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "model-fetching")]
use std::env;

use ndarray::Array;
use tracing::{debug, error};

use onnxruntime_sys as sys;
Expand All @@ -25,11 +24,11 @@ use crate::{
g_ort,
memory::MemoryInfo,
tensor::{
ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor},
OrtTensor,
construct::ConstructTensor,
ort_output_tensor::{OrtOutput, OrtOwnedTensorExtractor},
OrtOutputTensor,
},
AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType,
TypeToTensorElementDataType,
};

#[cfg(feature = "model-fetching")]
Expand Down Expand Up @@ -375,18 +374,23 @@ impl Session {
///
/// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus
/// used for the input data here.
pub fn run<'s, 't, 'm, TIn, TOut, D>(
&'s self,
input_arrays: Vec<Array<TIn, D>>,
) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
TOut: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
'm: 't, // 'm outlives 't (memory info outlives tensor)
's: 'm, // 's outlives 'm (session outlives memory info)
{
self.validate_input_shapes(&input_arrays)?;
pub fn run<'a>(
&self,
input_arrays: Vec<&mut dyn ConstructTensor>,
) -> Result<Vec<OrtOutput<'a>>> {
let memory_info = &self.memory_info;

let allocator = self.allocator_ptr;

let input_arrays = input_arrays
.into_iter()
.map(|v| v.construct(memory_info, allocator))
.collect::<Result<Vec<_>>>()?;

let input_arrays_shapes: Vec<Vec<usize>> =
input_arrays.iter().map(|v| v.shape().to_vec()).collect();

self.validate_input_shapes(&input_arrays_shapes)?;

// Build arguments to Run()

Expand All @@ -412,16 +416,9 @@ impl Session {
let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> =
vec![std::ptr::null_mut(); self.outputs.len()];

// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
.into_iter()
.map(|input_array| {
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
})
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
let input_ort_values: Vec<*const sys::OrtValue> = input_arrays
.iter()
.map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue)
.map(|input_array_ort| input_array_ort.ptr() as *const sys::OrtValue)
.collect();

let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
Expand All @@ -440,27 +437,25 @@ impl Session {
};
status_to_result(status).map_err(OrtError::Run)?;

let memory_info_ref = &self.memory_info;
let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> =
output_tensor_extractors_ptrs
.into_iter()
.map(|ptr| {
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo =
std::ptr::null_mut();
let status = unsafe {
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
};
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };
unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();

let mut output_tensor_extractor =
OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims));
output_tensor_extractor.tensor_ptr = ptr;
output_tensor_extractor.extract::<TOut>()
})
.collect();
let outputs: Result<Vec<OrtOutputTensor>> = output_tensor_extractors_ptrs
.into_iter()
.map(|ptr| {
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
let status = unsafe {
g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _)
};
status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?;
let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) };

unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) };
let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect();

let mut output_tensor_extractor = OrtOwnedTensorExtractor::new(dims);
output_tensor_extractor.tensor_ptr = ptr;

output_tensor_extractor.extract()
})
.collect();

// Reconvert to CString so drop impl is called and memory is freed
let cstrings: Result<Vec<CString>> = input_names_ptr
Expand All @@ -472,33 +467,29 @@ impl Session {
.collect();
cstrings?;

outputs
outputs?
.into_iter()
.map(|v| OrtOutput::try_from(v))
.collect()
}

fn validate_input_shapes<TIn, D>(&self, input_arrays: &[Array<TIn, D>]) -> Result<()>
where
TIn: TypeToTensorElementDataType + Debug + Clone,
D: ndarray::Dimension,
{
fn validate_input_shapes(&self, input_array_shapes: &[Vec<usize>]) -> Result<()> {
// ******************************************************************
// FIXME: Properly handle errors here
// Make sure all dimensions match (except dynamic ones)

// Verify length of inputs
if input_arrays.len() != self.inputs.len() {
if input_array_shapes.len() != self.inputs.len() {
error!(
"Non-matching number of inputs: {} (inference) vs {} (model)",
input_arrays.len(),
input_array_shapes.len(),
self.inputs.len()
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
inference_input: input_arrays
.iter()
.map(|input_array| input_array.shape().to_vec())
.collect(),
inference_input: input_array_shapes.to_vec(),
model_input: self
.inputs
.iter()
Expand All @@ -509,20 +500,20 @@ impl Session {
}

// Verify length of each individual inputs
let inputs_different_length = input_arrays
let inputs_different_length = input_array_shapes
.iter()
.zip(self.inputs.iter())
.any(|(l, r)| l.shape().len() != r.dimensions.len());
.any(|(l, r)| l.len() != r.dimensions.len());
if inputs_different_length {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
self.inputs, input_array_shapes
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays
inference_input: input_array_shapes
.iter()
.map(|input_array| input_array.shape().to_vec())
.map(|input_array| input_array.to_vec())
.collect(),
model_input: self
.inputs
Expand All @@ -534,24 +525,28 @@ impl Session {
}

// Verify shape of each individual inputs
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, // None means dynamic size; in that case shape always match
})
});
let inputs_different_shape =
input_array_shapes
.iter()
.zip(self.inputs.iter())
.any(|(l, r)| {
let l_shape = l;
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false, // None means dynamic size; in that case shape always match
})
});
if inputs_different_shape {
error!(
"Different input lengths: {:?} vs {:?}",
self.inputs, input_arrays
self.inputs, input_array_shapes
);
return Err(OrtError::NonMatchingDimensions(
NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays
inference_input: input_array_shapes
.iter()
.map(|input_array| input_array.shape().to_vec())
.map(|input_array| input_array.to_vec())
.collect(),
model_input: self
.inputs
Expand Down
9 changes: 5 additions & 4 deletions rust/onnxruntime/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
//! will be returned by the method which can be derefed into its internal
//! [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html).
pub mod construct;
pub mod ndarray_tensor;
pub mod ort_owned_tensor;
pub mod ort_tensor;
pub mod ort_input_tensor;
pub mod ort_output_tensor;

pub use ort_owned_tensor::OrtOwnedTensor;
pub use ort_tensor::OrtTensor;
pub use ort_output_tensor::OrtOutputTensor;
pub use ort_output_tensor::WithOutputTensor;
27 changes: 27 additions & 0 deletions rust/onnxruntime/src/tensor/construct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//! convert module has the trait for conversion of Inputs ConstructTensor.
use crate::{memory::MemoryInfo, OrtError};
use onnxruntime_sys::OrtAllocator;
use onnxruntime_sys::OrtValue;
use std::fmt::Debug;

/// The Input type for Rust onnxruntime Session::run
/// Many types can construct a
pub trait ConstructTensor: Debug {
/// Constuct an OrtTensor Input using the `MemoryInfo` and a raw pointer to the `OrtAllocator`.
fn construct<'a>(
&'a mut self,
memory_info: &MemoryInfo,
allocator: *mut OrtAllocator,
) -> Result<Box<dyn InputTensor + 'a>, OrtError>;
}

/// Allows the return value of ConstructTensor::construct
/// to be generic.
pub trait InputTensor {
/// The input tensor's shape
fn shape(&self) -> &[usize];

/// The input tensor's ptr
fn ptr(&self) -> *mut OrtValue;
}
Loading

0 comments on commit 1c01875

Please sign in to comment.