Skip to content

Commit

Permalink
[Rust]onnxruntimeをrust版voicevox_coreに導入 (#135)
Browse files Browse the repository at this point in the history
* Statusを仮実装した

* get_supported_devices を仮実装した

* エラーの名前とsourceの修正

* initialize関数を仮実装した

* get_supported_devicesが動作するかどうか確認するためのテストを追加

* lint errorを修正した

* version up onnxruntime-rs to 0.0.17

* TODOコメント追加

* Update crates/voicevox_core/Cargo.toml

Co-authored-by: Hiroshiba <[email protected]>

Co-authored-by: Hiroshiba <[email protected]>
  • Loading branch information
qwerty2501 and Hiroshiba authored May 25, 2022
1 parent 04b91c7 commit 2a8b209
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 6 deletions.
7 changes: 7 additions & 0 deletions crates/voicevox_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ name = "voicevox_core"
version = "0.1.0"
edition = "2021"

[features]
default = []
directml = []

[lib]
name = "core"
crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0.57"
cfg-if = "1.0.0"
derive-getters = "0.2.0"
derive-new = "0.5.9"
once_cell = "1.10.0"
onnxruntime = { git = "https://github.com/qwerty2501/onnxruntime-rs.git", version = "0.0.17" }
thiserror = "1.0.31"

[dev-dependencies]
Expand Down
24 changes: 23 additions & 1 deletion crates/voicevox_core/src/c_export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ use std::sync::Mutex;
* これはC文脈の処理と実装をわけるためと、内部実装の変更がAPIに影響を与えにくくするためである
*/

#[repr(C)]
#[repr(i32)]
#[derive(Debug, PartialEq)]
#[allow(non_camel_case_types)]
pub enum VoicevoxResultCode {
// C でのenum定義に合わせて大文字で定義している
// 出力フォーマットを変更すればRustでよく使われているUpperCamelにできるが、実際に出力されるコードとの差異をできるだけ少なくするため
VOICEVOX_RESULT_SUCCEED = 0,
VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT = 1,
VOICEVOX_RESULT_FAILED_LOAD_MODEL = 2,
VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES = 3,
VOICEVOX_RESULT_CANT_GPU_SUPPORT = 4,
}

fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) {
Expand All @@ -31,6 +34,16 @@ fn convert_result<T>(result: Result<T>) -> (Option<T>, VoicevoxResultCode) {
None,
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT,
),
Error::CantGpuSupport => {
(None, VoicevoxResultCode::VOICEVOX_RESULT_CANT_GPU_SUPPORT)
}
Error::LoadModel(_) => {
(None, VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL)
}
Error::GetSupportedDevices(_) => (
None,
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES,
),
}
}
}
Expand Down Expand Up @@ -219,6 +232,7 @@ pub extern "C" fn voicevox_error_result_to_message(
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use pretty_assertions::assert_eq;

#[rstest]
Expand All @@ -227,6 +241,14 @@ mod tests {
Err(Error::NotLoadedOpenjtalkDict),
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT
)]
#[case(
Err(Error::LoadModel(anyhow!("some load model error"))),
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_LOAD_MODEL
)]
#[case(
Err(Error::GetSupportedDevices(anyhow!("some get supported devices error"))),
VoicevoxResultCode::VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES
)]
fn convert_result_works(#[case] result: Result<()>, #[case] expected: VoicevoxResultCode) {
let (_, actual) = convert_result(result);
assert_eq!(expected, actual);
Expand Down
12 changes: 12 additions & 0 deletions crates/voicevox_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ pub enum Error {
// TODO:仮実装がlinterエラーにならないようにするための属性なのでこのenumが正式に使われる際にallow(dead_code)を取り除くこと
#[allow(dead_code)]
NotLoadedOpenjtalkDict,

#[error("{}", base_error_message(VOICEVOX_RESULT_CANT_GPU_SUPPORT))]
CantGpuSupport,

#[error("{},{0}", base_error_message(VOICEVOX_RESULT_FAILED_LOAD_MODEL))]
LoadModel(#[source] anyhow::Error),

#[error(
"{},{0}",
base_error_message(VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES)
)]
GetSupportedDevices(#[source] anyhow::Error),
}

fn base_error_message(result_code: VoicevoxResultCode) -> &'static str {
Expand Down
58 changes: 53 additions & 5 deletions crates/voicevox_core/src/internal.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,51 @@
use super::*;
use c_export::VoicevoxResultCode;
use once_cell::sync::Lazy;
use std::ffi::CStr;
use std::os::raw::c_int;
use std::sync::Mutex;

use status::*;

static INITIALIZED: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
static STATUS: Lazy<Mutex<Option<Status>>> = Lazy::new(|| Mutex::new(None));

//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと
#[allow(unused_variables)]
pub fn initialize(use_gpu: bool, cpu_num_threads: usize, load_all_models: bool) -> Result<()> {
unimplemented!()
let mut initialized = INITIALIZED.lock().unwrap();
*initialized = false;
if !use_gpu || can_support_gpu_feature()? {
let mut status_opt = STATUS.lock().unwrap();
let mut status = Status::new(use_gpu, cpu_num_threads);

// TODO: ここに status.load_metas() を呼び出すようにする
// https://github.com/VOICEVOX/voicevox_core/blob/main/core/src/core.cpp#L199-L201

if load_all_models {
for model_index in 0..Status::MODELS_COUNT {
status.load_model(model_index)?;
}
// TODO: ここにGPUメモリを確保させる処理を実装する
// https://github.com/VOICEVOX/voicevox_core/blob/main/core/src/core.cpp#L210-L219
}

*status_opt = Some(status);
*initialized = true;
Ok(())
} else {
Err(Error::CantGpuSupport)
}
}

fn can_support_gpu_feature() -> Result<bool> {
let supported_devices = SupportedDevices::get_supported_devices()?;

cfg_if! {
if #[cfg(feature = "directml")]{
Ok(*supported_devices.dml())
} else{
Ok(*supported_devices.cuda())
}
}
}

//TODO:仮実装がlinterエラーにならないようにするための属性なのでこの関数を正式に実装する際にallow(unused_variables)を取り除くこと
Expand Down Expand Up @@ -110,11 +149,20 @@ pub fn voicevox_wav_free(wav: *mut u8) -> Result<()> {

pub const fn voicevox_error_result_to_message(result_code: VoicevoxResultCode) -> &'static str {
// C APIのため、messageには必ず末尾にNULL文字を追加する
use VoicevoxResultCode::*;
match result_code {
VoicevoxResultCode::VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT => {
VOICEVOX_RESULT_NOT_LOADED_OPENJTALK_DICT => {
"voicevox_load_openjtalk_dict() を初めに呼んでください\0"
}
VOICEVOX_RESULT_FAILED_LOAD_MODEL => {
"modelデータ読み込み中にOnnxruntimeエラーが発生しました\0"
}

VOICEVOX_RESULT_CANT_GPU_SUPPORT => "GPU機能をサポートすることができません\0",
VOICEVOX_RESULT_FAILED_GET_SUPPORTED_DEVICES => {
"サポートされているデバイス情報取得中にエラーが発生しました\0"
}

VoicevoxResultCode::VOICEVOX_RESULT_SUCCEED => "エラーが発生しませんでした\0",
VOICEVOX_RESULT_SUCCEED => "エラーが発生しませんでした\0",
}
}
5 changes: 5 additions & 0 deletions crates/voicevox_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@ mod c_export;
mod error;
mod internal;
mod result;
mod status;

use error::*;
use result::*;

use derive_getters::*;
use derive_new::new;
#[cfg(test)]
use rstest::*;

use cfg_if::cfg_if;
184 changes: 184 additions & 0 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use super::*;
use once_cell::sync::Lazy;
use onnxruntime::{
environment::Environment, session::Session, GraphOptimizationLevel, LoggingLevel,
};
cfg_if! {
if #[cfg(not(feature="directml"))]{
use onnxruntime::CudaProviderOptions;
}
}
use std::collections::BTreeMap;

pub struct Status {
models: StatusModels,
session_options: SessionOptions,
}

struct StatusModels {
yukarin_s: BTreeMap<usize, Session<'static>>,
yukarin_sa: BTreeMap<usize, Session<'static>>,
decode: BTreeMap<usize, Session<'static>>,
}

#[derive(new, Getters)]
struct SessionOptions {
cpu_num_threads: usize,
use_gpu: bool,
}

struct Model {
yukarin_s_model: &'static [u8],
yukarin_sa_model: &'static [u8],
decode_model: &'static [u8],
}

static ENVIRONMENT: Lazy<Environment> = Lazy::new(|| {
cfg_if! {
if #[cfg(debug_assertions)]{
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Verbose;
} else{
const LOGGING_LEVEL: LoggingLevel = LoggingLevel::Warning;
}
}
Environment::builder()
.with_name(env!("CARGO_PKG_NAME"))
.with_log_level(LOGGING_LEVEL)
.build()
.unwrap()
});

#[derive(Getters)]
pub struct SupportedDevices {
// TODO:supported_devices関数を実装したらこのattributeをはずす
#[allow(dead_code)]
cpu: bool,
// TODO:supported_devices関数を実装したらこのattributeをはずす
#[allow(dead_code)]
cuda: bool,
// TODO:supported_devices関数を実装したらこのattributeをはずす
#[allow(dead_code)]
dml: bool,
}

impl SupportedDevices {
pub fn get_supported_devices() -> Result<Self> {
let mut cuda_support = false;
let mut dml_support = false;
for provider in onnxruntime::session::get_available_providers()
.map_err(|e| Error::GetSupportedDevices(e.into()))?
.iter()
{
match provider.as_str() {
"CUDAExecutionProvider" => cuda_support = true,
"DmlExecutionProvider" => dml_support = true,
_ => {}
}
}

Ok(SupportedDevices {
cpu: true,
cuda: cuda_support,
dml: dml_support,
})
}
}

unsafe impl Send for Status {}
unsafe impl Sync for Status {}

impl Status {
const YUKARIN_S_MODEL: &'static [u8] = include_bytes!(concat!(
env!("CARGO_WORKSPACE_DIR"),
"/model/yukarin_s.onnx"
));
const YUKARIN_SA_MODEL: &'static [u8] = include_bytes!(concat!(
env!("CARGO_WORKSPACE_DIR"),
"/model/yukarin_sa.onnx"
));

const DECODE_MODEL: &'static [u8] =
include_bytes!(concat!(env!("CARGO_WORKSPACE_DIR"), "/model/decode.onnx"));

const MODELS: [Model; 1] = [Model {
yukarin_s_model: Self::YUKARIN_S_MODEL,
yukarin_sa_model: Self::YUKARIN_SA_MODEL,
decode_model: Self::DECODE_MODEL,
}];
pub const MODELS_COUNT: usize = Self::MODELS.len();

pub fn new(use_gpu: bool, cpu_num_threads: usize) -> Self {
Self {
models: StatusModels {
yukarin_s: BTreeMap::new(),
yukarin_sa: BTreeMap::new(),
decode: BTreeMap::new(),
},
session_options: SessionOptions::new(cpu_num_threads, use_gpu),
}
}

pub fn load_model(&mut self, model_index: usize) -> Result<()> {
let model = &Self::MODELS[model_index];
let yukarin_s_session = self
.new_session(model.yukarin_s_model)
.map_err(Error::LoadModel)?;
let yukarin_sa_session = self
.new_session(model.yukarin_sa_model)
.map_err(Error::LoadModel)?;
let decode_model = self
.new_session(model.decode_model)
.map_err(Error::LoadModel)?;

self.models.yukarin_s.insert(model_index, yukarin_s_session);
self.models
.yukarin_sa
.insert(model_index, yukarin_sa_session);

self.models.decode.insert(model_index, decode_model);

Ok(())
}

fn new_session<B: AsRef<[u8]>>(
&self,
model_bytes: B,
) -> std::result::Result<Session<'static>, anyhow::Error> {
let session_builder = ENVIRONMENT
.new_session_builder()?
.with_optimization_level(GraphOptimizationLevel::Basic)?
.with_intra_op_num_threads(*self.session_options.cpu_num_threads() as i32)?
.with_inter_op_num_threads(*self.session_options.cpu_num_threads() as i32)?;

let session_builder = if *self.session_options.use_gpu() {
cfg_if! {
if #[cfg(feature = "directml")]{
session_builder
.with_disable_mem_pattern()?
.with_execution_mode(onnxruntime::ExecutionMode::ORT_SEQUENTIAL)?
} else {
let options = CudaProviderOptions::default();
session_builder
.with_disable_mem_pattern()?
.with_append_execution_provider_cuda(options)?
}
}
} else {
session_builder
};

Ok(session_builder.with_model_from_memory(model_bytes)?)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[rstest]
fn supported_devices_get_supported_devices_works() {
let result = SupportedDevices::get_supported_devices();
// 環境によって結果が変わるので、関数呼び出しが成功するかどうかの確認のみ行う
assert!(result.is_ok());
}
}

0 comments on commit 2a8b209

Please sign in to comment.