Skip to content

Commit

Permalink
Merge branch 'remove-panics-from-session'
Browse files Browse the repository at this point in the history
  • Loading branch information
nbigaouette committed Aug 4, 2021
2 parents a9358fd + b3e1a26 commit 88c6bab
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 46 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
};

use lazy_static::lazy_static;
use tracing::{debug, warn};
use tracing::{debug, error, warn};

use onnxruntime_sys as sys;

Expand Down Expand Up @@ -182,7 +182,11 @@ impl Drop for Environment {
);

assert_ne!(env_ptr, std::ptr::null_mut());
unsafe { release_env(env_ptr) };
if env_ptr.is_null() {
error!("Environment pointer is null, not dropping!");
} else {
unsafe { release_env(env_ptr) };
}

environment_guard.env_ptr = AtomicPtr::new(std::ptr::null_mut());
environment_guard.name = String::from("uninitialized");
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ pub enum OrtError {
/// Attempt to build a Rust `CString` from a null pointer
#[error("Failed to build CString when original contains null: {0}")]
CStringNulError(#[from] std::ffi::NulError),
#[error("{0} pointer should be null")]
/// Ort Pointer should have been null
PointerShouldBeNull(String),
/// Ort pointer should not have been null
#[error("{0} pointer should not be null")]
PointerShouldNotBeNull(String),
/// ONNX Model has invalid dimensions
#[error("Invalid dimensions")]
InvalidDimensions,
/// The runtime type was undefined
#[error("Undefined Tensor Element Type")]
UndefinedTensorElementType,
/// Error occurred when checking if ONNX tensor was properly initialized
#[error("Failed to check if tensor")]
IsTensorCheck,
}

/// Error used when dimensions of input (from model and from inference call)
Expand Down Expand Up @@ -176,6 +191,18 @@ impl From<*const sys::OrtStatus> for OrtStatusWrapper {
}
}

pub(crate) fn assert_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
ptr.is_null()
.then(|| ())
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
}

pub(crate) fn assert_not_null_pointer<T>(ptr: *const T, name: &str) -> Result<()> {
(!ptr.is_null())
.then(|| ())
.ok_or_else(|| OrtError::PointerShouldBeNull(name.to_owned()))
}

impl From<OrtStatusWrapper> for std::result::Result<(), OrtApiError> {
fn from(status: OrtStatusWrapper) -> Self {
if status.0.is_null() {
Expand Down
16 changes: 10 additions & 6 deletions onnxruntime/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use tracing::debug;

use onnxruntime_sys as sys;

use tracing::error;

use crate::{
error::{status_to_result, OrtError, Result},
error::{assert_not_null_pointer, status_to_result, OrtError, Result},
g_ort, AllocatorType, MemType,
};

Expand All @@ -25,7 +27,7 @@ impl MemoryInfo {
)
};
status_to_result(status).map_err(OrtError::CreateCpuMemoryInfo)?;
assert_ne!(memory_info_ptr, std::ptr::null_mut());
assert_not_null_pointer(memory_info_ptr, "MemoryInfo")?;

Ok(Self {
ptr: memory_info_ptr,
Expand All @@ -36,10 +38,12 @@ impl MemoryInfo {
impl Drop for MemoryInfo {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the memory information.");
assert_ne!(self.ptr, std::ptr::null_mut());

unsafe { g_ort().ReleaseMemoryInfo.unwrap()(self.ptr) };
if self.ptr.is_null() {
error!("MemoryInfo pointer is null, not dropping.");
} else {
debug!("Dropping the memory information.");
unsafe { g_ort().ReleaseMemoryInfo.unwrap()(self.ptr) };
}

self.ptr = std::ptr::null_mut();
}
Expand Down
72 changes: 43 additions & 29 deletions onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use onnxruntime_sys as sys;
use crate::{
char_p_to_string,
environment::Environment,
error::{status_to_result, NonMatchingDimensionsError, OrtError, Result},
error::{
assert_not_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError,
OrtApiError, OrtError, Result,
},
g_ort,
memory::MemoryInfo,
tensor::{
Expand Down Expand Up @@ -73,9 +76,12 @@ pub struct SessionBuilder<'a> {
impl<'a> Drop for SessionBuilder<'a> {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session options.");
assert_ne!(self.session_options_ptr, std::ptr::null_mut());
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
if self.session_options_ptr.is_null() {
error!("Session options pointer is null, not dropping");
} else {
debug!("Dropping the session options.");
unsafe { g_ort().ReleaseSessionOptions.unwrap()(self.session_options_ptr) };
}
}
}

Expand All @@ -85,8 +91,8 @@ impl<'a> SessionBuilder<'a> {
let status = unsafe { g_ort().CreateSessionOptions.unwrap()(&mut session_options_ptr) };

status_to_result(status).map_err(OrtError::SessionOptions)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_options_ptr, std::ptr::null_mut());
assert_null_pointer(status, "SessionStatus")?;
assert_not_null_pointer(session_options_ptr, "SessionOptions")?;

Ok(SessionBuilder {
env,
Expand All @@ -105,7 +111,7 @@ impl<'a> SessionBuilder<'a> {
let status =
unsafe { g_ort().SetIntraOpNumThreads.unwrap()(self.session_options_ptr, num_threads) };
status_to_result(status).map_err(OrtError::SessionOptions)?;
assert_eq!(status, std::ptr::null_mut());
assert_null_pointer(status, "SessionStatus")?;
Ok(self)
}

Expand Down Expand Up @@ -199,14 +205,14 @@ impl<'a> SessionBuilder<'a> {
)
};
status_to_result(status).map_err(OrtError::Session)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_ptr, std::ptr::null_mut());
assert_null_pointer(status, "SessionStatus")?;
assert_not_null_pointer(session_ptr, "Session")?;

let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
status_to_result(status).map_err(OrtError::Allocator)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(allocator_ptr, std::ptr::null_mut());
assert_null_pointer(status, "SessionStatus")?;
assert_not_null_pointer(allocator_ptr, "Allocator")?;

let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;

Expand Down Expand Up @@ -255,14 +261,14 @@ impl<'a> SessionBuilder<'a> {
)
};
status_to_result(status).map_err(OrtError::Session)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(session_ptr, std::ptr::null_mut());
assert_null_pointer(status, "SessionStatus")?;
assert_not_null_pointer(session_ptr, "Session")?;

let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
let status = unsafe { g_ort().GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr) };
status_to_result(status).map_err(OrtError::Allocator)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(allocator_ptr, std::ptr::null_mut());
assert_null_pointer(status, "SessionStatus")?;
assert_not_null_pointer(allocator_ptr, "Allocator")?;

let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;

Expand Down Expand Up @@ -352,7 +358,11 @@ impl<'a> Drop for Session<'a> {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session.");
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
if self.session_ptr.is_null() {
error!("Session pointer is null, not dropping.");
} else {
unsafe { g_ort().ReleaseSession.unwrap()(self.session_ptr) };
}
// FIXME: There is no C function to release the allocator?

self.session_ptr = std::ptr::null_mut();
Expand Down Expand Up @@ -453,13 +463,14 @@ impl<'a> Session<'a> {
.collect();

// Reconvert to CString so drop impl is called and memory is freed
let _: Vec<CString> = input_names_ptr
let cstrings: Result<Vec<CString>> = input_names_ptr
.into_iter()
.map(|p| {
assert_ne!(p, std::ptr::null());
unsafe { CString::from_raw(p as *mut i8) }
assert_not_null_pointer(p, "i8 for CString")?;
unsafe { Ok(CString::from_raw(p as *mut i8)) }
})
.collect();
cstrings?;

outputs
}
Expand Down Expand Up @@ -568,7 +579,9 @@ unsafe fn get_tensor_dimensions(
let mut num_dims = 0;
let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims);
status_to_result(status).map_err(OrtError::GetDimensionsCount)?;
assert_ne!(num_dims, 0);
(num_dims != 0)
.then(|| ())
.ok_or(OrtError::InvalidDimensions)?;

let mut node_dims: Vec<i64> = vec![0; num_dims as usize];
let status = g_ort().GetDimensions.unwrap()(
Expand Down Expand Up @@ -603,8 +616,10 @@ mod dangerous {
let mut num_nodes: usize = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::InOutCount)?;
assert_eq!(status, std::ptr::null_mut());
assert_ne!(num_nodes, 0);
assert_null_pointer(status, "SessionStatus")?;
(num_nodes != 0).then(|| ()).ok_or_else(|| {
OrtError::InOutCount(OrtApiError::Msg("No nodes in model".to_owned()))
})?;
Ok(num_nodes)
}

Expand Down Expand Up @@ -641,7 +656,7 @@ mod dangerous {

let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::InputName)?;
assert_ne!(name_bytes, std::ptr::null_mut());
assert_not_null_pointer(name_bytes, "InputName")?;

// FIXME: Is it safe to keep ownership of the memory?
let name = char_p_to_string(name_bytes)?;
Expand Down Expand Up @@ -692,23 +707,22 @@ mod dangerous {

let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_ne!(typeinfo_ptr, std::ptr::null_mut());
assert_not_null_pointer(typeinfo_ptr, "TypeInfo")?;

let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
let status = unsafe {
g_ort().CastTypeInfoToTensorInfo.unwrap()(typeinfo_ptr, &mut tensor_info_ptr)
};
status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?;
assert_ne!(tensor_info_ptr, std::ptr::null_mut());
assert_not_null_pointer(tensor_info_ptr, "TensorInfo")?;

let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
let status =
unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) };
status_to_result(status).map_err(OrtError::TensorElementType)?;
assert_ne!(
type_sys,
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
);
(type_sys != sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
.then(|| ())
.ok_or(OrtError::UndefinedTensorElementType)?;
// This transmute should be safe since its value is read from GetTensorElementType which we must trust.
let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) };

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/src/tensor/ort_owned_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ where
let mut is_tensor = 0;
let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) };
status_to_result(status).map_err(OrtError::IsTensor)?;
assert_eq!(is_tensor, 1);
(is_tensor == 1)
.then(|| ())
.ok_or(OrtError::IsTensorCheck)?;

// Get pointer to output tensor float values
let mut output_array_ptr: *mut T = std::ptr::null_mut();
Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/src/tensor/ort_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ use tracing::{debug, error};
use onnxruntime_sys as sys;

use crate::{
error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo,
tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType,
TypeToTensorElementDataType,
error::{assert_not_null_pointer, call_ort, status_to_result},
g_ort,
memory::MemoryInfo,
tensor::ndarray_tensor::NdArrayTensor,
OrtError, Result, TensorElementDataType, TypeToTensorElementDataType,
};

/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
Expand Down Expand Up @@ -67,7 +69,7 @@ where
// onnxruntime as is
let tensor_values_ptr: *mut std::ffi::c_void =
array.as_mut_ptr() as *mut std::ffi::c_void;
assert_ne!(tensor_values_ptr, std::ptr::null_mut());
assert_not_null_pointer(tensor_values_ptr, "TensorValues")?;

unsafe {
call_ort(|ort| {
Expand All @@ -83,7 +85,7 @@ where
})
}
.map_err(OrtError::CreateTensorWithData)?;
assert_ne!(tensor_ptr, std::ptr::null_mut());
assert_not_null_pointer(tensor_ptr, "Tensor")?;

let mut is_tensor = 0;
let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
Expand Down Expand Up @@ -134,7 +136,7 @@ where
}
}

assert_ne!(tensor_ptr, std::ptr::null_mut());
assert_not_null_pointer(tensor_ptr, "Tensor")?;

Ok(OrtTensor {
c_ptr: tensor_ptr,
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
};

use onnxruntime::error::OrtDownloadError;
use onnxruntime::tensor::OrtOwnedTensor;

mod download {
use super::*;
Expand Down Expand Up @@ -101,7 +102,8 @@ mod download {

// Downloaded model does not have a softmax as final layer; call softmax on second axis
// and iterate on resulting probabilities, creating an index to later access labels.
let mut probabilities: Vec<(usize, f32)> = outputs[0]
let output: &OrtOwnedTensor<f32, _> = &outputs[0];
let mut probabilities: Vec<(usize, f32)> = output
.softmax(ndarray::Axis(1))
.iter()
.copied()
Expand Down Expand Up @@ -190,7 +192,8 @@ mod download {
onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>,
> = session.run(input_tensor_values).unwrap();

let mut probabilities: Vec<(usize, f32)> = outputs[0]
let output: &OrtOwnedTensor<f32, _> = &outputs[0];
let mut probabilities: Vec<(usize, f32)> = output
.softmax(ndarray::Axis(1))
.iter()
.copied()
Expand Down

0 comments on commit 88c6bab

Please sign in to comment.