Skip to content

Commit

Permalink
[rust] Make Environment name a CString
Browse files Browse the repository at this point in the history
We are passing a CString ptr to onnxruntime. This CString must
be stored.
  • Loading branch information
boydjohnson committed Oct 3, 2022
1 parent 5b46fab commit 80144cc
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions rust/onnxruntime/src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ lazy_static! {

#[derive(Debug)]
pub(crate) struct EnvironmentSingleton {
name: String,
name: Option<CString>,
pub(crate) env_ptr: *mut sys::OrtEnv,
}

Expand All @@ -38,7 +38,7 @@ impl Drop for EnvironmentSingleton {
impl Default for EnvironmentSingleton {
fn default() -> Self {
EnvironmentSingleton {
name: String::from("uninitialized"),
name: None,
env_ptr: std::ptr::null_mut(),
}
}
Expand Down Expand Up @@ -88,15 +88,23 @@ impl Environment {
/// Return the name of the current environment
#[must_use]
pub fn name(&self) -> String {
self.env.lock().unwrap().name.to_string()
self.env
.lock()
.unwrap()
.name
.as_ref()
.unwrap()
.to_str()
.unwrap()
.to_string()
}

pub(crate) fn env(&self) -> Arc<Mutex<EnvironmentSingleton>> {
Arc::clone(&self.env)
}

#[tracing::instrument]
fn new(name: String, log_level: LoggingLevel) -> Result<Environment> {
fn new(name: &str, log_level: LoggingLevel) -> Result<Environment> {
// NOTE: Because 'G_ENV' is a lazy_static, locking it will, initially, create
// a new Arc<Mutex<EnvironmentSingleton>> with a strong count of 1.
// Cloning it to embed it inside the 'Environment' to return
Expand All @@ -114,7 +122,7 @@ impl Environment {
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();

let cname = CString::new(name.clone()).unwrap();
let cname = CString::new(name).unwrap();

let create_env_with_custom_logger = g_ort().CreateEnvWithCustomLogger.unwrap();
let status = {
Expand All @@ -137,7 +145,7 @@ impl Environment {
);

*g_env_ptr = env_ptr;
environment_guard.name = name;
environment_guard.name = Some(cname);

// NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one.
// If this 'Environment' is the only one in the process, the strong count
Expand All @@ -147,7 +155,7 @@ impl Environment {
Ok(Environment { env: G_ENV.clone() })
} else {
warn!(
name = environment_guard.name.as_str(),
name = environment_guard.name.as_ref().unwrap().to_str().unwrap(),
env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(),
"Environment already initialized, reusing it.",
);
Expand Down Expand Up @@ -239,7 +247,7 @@ impl EnvBuilder {

/// Commit the configuration to a new [`Environment`](environment/struct.Environment.html)
pub fn build(self) -> Result<Environment> {
Environment::new(self.name, self.log_level)
Environment::new(&self.name, self.log_level)
}
}

Expand Down Expand Up @@ -324,13 +332,12 @@ mod tests {
fn concurrent_environment_creations() {
let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run();

let initial_name = String::from("concurrent_environment_creation");
let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning).unwrap();
let initial_name = "concurrent_environment_creation";
let main_env = Environment::new(initial_name, LoggingLevel::Warning).unwrap();
let main_env_ptr = main_env.env().lock().unwrap().env_ptr as usize;

let children: Vec<_> = (0..10)
.map(|t| {
let initial_name_cloned = initial_name.clone();
std::thread::spawn(move || {
let name = format!("concurrent_environment_creation: {}", t);
let env = Environment::builder()
Expand All @@ -339,7 +346,7 @@ mod tests {
.build()
.unwrap();

assert_eq!(env.name(), initial_name_cloned);
assert_eq!(env.name(), initial_name.to_string());
assert_eq!(env.env().lock().unwrap().env_ptr as usize, main_env_ptr);
})
})
Expand Down

0 comments on commit 80144cc

Please sign in to comment.