diff --git a/rust/onnxruntime/src/environment.rs b/rust/onnxruntime/src/environment.rs index fcc5d90d6da25..68baf44cc5f37 100644 --- a/rust/onnxruntime/src/environment.rs +++ b/rust/onnxruntime/src/environment.rs @@ -22,7 +22,7 @@ lazy_static! { #[derive(Debug)] pub(crate) struct EnvironmentSingleton { - name: String, + name: Option, pub(crate) env_ptr: *mut sys::OrtEnv, } @@ -38,7 +38,7 @@ impl Drop for EnvironmentSingleton { impl Default for EnvironmentSingleton { fn default() -> Self { EnvironmentSingleton { - name: String::from("uninitialized"), + name: None, env_ptr: std::ptr::null_mut(), } } @@ -88,7 +88,15 @@ impl Environment { /// Return the name of the current environment #[must_use] pub fn name(&self) -> String { - self.env.lock().unwrap().name.to_string() + self.env + .lock() + .unwrap() + .name + .as_ref() + .unwrap() + .to_str() + .unwrap() + .to_string() } pub(crate) fn env(&self) -> Arc> { @@ -96,7 +104,7 @@ impl Environment { } #[tracing::instrument] - fn new(name: String, log_level: LoggingLevel) -> Result { + fn new(name: &str, log_level: LoggingLevel) -> Result { // NOTE: Because 'G_ENV' is a lazy_static, locking it will, initially, create // a new Arc> with a strong count of 1. // Cloning it to embed it inside the 'Environment' to return @@ -114,7 +122,7 @@ impl Environment { // FIXME: What should go here? let logger_param: *mut std::ffi::c_void = std::ptr::null_mut(); - let cname = CString::new(name.clone()).unwrap(); + let cname = CString::new(name).unwrap(); let create_env_with_custom_logger = g_ort().CreateEnvWithCustomLogger.unwrap(); let status = { @@ -137,7 +145,7 @@ impl Environment { ); *g_env_ptr = env_ptr; - environment_guard.name = name; + environment_guard.name = Some(cname); // NOTE: Cloning the lazy_static 'G_ENV' will increase its strong count by one. // If this 'Environment' is the only one in the process, the strong count @@ -147,7 +155,7 @@ impl Environment { Ok(Environment { env: G_ENV.clone() }) } else { warn!( - name = environment_guard.name.as_str(), + name = environment_guard.name.as_ref().unwrap().to_str().unwrap(), env_ptr = format!("{:?}", environment_guard.env_ptr).as_str(), "Environment already initialized, reusing it.", ); @@ -239,7 +247,7 @@ impl EnvBuilder { /// Commit the configuration to a new [`Environment`](environment/struct.Environment.html) pub fn build(self) -> Result { - Environment::new(self.name, self.log_level) + Environment::new(&self.name, self.log_level) } } @@ -324,13 +332,12 @@ mod tests { fn concurrent_environment_creations() { let _concurrent_run_lock_guard = CONCURRENT_TEST_RUN.single_test_run(); - let initial_name = String::from("concurrent_environment_creation"); - let main_env = Environment::new(initial_name.clone(), LoggingLevel::Warning).unwrap(); + let initial_name = "concurrent_environment_creation"; + let main_env = Environment::new(initial_name, LoggingLevel::Warning).unwrap(); let main_env_ptr = main_env.env().lock().unwrap().env_ptr as usize; let children: Vec<_> = (0..10) .map(|t| { - let initial_name_cloned = initial_name.clone(); std::thread::spawn(move || { let name = format!("concurrent_environment_creation: {}", t); let env = Environment::builder() @@ -339,7 +346,7 @@ mod tests { .build() .unwrap(); - assert_eq!(env.name(), initial_name_cloned); + assert_eq!(env.name(), initial_name.to_string()); assert_eq!(env.env().lock().unwrap().env_ptr as usize, main_env_ptr); }) })