From 1c0187527270c3c0ce664ab59082241e14def132 Mon Sep 17 00:00:00 2001 From: Boyd Johnson Date: Wed, 21 Sep 2022 09:24:08 +0000 Subject: [PATCH] [rust] Change inputs and outputs to session.run Adds WithOrtOwnedTensor and WithOrtTensor that tie the drop of T to the underlying tensor. Changes the run function to take Vec where ConstructTensor is implemented for items that can construct a WithOrtTensor for a particular T. --- rust/onnxruntime/examples/issue22.rs | 16 +- rust/onnxruntime/examples/sample.rs | 16 +- rust/onnxruntime/src/lib.rs | 8 +- rust/onnxruntime/src/memory.rs | 2 +- rust/onnxruntime/src/session.rs | 145 ++++---- rust/onnxruntime/src/tensor.rs | 9 +- rust/onnxruntime/src/tensor/construct.rs | 27 ++ .../{ort_tensor.rs => ort_input_tensor.rs} | 155 ++++----- .../src/tensor/ort_output_tensor.rs | 312 ++++++++++++++++++ .../src/tensor/ort_owned_tensor.rs | 137 -------- rust/onnxruntime/tests/integration_tests.rs | 77 ++--- 11 files changed, 554 insertions(+), 350 deletions(-) create mode 100644 rust/onnxruntime/src/tensor/construct.rs rename rust/onnxruntime/src/tensor/{ort_tensor.rs => ort_input_tensor.rs} (71%) create mode 100644 rust/onnxruntime/src/tensor/ort_output_tensor.rs delete mode 100644 rust/onnxruntime/src/tensor/ort_owned_tensor.rs diff --git a/rust/onnxruntime/examples/issue22.rs b/rust/onnxruntime/examples/issue22.rs index 944940055549d..ca4b68ac4a8b9 100644 --- a/rust/onnxruntime/examples/issue22.rs +++ b/rust/onnxruntime/examples/issue22.rs @@ -4,7 +4,9 @@ //! https://drive.google.com/file/d/1FmL-Wpm06V-8wgRqvV3Skey_X98Ue4D_/view?usp=sharing use ndarray::Array2; -use onnxruntime::{environment::Environment, tensor::OrtOwnedTensor, GraphOptimizationLevel}; +use onnxruntime::{ + environment::Environment, tensor::construct::ConstructTensor, GraphOptimizationLevel, +}; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -31,10 +33,12 @@ fn main() { println!("{:#?}", session.inputs); println!("{:#?}", session.outputs); - let input_ids = Array2::::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap(); - let attention_mask = Array2::::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap(); + let mut input_ids = Array2::::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap(); + let mut attention_mask = Array2::::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap(); - let outputs: Vec> = - session.run(vec![input_ids, attention_mask]).unwrap(); - print!("outputs: {:#?}", outputs); + let inputs: Vec<&mut dyn ConstructTensor> = vec![&mut input_ids, &mut attention_mask]; + + let outputs = session.run(inputs).unwrap(); + + print!("outputs: {:#?}", outputs[0].float_array().unwrap()); } diff --git a/rust/onnxruntime/examples/sample.rs b/rust/onnxruntime/examples/sample.rs index 6f877d77787ed..27605f713badb 100644 --- a/rust/onnxruntime/examples/sample.rs +++ b/rust/onnxruntime/examples/sample.rs @@ -1,8 +1,8 @@ #![forbid(unsafe_code)] use onnxruntime::{ - environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel, - LoggingLevel, + environment::Environment, ndarray::Array, tensor::construct::ConstructTensor, + GraphOptimizationLevel, LoggingLevel, }; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -59,16 +59,18 @@ fn run() -> Result<(), Error> { .iter() .map(|d| d.unwrap()) .product(); - let array = Array::linspace(0.0_f32, 1.0, n as usize) + let mut array = Array::linspace(0.0_f32, 1.0, n as usize) .into_shape(input0_shape) .unwrap(); - let input_tensor_values = vec![array]; + let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array]; - let outputs: Vec> = session.run(input_tensor_values)?; + let outputs = session.run(input_tensor_values)?; - assert_eq!(outputs[0].shape(), output0_shape.as_slice()); + let output = outputs[0].float_array().unwrap(); + + assert_eq!(output.shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]); } Ok(()) diff --git a/rust/onnxruntime/src/lib.rs b/rust/onnxruntime/src/lib.rs index bdf2a8586adef..6d06c09c821d3 100644 --- a/rust/onnxruntime/src/lib.rs +++ b/rust/onnxruntime/src/lib.rs @@ -90,7 +90,7 @@ to download. //! //! ```no_run //! # use std::error::Error; -//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::OrtOwnedTensor}; +//! # use onnxruntime::{environment::Environment, LoggingLevel, GraphOptimizationLevel, tensor::construct::ConstructTensor}; //! # fn main() -> Result<(), Box> { //! # let environment = Environment::builder() //! # .with_name("test") @@ -101,10 +101,10 @@ to download. //! # .with_graph_optimization_level(GraphOptimizationLevel::Basic)? //! # .with_intra_op_num_threads(1)? //! # .with_model_from_file("squeezenet.onnx")?; -//! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100); +//! let mut array = ndarray::Array::linspace(0.0_f32, 1.0, 100); //! // Multiple inputs and outputs are possible -//! let input_tensor = vec![array]; -//! let outputs: Vec> = session.run(input_tensor)?; +//! let input_tensor: Vec<&mut dyn ConstructTensor> = vec![&mut array]; +//! let outputs = session.run(input_tensor)?; //! # Ok(()) //! # } //! ``` diff --git a/rust/onnxruntime/src/memory.rs b/rust/onnxruntime/src/memory.rs index cf9873220f0c9..7449d5713133d 100644 --- a/rust/onnxruntime/src/memory.rs +++ b/rust/onnxruntime/src/memory.rs @@ -10,7 +10,7 @@ use crate::{ }; #[derive(Debug)] -pub(crate) struct MemoryInfo { +pub struct MemoryInfo { pub ptr: *mut sys::OrtMemoryInfo, } diff --git a/rust/onnxruntime/src/session.rs b/rust/onnxruntime/src/session.rs index ef0304968c40b..36493c9bc5574 100644 --- a/rust/onnxruntime/src/session.rs +++ b/rust/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryFrom, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -10,7 +10,6 @@ use std::os::windows::ffi::OsStrExt; #[cfg(feature = "model-fetching")] use std::env; -use ndarray::Array; use tracing::{debug, error}; use onnxruntime_sys as sys; @@ -25,11 +24,11 @@ use crate::{ g_ort, memory::MemoryInfo, tensor::{ - ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, - OrtTensor, + construct::ConstructTensor, + ort_output_tensor::{OrtOutput, OrtOwnedTensorExtractor}, + OrtOutputTensor, }, AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, - TypeToTensorElementDataType, }; #[cfg(feature = "model-fetching")] @@ -375,18 +374,23 @@ impl Session { /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'s, 't, 'm, TIn, TOut, D>( - &'s self, - input_arrays: Vec>, - ) -> Result>> - where - TIn: TypeToTensorElementDataType + Debug + Clone, - TOut: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, - 'm: 't, // 'm outlives 't (memory info outlives tensor) - 's: 'm, // 's outlives 'm (session outlives memory info) - { - self.validate_input_shapes(&input_arrays)?; + pub fn run<'a>( + &self, + input_arrays: Vec<&mut dyn ConstructTensor>, + ) -> Result>> { + let memory_info = &self.memory_info; + + let allocator = self.allocator_ptr; + + let input_arrays = input_arrays + .into_iter() + .map(|v| v.construct(memory_info, allocator)) + .collect::>>()?; + + let input_arrays_shapes: Vec> = + input_arrays.iter().map(|v| v.shape().to_vec()).collect(); + + self.validate_input_shapes(&input_arrays_shapes)?; // Build arguments to Run() @@ -412,16 +416,9 @@ impl Session { let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; - // The C API expects pointers for the arrays (pointers to C-arrays) - let input_ort_tensors: Vec> = input_arrays - .into_iter() - .map(|input_array| { - OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array) - }) - .collect::>>>()?; - let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors + let input_ort_values: Vec<*const sys::OrtValue> = input_arrays .iter() - .map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue) + .map(|input_array_ort| input_array_ort.ptr() as *const sys::OrtValue) .collect(); let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null(); @@ -440,27 +437,25 @@ impl Session { }; status_to_result(status).map_err(OrtError::Run)?; - let memory_info_ref = &self.memory_info; - let outputs: Result>>> = - output_tensor_extractors_ptrs - .into_iter() - .map(|ptr| { - let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = - std::ptr::null_mut(); - let status = unsafe { - g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) - }; - status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; - let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; - unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; - let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); - - let mut output_tensor_extractor = - OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); - output_tensor_extractor.tensor_ptr = ptr; - output_tensor_extractor.extract::() - }) - .collect(); + let outputs: Result> = output_tensor_extractors_ptrs + .into_iter() + .map(|ptr| { + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + let status = unsafe { + g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) + }; + status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; + let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; + + unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; + let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); + + let mut output_tensor_extractor = OrtOwnedTensorExtractor::new(dims); + output_tensor_extractor.tensor_ptr = ptr; + + output_tensor_extractor.extract() + }) + .collect(); // Reconvert to CString so drop impl is called and memory is freed let cstrings: Result> = input_names_ptr @@ -472,33 +467,29 @@ impl Session { .collect(); cstrings?; - outputs + outputs? + .into_iter() + .map(|v| OrtOutput::try_from(v)) + .collect() } - fn validate_input_shapes(&self, input_arrays: &[Array]) -> Result<()> - where - TIn: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, - { + fn validate_input_shapes(&self, input_array_shapes: &[Vec]) -> Result<()> { // ****************************************************************** // FIXME: Properly handle errors here // Make sure all dimensions match (except dynamic ones) // Verify length of inputs - if input_arrays.len() != self.inputs.len() { + if input_array_shapes.len() != self.inputs.len() { error!( "Non-matching number of inputs: {} (inference) vs {} (model)", - input_arrays.len(), + input_array_shapes.len(), self.inputs.len() ); return Err(OrtError::NonMatchingDimensions( NonMatchingDimensionsError::InputsCount { inference_input_count: 0, model_input_count: 0, - inference_input: input_arrays - .iter() - .map(|input_array| input_array.shape().to_vec()) - .collect(), + inference_input: input_array_shapes.to_vec(), model_input: self .inputs .iter() @@ -509,20 +500,20 @@ impl Session { } // Verify length of each individual inputs - let inputs_different_length = input_arrays + let inputs_different_length = input_array_shapes .iter() .zip(self.inputs.iter()) - .any(|(l, r)| l.shape().len() != r.dimensions.len()); + .any(|(l, r)| l.len() != r.dimensions.len()); if inputs_different_length { error!( "Different input lengths: {:?} vs {:?}", - self.inputs, input_arrays + self.inputs, input_array_shapes ); return Err(OrtError::NonMatchingDimensions( NonMatchingDimensionsError::InputsLength { - inference_input: input_arrays + inference_input: input_array_shapes .iter() - .map(|input_array| input_array.shape().to_vec()) + .map(|input_array| input_array.to_vec()) .collect(), model_input: self .inputs @@ -534,24 +525,28 @@ impl Session { } // Verify shape of each individual inputs - let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| { - let l_shape = l.shape(); - let r_shape = r.dimensions.as_slice(); - l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 { - Some(r3) => *r3 as usize != *l2, - None => false, // None means dynamic size; in that case shape always match - }) - }); + let inputs_different_shape = + input_array_shapes + .iter() + .zip(self.inputs.iter()) + .any(|(l, r)| { + let l_shape = l; + let r_shape = r.dimensions.as_slice(); + l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 { + Some(r3) => *r3 as usize != *l2, + None => false, // None means dynamic size; in that case shape always match + }) + }); if inputs_different_shape { error!( "Different input lengths: {:?} vs {:?}", - self.inputs, input_arrays + self.inputs, input_array_shapes ); return Err(OrtError::NonMatchingDimensions( NonMatchingDimensionsError::InputsLength { - inference_input: input_arrays + inference_input: input_array_shapes .iter() - .map(|input_array| input_array.shape().to_vec()) + .map(|input_array| input_array.to_vec()) .collect(), model_input: self .inputs diff --git a/rust/onnxruntime/src/tensor.rs b/rust/onnxruntime/src/tensor.rs index 92404842477a8..64dae3f258364 100644 --- a/rust/onnxruntime/src/tensor.rs +++ b/rust/onnxruntime/src/tensor.rs @@ -23,9 +23,10 @@ //! will be returned by the method which can be derefed into its internal //! [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). +pub mod construct; pub mod ndarray_tensor; -pub mod ort_owned_tensor; -pub mod ort_tensor; +pub mod ort_input_tensor; +pub mod ort_output_tensor; -pub use ort_owned_tensor::OrtOwnedTensor; -pub use ort_tensor::OrtTensor; +pub use ort_output_tensor::OrtOutputTensor; +pub use ort_output_tensor::WithOutputTensor; diff --git a/rust/onnxruntime/src/tensor/construct.rs b/rust/onnxruntime/src/tensor/construct.rs new file mode 100644 index 0000000000000..2a1e8001ee0f8 --- /dev/null +++ b/rust/onnxruntime/src/tensor/construct.rs @@ -0,0 +1,27 @@ +//! convert module has the trait for conversion of Inputs ConstructTensor. + +use crate::{memory::MemoryInfo, OrtError}; +use onnxruntime_sys::OrtAllocator; +use onnxruntime_sys::OrtValue; +use std::fmt::Debug; + +/// The Input type for Rust onnxruntime Session::run +/// Many types can construct a +pub trait ConstructTensor: Debug { + /// Constuct an OrtTensor Input using the `MemoryInfo` and a raw pointer to the `OrtAllocator`. + fn construct<'a>( + &'a mut self, + memory_info: &MemoryInfo, + allocator: *mut OrtAllocator, + ) -> Result, OrtError>; +} + +/// Allows the return value of ConstructTensor::construct +/// to be generic. +pub trait InputTensor { + /// The input tensor's shape + fn shape(&self) -> &[usize]; + + /// The input tensor's ptr + fn ptr(&self) -> *mut OrtValue; +} diff --git a/rust/onnxruntime/src/tensor/ort_tensor.rs b/rust/onnxruntime/src/tensor/ort_input_tensor.rs similarity index 71% rename from rust/onnxruntime/src/tensor/ort_tensor.rs rename to rust/onnxruntime/src/tensor/ort_input_tensor.rs index 8d9b85bc40f3b..5f32f0fe8d289 100644 --- a/rust/onnxruntime/src/tensor/ort_tensor.rs +++ b/rust/onnxruntime/src/tensor/ort_input_tensor.rs @@ -1,59 +1,65 @@ //! Module containing tensor with memory owned by Rust -use std::{ffi, fmt::Debug, ops::Deref}; - -use ndarray::Array; -use tracing::{debug, error}; - -use onnxruntime_sys as sys; - +use super::construct::{ConstructTensor, InputTensor}; use crate::{ error::{assert_not_null_pointer, call_ort, status_to_result}, g_ort, memory::MemoryInfo, - tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType, TypeToTensorElementDataType, }; +use ndarray::{Array, Dimension}; +use onnxruntime_sys as sys; +use std::{ffi, fmt::Debug}; +use sys::OrtAllocator; +use tracing::{debug, error}; -/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) +/// An Input tensor. /// -/// This tensor bounds the ONNX Runtime to `ndarray`; it is used to copy an +/// This ties the lifetime of T to the OrtValue; it is used to copy an /// [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) to the runtime's memory. /// /// **NOTE**: The type is not meant to be used directly, use an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) /// instead. #[derive(Debug)] -#[allow(dead_code)] // This is required to appease clipply as `memory_info` is not read from. -pub struct OrtTensor<'t, T, D> +pub struct OrtInputTensor where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, + T: Debug, { pub(crate) c_ptr: *mut sys::OrtValue, - array: Array, - memory_info: &'t MemoryInfo, + pub(crate) shape: Vec, + #[allow(dead_code)] + item: T, +} + +impl OrtInputTensor +where + T: Debug, +{ + /// The shape of the OrtTensor. + pub fn shape(&self) -> &[usize] { + &self.shape + } } -impl<'t, T, D> OrtTensor<'t, T, D> +impl ConstructTensor for Array where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, + T: TypeToTensorElementDataType + Debug, + D: Dimension, { - pub(crate) fn from_array<'m>( - memory_info: &'m MemoryInfo, - allocator_ptr: *mut sys::OrtAllocator, - mut array: Array, - ) -> Result> - where - 'm: 't, // 'm outlives 't - { + fn construct<'a>( + &'a mut self, + memory_info: &MemoryInfo, + allocator_ptr: *mut OrtAllocator, + ) -> Result> { // where onnxruntime will write the tensor data to let mut tensor_ptr: *mut sys::OrtValue = std::ptr::null_mut(); let tensor_ptr_ptr: *mut *mut sys::OrtValue = &mut tensor_ptr; - let shape: Vec = array.shape().iter().map(|d: &usize| *d as i64).collect(); + let sh = self.shape().to_vec(); + + let shape: Vec = self.shape().iter().map(|d: &usize| *d as i64).collect(); let shape_ptr: *const i64 = shape.as_ptr(); - let shape_len = array.shape().len(); + let shape_len = self.shape().len(); match T::tensor_element_data_type() { TensorElementDataType::Float @@ -66,10 +72,13 @@ where | TensorElementDataType::Double | TensorElementDataType::Uint32 | TensorElementDataType::Uint64 => { + let buffer_size = self.len() * std::mem::size_of::(); + // primitive data is already suitably laid out in memory; provide it to // onnxruntime as is let tensor_values_ptr: *mut std::ffi::c_void = - array.as_mut_ptr().cast::(); + self.as_mut_ptr().cast::(); + assert_not_null_pointer(tensor_values_ptr, "TensorValues")?; unsafe { @@ -77,7 +86,7 @@ where ort.CreateTensorWithDataAsOrtValue.unwrap()( memory_info.ptr, tensor_values_ptr, - array.len() * std::mem::size_of::(), + buffer_size, shape_ptr, shape_len, T::tensor_element_data_type().into(), @@ -108,7 +117,7 @@ where .map_err(OrtError::CreateTensor)?; // create null-terminated copies of each string, as per `FillStringTensor` docs - let null_terminated_copies: Vec = array + let null_terminated_copies: Vec = self .iter() .map(|elt| { let slice = elt @@ -139,30 +148,17 @@ where assert_not_null_pointer(tensor_ptr, "Tensor")?; - Ok(OrtTensor { + Ok(Box::new(OrtInputTensor { c_ptr: tensor_ptr, - array, - memory_info, - }) - } -} - -impl<'t, T, D> Deref for OrtTensor<'t, T, D> -where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, -{ - type Target = Array; - - fn deref(&self) -> &Self::Target { - &self.array + shape: sh, + item: self, + })) } } -impl<'t, T, D> Drop for OrtTensor<'t, T, D> +impl Drop for OrtInputTensor where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, + T: Debug, { #[tracing::instrument] fn drop(&mut self) { @@ -178,18 +174,17 @@ where } } -impl<'t, T, D> OrtTensor<'t, T, D> +impl InputTensor for OrtInputTensor<&mut Array> where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, + T: TypeToTensorElementDataType + Debug, + D: Dimension, { - /// Apply a softmax on the specified axis - pub fn softmax(&self, axis: ndarray::Axis) -> Array - where - D: ndarray::RemoveAxis, - T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, - { - self.array.softmax(axis) + fn ptr(&self) -> *mut sys::OrtValue { + self.c_ptr + } + + fn shape(&self) -> &[usize] { + &self.shape } } @@ -198,14 +193,16 @@ mod tests { use super::*; use crate::{AllocatorType, MemType}; use ndarray::{arr0, arr1, arr2, arr3}; - use std::ptr; + use test_log::test; #[test] fn orttensor_from_array_0d_i32() { let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap(); - let array = arr0::(123); - let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap(); + let mut array = arr0::(123); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); let expected_shape: &[usize] = &[]; assert_eq!(tensor.shape(), expected_shape); } @@ -213,8 +210,10 @@ mod tests { #[test] fn orttensor_from_array_1d_i32() { let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap(); - let array = arr1(&[1_i32, 2, 3, 4, 5, 6]); - let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap(); + let mut array = arr1(&[1_i32, 2, 3, 4, 5, 6]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); let expected_shape: &[usize] = &[6]; assert_eq!(tensor.shape(), expected_shape); } @@ -222,43 +221,51 @@ mod tests { #[test] fn orttensor_from_array_2d_i32() { let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap(); - let array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]); - let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap(); + let mut array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); assert_eq!(tensor.shape(), &[2, 6]); } #[test] fn orttensor_from_array_3d_i32() { let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap(); - let array = arr3(&[ + let mut array = arr3(&[ [[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], [[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]], [[25, 26, 27, 28, 29, 30], [31, 32, 33, 34, 35, 36]], ]); - let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap(); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); assert_eq!(tensor.shape(), &[3, 2, 6]); } #[test] fn orttensor_from_array_1d_string() { let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap(); - let array = arr1(&[ + let mut array = arr1(&[ String::from("foo"), String::from("bar"), String::from("baz"), ]); - let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap(); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); assert_eq!(tensor.shape(), &[3]); } #[test] fn orttensor_from_array_3d_str() { let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap(); - let array = arr3(&[ + let mut array = arr3(&[ [["1", "2", "3"], ["4", "5", "6"]], [["7", "8", "9"], ["10", "11", "12"]], ]); - let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap(); + let tensor = array + .construct(&memory_info, ort_default_allocator()) + .unwrap(); assert_eq!(tensor.shape(), &[2, 2, 3]); } diff --git a/rust/onnxruntime/src/tensor/ort_output_tensor.rs b/rust/onnxruntime/src/tensor/ort_output_tensor.rs new file mode 100644 index 0000000000000..55bbe189c6c8f --- /dev/null +++ b/rust/onnxruntime/src/tensor/ort_output_tensor.rs @@ -0,0 +1,312 @@ +//! Module containing tensor with memory owned by the ONNX Runtime + +use crate::{error::status_to_result, g_ort, OrtError, Result, TypeToTensorElementDataType}; +use ndarray::ArrayView; +use onnxruntime_sys as sys; +use std::{convert::TryFrom, fmt::Debug}; + +use tracing::debug; + +/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. +/// +/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method. +/// It is not meant to be created directly. +#[derive(Debug)] +pub struct OrtOutputTensor { + pub(crate) tensor_ptr: *mut sys::OrtValue, + pub(crate) shape: Vec, +} + +#[derive(Debug)] +pub(crate) struct OrtOwnedTensorExtractor { + pub(crate) tensor_ptr: *mut sys::OrtValue, + pub(crate) shape: Vec, +} + +impl OrtOwnedTensorExtractor { + pub(crate) fn new(shape: Vec) -> OrtOwnedTensorExtractor { + OrtOwnedTensorExtractor { + tensor_ptr: std::ptr::null_mut(), + shape, + } + } + + pub(crate) fn extract(self) -> Result { + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need too free the pointer used to create the ArrayView. + + assert_ne!(self.tensor_ptr, std::ptr::null_mut()); + + let mut is_tensor = 0; + let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; + status_to_result(status).map_err(OrtError::IsTensor)?; + (is_tensor == 1) + .then_some(()) + .ok_or(OrtError::IsTensorCheck)?; + + Ok(OrtOutputTensor { + tensor_ptr: self.tensor_ptr, + shape: self.shape, + }) + } +} + +impl Drop for OrtOutputTensor { + #[tracing::instrument] + fn drop(&mut self) { + debug!("Dropping OrtOwnedTensor."); + unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) } + + self.tensor_ptr = std::ptr::null_mut(); + } +} + +/// An Ouput tensor with the ptr and the item that will copy from the ptr. +#[derive(Debug)] +pub struct WithOutputTensor<'a, T> { + #[allow(dead_code)] + pub(crate) tensor: OrtOutputTensor, + item: ArrayView<'a, T, ndarray::IxDyn>, +} + +impl<'a, T> std::ops::Deref for WithOutputTensor<'a, T> { + type Target = ArrayView<'a, T, ndarray::IxDyn>; + + fn deref(&self) -> &Self::Target { + &self.item + } +} + +impl<'a, T> TryFrom for WithOutputTensor<'a, T> +where + T: TypeToTensorElementDataType, +{ + type Error = OrtError; + + fn try_from(value: OrtOutputTensor) -> Result { + // Get pointer to output tensor float values + let mut output_array_ptr: *mut T = std::ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = + output_array_ptr_ptr.cast::<*mut std::ffi::c_void>(); + let status = unsafe { + g_ort().GetTensorMutableData.unwrap()(value.tensor_ptr, output_array_ptr_ptr_void) + }; + status_to_result(status).map_err(OrtError::IsTensor)?; + assert_ne!(output_array_ptr, std::ptr::null_mut()); + + let array_view = + unsafe { ArrayView::from_shape_ptr(ndarray::IxDyn(&value.shape), output_array_ptr) }; + + Ok(WithOutputTensor { + tensor: value, + item: array_view, + }) + } +} + +/// The onnxruntime Run output type. +pub enum OrtOutput<'a> { + /// Tensor of f32s + Float(WithOutputTensor<'a, f32>), + /// Tensor of f64s + Double(WithOutputTensor<'a, f64>), + /// Tensor of u8s + UInt8(WithOutputTensor<'a, u8>), + /// Tensor of u16s + UInt16(WithOutputTensor<'a, u16>), + /// Tensor of u32s + UInt32(WithOutputTensor<'a, u32>), + /// Tensor of u64s + UInt64(WithOutputTensor<'a, u64>), + /// Tensor of i8s + Int8(WithOutputTensor<'a, i8>), + /// Tensor of i16s + Int16(WithOutputTensor<'a, i16>), + /// Tensor of i32s + Int32(WithOutputTensor<'a, i32>), + /// Tensor of i64s + Int64(WithOutputTensor<'a, i64>), + /// Tensor of Strings + String(WithOutputTensor<'a, String>), +} + +impl<'a> OrtOutput<'a> { + /// Return `WithOutputTensor<'a, f32>` which derefs into an `ArrayView`. + pub fn float_array(&self) -> Option<&WithOutputTensor<'a, f32>> { + if let Self::Float(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, f64>` which derefs into an `ArrayView`. + pub fn double_array(&self) -> Option<&WithOutputTensor<'a, f64>> { + if let Self::Double(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u8>` which derefs into an `ArrayView`. + pub fn uint8_array(&self) -> Option<&WithOutputTensor<'a, u8>> { + if let Self::UInt8(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u16>` which derefs into an `ArrayView`. + pub fn uint16_array(&self) -> Option<&WithOutputTensor<'a, u16>> { + if let Self::UInt16(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u32>` which derefs into an `ArrayView`. + pub fn uint32_array(&self) -> Option<&WithOutputTensor<'a, u32>> { + if let Self::UInt32(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, u64>` which derefs into an `ArrayView`. + pub fn uint64_array(&self) -> Option<&WithOutputTensor<'a, u64>> { + if let Self::UInt64(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i8>` which derefs into an `ArrayView`. + pub fn int8_array(&self) -> Option<&WithOutputTensor<'a, i8>> { + if let Self::Int8(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i16>` which derefs into an `ArrayView`. + pub fn int16_array(&self) -> Option<&WithOutputTensor<'a, i16>> { + if let Self::Int16(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i32>` which derefs into an `ArrayView`. + pub fn int32_array(&self) -> Option<&WithOutputTensor<'a, i32>> { + if let Self::Int32(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, i64>` which derefs into an `ArrayView`. + pub fn int64_array(&self) -> Option<&WithOutputTensor<'a, i64>> { + if let Self::Int64(item) = self { + Some(item) + } else { + None + } + } + + /// Return `WithOutputTensor<'a, String>` which derefs into an `ArrayView`. + pub fn string_array(&self) -> Option<&WithOutputTensor<'a, String>> { + if let Self::String(item) = self { + Some(item) + } else { + None + } + } +} + +impl<'a> TryFrom for OrtOutput<'a> { + type Error = OrtError; + + fn try_from(value: OrtOutputTensor) -> Result> { + unsafe { + let mut shape_info = std::ptr::null_mut(); + + let status = g_ort().GetTensorTypeAndShape.unwrap()(value.tensor_ptr, &mut shape_info); + + status_to_result(status).map_err(OrtError::IsTensor)?; + + assert_ne!(shape_info, std::ptr::null_mut()); + + let mut element_type = + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + let status = g_ort().GetTensorElementType.unwrap()(shape_info, &mut element_type); + + status_to_result(status).map_err(OrtError::IsTensor)?; + + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(shape_info); + + match element_type { + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => { + WithOutputTensor::try_from(value).map(OrtOutput::Float) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt8) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int8) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt16) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int16) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int32) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => { + WithOutputTensor::try_from(value).map(OrtOutput::Int64) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => { + WithOutputTensor::try_from(value).map(OrtOutput::String) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => { + WithOutputTensor::try_from(value).map(OrtOutput::Double) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt32) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => { + WithOutputTensor::try_from(value).map(OrtOutput::UInt64) + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => { + unimplemented!() + } + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => { + unimplemented!() + } + } + } + } +} diff --git a/rust/onnxruntime/src/tensor/ort_owned_tensor.rs b/rust/onnxruntime/src/tensor/ort_owned_tensor.rs deleted file mode 100644 index a95f6cefecfd3..0000000000000 --- a/rust/onnxruntime/src/tensor/ort_owned_tensor.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! Module containing tensor with memory owned by the ONNX Runtime - -use std::{fmt::Debug, ops::Deref}; - -use ndarray::{Array, ArrayView}; -use tracing::debug; - -use onnxruntime_sys as sys; - -use crate::{ - error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor, - OrtError, Result, TypeToTensorElementDataType, -}; - -/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. -/// -/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method. -/// It is not meant to be created directly. -/// -/// The tensor hosts an [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html) -/// of the data on the C side. This allows manipulation on the Rust side using `ndarray` without copying the data. -/// -/// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to -/// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). -#[derive(Debug)] -#[allow(dead_code)] // This is to appease clippy as `memory_info` is not read. -pub struct OrtOwnedTensor<'t, 'm, T, D> -where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, - 'm: 't, // 'm outlives 't -{ - pub(crate) tensor_ptr: *mut sys::OrtValue, - array_view: ArrayView<'t, T, D>, - memory_info: &'m MemoryInfo, -} - -impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D> -where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, -{ - type Target = ArrayView<'t, T, D>; - - fn deref(&self) -> &Self::Target { - &self.array_view - } -} - -impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D> -where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, -{ - /// Apply a softmax on the specified axis - pub fn softmax(&self, axis: ndarray::Axis) -> Array - where - D: ndarray::RemoveAxis, - T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, - { - self.array_view.softmax(axis) - } -} - -#[derive(Debug)] -pub(crate) struct OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) tensor_ptr: *mut sys::OrtValue, - memory_info: &'m MemoryInfo, - shape: D, -} - -impl<'m, D> OrtOwnedTensorExtractor<'m, D> -where - D: ndarray::Dimension, -{ - pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> { - OrtOwnedTensorExtractor { - tensor_ptr: std::ptr::null_mut(), - memory_info, - shape, - } - } - - pub(crate) fn extract<'t, T>(self) -> Result> - where - T: TypeToTensorElementDataType + Debug + Clone, - { - // Note: Both tensor and array will point to the same data, nothing is copied. - // As such, there is no need too free the pointer used to create the ArrayView. - - assert_ne!(self.tensor_ptr, std::ptr::null_mut()); - - let mut is_tensor = 0; - let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; - status_to_result(status).map_err(OrtError::IsTensor)?; - (is_tensor == 1) - .then_some(()) - .ok_or(OrtError::IsTensorCheck)?; - - // Get pointer to output tensor float values - let mut output_array_ptr: *mut T = std::ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = - output_array_ptr_ptr.cast::<*mut std::ffi::c_void>(); - let status = unsafe { - g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void) - }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_ne!(output_array_ptr, std::ptr::null_mut()); - - let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) }; - - Ok(OrtOwnedTensor { - tensor_ptr: self.tensor_ptr, - array_view, - memory_info: self.memory_info, - }) - } -} - -impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D> -where - T: TypeToTensorElementDataType + Debug + Clone, - D: ndarray::Dimension, - 'm: 't, // 'm outlives 't -{ - #[tracing::instrument] - fn drop(&mut self) { - debug!("Dropping OrtOwnedTensor."); - unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) } - - self.tensor_ptr = std::ptr::null_mut(); - } -} diff --git a/rust/onnxruntime/tests/integration_tests.rs b/rust/onnxruntime/tests/integration_tests.rs index 23b99c09ba912..1ddaf0ce72d23 100644 --- a/rust/onnxruntime/tests/integration_tests.rs +++ b/rust/onnxruntime/tests/integration_tests.rs @@ -1,3 +1,6 @@ +use onnxruntime::error::OrtDownloadError; +use onnxruntime::tensor::ndarray_tensor::NdArrayTensor; + use std::{ fs, io::{self, BufRead, BufReader}, @@ -5,9 +8,6 @@ use std::{ time::Duration, }; -use onnxruntime::error::OrtDownloadError; -use onnxruntime::tensor::OrtOwnedTensor; - mod download { use super::*; @@ -18,6 +18,7 @@ mod download { use onnxruntime::{ download::vision::{DomainBasedImageClassification, ImageClassification}, environment::Environment, + tensor::construct::ConstructTensor, GraphOptimizationLevel, LoggingLevel, }; @@ -93,16 +94,14 @@ mod download { } // Batch of 1 - let input_tensor_values = vec![array]; + let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs = session.run(input_tensor_values).unwrap(); // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. - let output: &OrtOwnedTensor = &outputs[0]; + let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .iter() @@ -176,7 +175,7 @@ mod download { ) .to_luma8(); - let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let mut array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); let channels = pixel.channels(); @@ -185,14 +184,12 @@ mod download { }); // Batch of 1 - let input_tensor_values = vec![array]; + let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs = session.run(input_tensor_values).unwrap(); - let output: &OrtOwnedTensor = &outputs[0]; + let output = outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .iter() @@ -263,23 +260,22 @@ mod download { ) .to_luma8(); - let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { - let pixel = image_buffer.get_pixel(i as u32, j as u32); - let channels = pixel.channels(); + let mut array = + ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); - // range [0, 255] -> range [0, 1] - (channels[c] as f32) / 255.0 - }); + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); // Batch of 1 - let input_tensor_values = vec![array]; + let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs = session.run(input_tensor_values).unwrap(); - let output: &OrtOwnedTensor = &outputs[0]; + let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .iter() @@ -353,23 +349,22 @@ mod download { ) .to_luma8(); - let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { - let pixel = image_buffer.get_pixel(i as u32, j as u32); - let channels = pixel.channels(); + let mut array = + ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { + let pixel = image_buffer.get_pixel(i as u32, j as u32); + let channels = pixel.channels(); - // range [0, 255] -> range [0, 1] - (channels[c] as f32) / 255.0 - }); + // range [0, 255] -> range [0, 1] + (channels[c] as f32) / 255.0 + }); // Batch of 1 - let input_tensor_values = vec![array]; + let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs = session.run(input_tensor_values).unwrap(); - let output: &OrtOwnedTensor = &outputs[0]; + let output = &outputs[0].float_array().unwrap(); let mut probabilities: Vec<(usize, f32)> = output .softmax(ndarray::Axis(1)) .iter() @@ -461,7 +456,7 @@ mod download { .unwrap() .to_rgb8(); - let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| { + let mut array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); let channels = pixel.channels(); @@ -470,15 +465,13 @@ mod download { }); // Just one input - let input_tensor_values = vec![array]; + let input_tensor_values: Vec<&mut dyn ConstructTensor> = vec![&mut array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor>, - > = session.run(input_tensor_values).unwrap(); + let outputs = session.run(input_tensor_values).unwrap(); assert_eq!(outputs.len(), 1); - let output = &outputs[0]; + let output = outputs[0].float_array().unwrap(); // The image should have doubled in size assert_eq!(output.shape(), [1, 448, 448, 3]);