-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Rust]onnxruntimeをrust版voicevox_coreに導入 (#135)
* Statusを仮実装した * get_supported_devices を仮実装した * エラーの名前とsourceの修正 * initialize関数を仮実装した * get_supported_devicesが動作するかどうか確認するためのテストを追加 * lint errorを修正した * version up onnxruntime-rs to 0.0.17 * TODOコメント追加 * Update crates/voicevox_core/Cargo.toml Co-authored-by: Hiroshiba <[email protected]> Co-authored-by: Hiroshiba <[email protected]>
- Loading branch information
1 parent
04b91c7
commit 2a8b209
Showing
6 changed files
with
284 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
use super::*; | ||
use once_cell::sync::Lazy; | ||
use onnxruntime::{ | ||
environment::Environment, session::Session, GraphOptimizationLevel, LoggingLevel, | ||
}; | ||
cfg_if! { | ||
if #[cfg(not(feature="directml"))]{ | ||
use onnxruntime::CudaProviderOptions; | ||
} | ||
} | ||
use std::collections::BTreeMap; | ||
|
||
pub struct Status { | ||
models: StatusModels, | ||
session_options: SessionOptions, | ||
} | ||
|
||
struct StatusModels { | ||
yukarin_s: BTreeMap<usize, Session<'static>>, | ||
yukarin_sa: BTreeMap<usize, Session<'static>>, | ||
decode: BTreeMap<usize, Session<'static>>, | ||
} | ||
|
||
#[derive(new, Getters)] | ||
struct SessionOptions { | ||
cpu_num_threads: usize, | ||
use_gpu: bool, | ||
} | ||
|
||
struct Model { | ||
yukarin_s_model: &'static [u8], | ||
yukarin_sa_model: &'static [u8], | ||
decode_model: &'static [u8], | ||
} | ||
|
||
static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| { | ||
cfg_if! { | ||
if #[cfg(debug_assertions)]{ | ||
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose; | ||
} else{ | ||
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning; | ||
} | ||
} | ||
Environment::builder() | ||
.with_name(env!("CARGO_PKG_NAME")) | ||
.with_log_level(LOGGING_LEVEL) | ||
.build() | ||
.unwrap() | ||
}); | ||
|
||
#[derive(Getters)] | ||
pub struct SupportedDevices { | ||
// TODO:supported_devices関数を実装したらこのattributeをはずす | ||
#[allow(dead_code)] | ||
cpu: bool, | ||
// TODO:supported_devices関数を実装したらこのattributeをはずす | ||
#[allow(dead_code)] | ||
cuda: bool, | ||
// TODO:supported_devices関数を実装したらこのattributeをはずす | ||
#[allow(dead_code)] | ||
dml: bool, | ||
} | ||
|
||
impl SupportedDevices { | ||
pub fn get_supported_devices() -> Result<Self> { | ||
let mut cuda_support = false; | ||
let mut dml_support = false; | ||
for provider in onnxruntime::session::get_available_providers() | ||
.map_err(|e| Error::GetSupportedDevices(e.into()))? | ||
.iter() | ||
{ | ||
match provider.as_str() { | ||
"CUDAExecutionProvider" => cuda_support = true, | ||
"DmlExecutionProvider" => dml_support = true, | ||
_ => {} | ||
} | ||
} | ||
|
||
Ok(SupportedDevices { | ||
cpu: true, | ||
cuda: cuda_support, | ||
dml: dml_support, | ||
}) | ||
} | ||
} | ||
|
||
unsafe impl Send for Status {} | ||
unsafe impl Sync for Status {} | ||
|
||
impl Status { | ||
const YUKARIN_S_MODEL: &'static [u8] = include_bytes!(concat!( | ||
env!("CARGO_WORKSPACE_DIR"), | ||
"/model/yukarin_s.onnx" | ||
)); | ||
const YUKARIN_SA_MODEL: &'static [u8] = include_bytes!(concat!( | ||
env!("CARGO_WORKSPACE_DIR"), | ||
"/model/yukarin_sa.onnx" | ||
)); | ||
|
||
const DECODE_MODEL: &'static [u8] = | ||
include_bytes!(concat!(env!("CARGO_WORKSPACE_DIR"), "/model/decode.onnx")); | ||
|
||
const MODELS: [Model; 1] = [Model { | ||
yukarin_s_model: Self::YUKARIN_S_MODEL, | ||
yukarin_sa_model: Self::YUKARIN_SA_MODEL, | ||
decode_model: Self::DECODE_MODEL, | ||
}]; | ||
pub const MODELS_COUNT: usize = Self::MODELS.len(); | ||
|
||
pub fn new(use_gpu: bool, cpu_num_threads: usize) -> Self { | ||
Self { | ||
models: StatusModels { | ||
yukarin_s: BTreeMap::new(), | ||
yukarin_sa: BTreeMap::new(), | ||
decode: BTreeMap::new(), | ||
}, | ||
session_options: SessionOptions::new(cpu_num_threads, use_gpu), | ||
} | ||
} | ||
|
||
pub fn load_model(&mut self, model_index: usize) -> Result<()> { | ||
let model = &Self::MODELS[model_index]; | ||
let yukarin_s_session = self | ||
.new_session(model.yukarin_s_model) | ||
.map_err(Error::LoadModel)?; | ||
let yukarin_sa_session = self | ||
.new_session(model.yukarin_sa_model) | ||
.map_err(Error::LoadModel)?; | ||
let decode_model = self | ||
.new_session(model.decode_model) | ||
.map_err(Error::LoadModel)?; | ||
|
||
self.models.yukarin_s.insert(model_index, yukarin_s_session); | ||
self.models | ||
.yukarin_sa | ||
.insert(model_index, yukarin_sa_session); | ||
|
||
self.models.decode.insert(model_index, decode_model); | ||
|
||
Ok(()) | ||
} | ||
|
||
fn new_session<B: AsRef<[u8]>>( | ||
&self, | ||
model_bytes: B, | ||
) -> std::result::Result<Session<'static>, anyhow::Error> { | ||
let session_builder = ENVIRONMENT | ||
.new_session_builder()? | ||
.with_optimization_level(GraphOptimizationLevel::Basic)? | ||
.with_intra_op_num_threads(*self.session_options.cpu_num_threads() as i32)? | ||
.with_inter_op_num_threads(*self.session_options.cpu_num_threads() as i32)?; | ||
|
||
let session_builder = if *self.session_options.use_gpu() { | ||
cfg_if! { | ||
if #[cfg(feature = "directml")]{ | ||
session_builder | ||
.with_disable_mem_pattern()? | ||
.with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)? | ||
} else { | ||
let options = CudaProviderOptions::default(); | ||
session_builder | ||
.with_disable_mem_pattern()? | ||
.with_append_execution_provider_cuda(options)? | ||
} | ||
} | ||
} else { | ||
session_builder | ||
}; | ||
|
||
Ok(session_builder.with_model_from_memory(model_bytes)?) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[rstest] | ||
fn supported_devices_get_supported_devices_works() { | ||
let result = SupportedDevices::get_supported_devices(); | ||
// 環境によって結果が変わるので、関数呼び出しが成功するかどうかの確認のみ行う | ||
assert!(result.is_ok()); | ||
} | ||
} |