Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model config, with custom remote tokenizer #35

Merged
merged 2 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 2 additions & 56 deletions apps/desktop/src-tauri/src/inference_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@ use actix_web::web::{Bytes, Json};

use actix_web::{get, post, App, HttpResponse, HttpServer, Responder};
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use tauri::AppHandle;
use serde::Serialize;

use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use crate::abort_stream::AbortStream;
use crate::config::ConfigKey;
use crate::inference::thread::{
start_inference, CompletionRequest, InferenceThreadRequest,
};
use crate::model_pool::{self, spawn_pool};
use crate::model_stats;
use crate::path::get_app_dir_path_buf;
use llm::VocabularySource;
use crate::model_pool;

#[derive(Default)]
pub struct State {
Expand Down Expand Up @@ -152,51 +146,3 @@ pub async fn stop_server<'a>(

Ok(())
}

#[derive(Serialize, Deserialize)]
pub struct ModelVocabulary {
/// Local path to vocabulary
pub vocabulary_path: Option<PathBuf>,

/// Remote HuggingFace repository containing vocabulary
pub vocabulary_repository: Option<String>,
}
impl ModelVocabulary {
pub fn to_source(&self) -> VocabularySource {
match (&self.vocabulary_path, &self.vocabulary_repository) {
(Some(path), None) => {
VocabularySource::HuggingFaceTokenizerFile(path.to_owned())
}
(None, Some(repo)) => {
VocabularySource::HuggingFaceRemote(repo.to_owned())
}
(_, _) => VocabularySource::Model,
}
}
}

#[tauri::command]
pub async fn load_model<'a>(
model_stats_bucket_state: tauri::State<'_, model_stats::State>,
config_state: tauri::State<'_, crate::config::State>,
app_handle: AppHandle,
path: &str,
model_type: &str,
model_vocabulary: ModelVocabulary,
concurrency: usize,
) -> Result<(), String> {
config_state.set(ConfigKey::OnboardState, format!("done"))?;
model_stats::increment_load_count(model_stats_bucket_state, path)?;

let cache_dir =
get_app_dir_path_buf(app_handle, String::from("inference_cache"))?;

spawn_pool(
path,
model_type,
&model_vocabulary.to_source(),
concurrency,
&cache_dir,
)
.await
}
2 changes: 1 addition & 1 deletion apps/desktop/src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fn main() {
model_config::set_model_config,
inference_server::start_server,
inference_server::stop_server,
inference_server::load_model,
model_pool::load_model,
model_type::get_model_type,
model_type::set_model_type,
test::test_model,
Expand Down
53 changes: 32 additions & 21 deletions apps/desktop/src-tauri/src/model_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ use std::sync::Arc;

use tauri::Manager;

#[derive(Serialize, Deserialize, PartialEq, Clone)]
#[derive(Serialize, Deserialize, PartialEq, Clone, Default)]
pub struct ModelConfig {
pub vocabulary: String,
pub tokenizer: String,

#[serde(rename = "defaultPromptTemplate")]
pub default_prompt_template: String,
}

#[derive(Clone)]
Expand All @@ -26,37 +29,45 @@ impl State {
app.manage(State(Arc::new(Mutex::new(bucket))));
Ok(())
}

pub fn get(&self, path: &str) -> Result<ModelConfig, String> {
let file_path = String::from(path);
let bucket = self.0.lock();

match bucket.get(&file_path) {
Ok(Some(value)) => return Ok(value.0),
Ok(None) => Ok(ModelConfig::default()),
Err(e) => Err(format!("Error retrieving model type for {}: {}", path, e)),
}
}

pub fn set(&self, path: &str, data: ModelConfig) -> Result<(), String> {
let file_path = String::from(path);
let bucket = self.0.lock();

bucket
.set(&file_path, &Json(data))
.map_err(|e| format!("{}", e))?;

bucket.flush().map_err(|e| format!("{}", e))?;

Ok(())
}
}

#[tauri::command]
pub fn get_model_config(
state: tauri::State<'_, State>,
path: &str,
) -> Result<ModelConfig, String> {
let file_path = String::from(path);
let bucket = state.0.lock();

match bucket.get(&file_path) {
Ok(Some(value)) => return Ok(value.0),
Ok(None) => Err(format!("No model config for {}", path)),
Err(e) => Err(format!("Error retrieving model type for {}: {}", path, e)),
}
state.get(path)
}

#[tauri::command]
pub async fn set_model_config(
state: tauri::State<'_, State>,
path: &str,
data: ModelConfig,
config: ModelConfig,
) -> Result<(), String> {
let file_path = String::from(path);
let bucket = state.0.lock();

bucket
.set(&file_path, &Json(data))
.map_err(|e| format!("{}", e))?;

bucket.flush().map_err(|e| format!("{}", e))?;

Ok(())
state.set(path, config)
}
50 changes: 49 additions & 1 deletion apps/desktop/src-tauri/src/model_pool.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use tauri::AppHandle;

use std::{fs, path::PathBuf, sync::Arc};

use llm::{load_progress_callback_stdout, ModelArchitecture, VocabularySource};

use std::path::Path;

use crate::inference::thread::ModelGuard;
use crate::{inference::thread::ModelGuard, model_stats};
use std::collections::VecDeque;

use crate::config::ConfigKey;
use crate::path::get_app_dir_path_buf;

pub static LOADED_MODEL_POOL: Lazy<Mutex<VecDeque<Option<ModelGuard>>>> =
Lazy::new(|| Mutex::new(VecDeque::new()));

Expand Down Expand Up @@ -114,3 +118,47 @@ pub async fn spawn_pool(

Ok(())
}

fn get_vocab_source(vocab: String) -> VocabularySource {
match vocab {
v if {
let path = Path::new(&v);
path.is_absolute() && path.exists() && path.is_file()
} =>
{
VocabularySource::HuggingFaceTokenizerFile(PathBuf::from(v))
}
v if v.len() > 0 => VocabularySource::HuggingFaceRemote(v),
_ => VocabularySource::Model,
}
}

#[tauri::command]
pub async fn load_model<'a>(
model_stats_bucket_state: tauri::State<'_, model_stats::State>,
config_state: tauri::State<'_, crate::config::State>,
model_type_bucket_state: tauri::State<'_, crate::model_type::State>,
model_config_bucket_state: tauri::State<'_, crate::model_config::State>,
app_handle: AppHandle,
path: &str,
concurrency: usize,
) -> Result<(), String> {
config_state.set(ConfigKey::OnboardState, format!("done"))?;
model_stats_bucket_state.increment_load_count(path)?;

let cache_dir =
get_app_dir_path_buf(app_handle, String::from("inference_cache"))?;

let model_type = model_type_bucket_state.get(path)?;

let model_config = model_config_bucket_state.get(path)?;

spawn_pool(
path,
model_type.as_str(),
&get_vocab_source(model_config.tokenizer),
concurrency,
&cache_dir,
)
.await
}
61 changes: 30 additions & 31 deletions apps/desktop/src-tauri/src/model_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,46 +27,45 @@ impl State {
app.manage(State(Arc::new(Mutex::new(bucket))));
Ok(())
}
}

pub fn increment_load_count(
state: tauri::State<'_, State>,
path: &str,
) -> Result<(), String> {
let current_value = get_model_stats(state.clone(), path)?;
pub fn increment_load_count(&self, path: &str) -> Result<(), String> {
let current_value = self.get(path)?;
let bucket = self.0.lock();
let file_path = String::from(path);

bucket
.set(
&file_path,
&Json(ModelStats {
load_count: current_value.load_count + 1,
}),
)
.map_err(|e| format!("{}", e))?;

bucket.flush().map_err(|e| format!("{}", e))?;
Ok(())
}

let bucket = state.0.lock();
let file_path = String::from(path);
pub fn get(&self, path: &str) -> Result<ModelStats, String> {
let bucket = self.0.lock();

bucket
.set(
&file_path,
&Json(ModelStats {
load_count: current_value.load_count + 1,
}),
)
.map_err(|e| format!("{}", e))?;
let file_path = String::from(path);

bucket.flush().map_err(|e| format!("{}", e))?;
Ok(())
match bucket.get(&file_path) {
Ok(Some(value)) => Ok(value.0),
Ok(None) => Ok(Default::default()),
Err(e) => {
println!("Error retrieving model stats for {}: {}", path, e);
return Ok(Default::default());
}
}
}
}

#[tauri::command]
pub fn get_model_stats(
state: tauri::State<'_, State>,
path: &str,
) -> Result<ModelStats, String> {
let bucket = state.0.lock();

let file_path = String::from(path);

match bucket.get(&file_path) {
Ok(Some(value)) => Ok(value.0),
Ok(None) => Ok(Default::default()),
Err(e) => {
println!("Error retrieving model stats for {}: {}", path, e);
bucket.clear().unwrap();
return Ok(Default::default());
}
}
state.get(path)
}
47 changes: 27 additions & 20 deletions apps/desktop/src-tauri/src/model_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,39 @@ impl State {
app.manage(State(Arc::new(Mutex::new(bucket))));
Ok(())
}

pub fn get(&self, path: &str) -> Result<String, String> {
let bucket = self.0.lock();

let file_path = String::from(path);

match bucket.get(&file_path) {
Ok(Some(value)) => Ok(value),
_ => Ok("llama".to_string()),
}
}

pub fn set(&self, path: &str, model_type: &str) -> Result<(), String> {
let model_type_bucket = self.0.lock();

let file_path = String::from(path);

model_type_bucket
.set(&file_path, &String::from(model_type))
.map_err(|e| format!("{}", e))?;

model_type_bucket.flush().map_err(|e| format!("{}", e))?;

Ok(())
}
}

#[tauri::command]
pub fn get_model_type(
state: tauri::State<'_, State>,
path: &str,
) -> Result<String, String> {
let model_type_bucket = state.0.lock();

let file_path = String::from(path);

match model_type_bucket.get(&file_path) {
Ok(Some(value)) => return Ok(value),
Ok(None) => Err(format!("No cached model type for {}", path)),
Err(e) => Err(format!("Error retrieving model type for {}: {}", path, e)),
}
state.get(path)
}

#[tauri::command]
Expand All @@ -43,15 +60,5 @@ pub async fn set_model_type(
path: &str,
model_type: &str,
) -> Result<(), String> {
let model_type_bucket = state.0.lock();

let file_path = String::from(path);

model_type_bucket
.set(&file_path, &String::from(model_type))
.map_err(|e| format!("{}", e))?;

model_type_bucket.flush().map_err(|e| format!("{}", e))?;

Ok(())
state.set(path, model_type)
}
Loading