From 22592374c18f2b4d512b0a38729653a5847213e4 Mon Sep 17 00:00:00 2001 From: Eric Date: Sun, 12 Nov 2023 01:50:16 +0800 Subject: [PATCH] refactor: add experimental-http feature (#750) * add experimental-http feature, update code * refactor: add experimental-http feature --- crates/tabby/Cargo.toml | 3 +- crates/tabby/src/serve/engine.rs | 55 ++++++++++++++++---------------- crates/tabby/src/serve/mod.rs | 30 ++++++++++------- 3 files changed, 49 insertions(+), 39 deletions(-) diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 0f670dc46439..beee03141eee 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [features] cuda = ["llama-cpp-bindings/cuda"] +experimental-http = ["dep:http-api-bindings"] [dependencies] tabby-common = { path = "../tabby-common" } @@ -36,7 +37,7 @@ tantivy = { workspace = true } anyhow = { workspace = true } sysinfo = "0.29.8" nvml-wrapper = "0.9.0" -http-api-bindings = { path = "../http-api-bindings" } +http-api-bindings = { path = "../http-api-bindings", optional = true } # included when build with `experimental-http` feature async-stream = { workspace = true } axum-streams = { version = "0.9.1", features = ["json"] } minijinja = { version = "1.0.8", features = ["loader"] } diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 9dc7a0a300df..8e28d1271006 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -10,39 +10,40 @@ pub async fn create_engine( model_id: &str, args: &crate::serve::ServeArgs, ) -> (Box, EngineInfo) { - if args.device != super::Device::ExperimentalHttp { - if fs::metadata(model_id).is_ok() { - let path = PathBuf::from(model_id); - let model_path = path.join(GGML_MODEL_RELATIVE_PATH); - let engine = create_ggml_engine( - &args.device, - model_path.display().to_string().as_str(), - args.parallelism, - ); - let engine_info = EngineInfo::read(path.join("tabby.json")); - (engine, engine_info) - } else { - let (registry, name) = parse_model_id(model_id); - let registry = ModelRegistry::new(registry).await; - let model_path = registry.get_model_path(name).display().to_string(); - let model_info = registry.get_model_info(name); - let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); - ( - engine, - EngineInfo { - prompt_template: model_info.prompt_template.clone(), - chat_template: model_info.chat_template.clone(), - }, - ) - } - } else { + #[cfg(feature = "experimental-http")] + if args.device == crate::serve::Device::ExperimentalHttp { let (engine, prompt_template) = http_api_bindings::create(model_id); - ( + return ( engine, EngineInfo { prompt_template: Some(prompt_template), chat_template: None, }, + ); + } + + if fs::metadata(model_id).is_ok() { + let path = PathBuf::from(model_id); + let model_path = path.join(GGML_MODEL_RELATIVE_PATH); + let engine = create_ggml_engine( + &args.device, + model_path.display().to_string().as_str(), + args.parallelism, + ); + let engine_info = EngineInfo::read(path.join("tabby.json")); + (engine, engine_info) + } else { + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + let model_path = registry.get_model_path(name).display().to_string(); + let model_info = registry.get_model_info(name); + let engine = create_ggml_engine(&args.device, &model_path, args.parallelism); + ( + engine, + EngineInfo { + prompt_template: model_info.prompt_template.clone(), + chat_template: model_info.chat_template.clone(), + }, ) } } diff --git a/crates/tabby/src/serve/mod.rs b/crates/tabby/src/serve/mod.rs index 1d3bcc7a4f5b..7bee9df68538 100644 --- a/crates/tabby/src/serve/mod.rs +++ b/crates/tabby/src/serve/mod.rs @@ -24,7 +24,7 @@ use tabby_common::{ use tabby_download::download_model; use tokio::time::sleep; use tower_http::{cors::CorsLayer, timeout::TimeoutLayer}; -use tracing::{info, warn}; +use tracing::info; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -86,6 +86,7 @@ pub enum Device { #[strum(serialize = "metal")] Metal, + #[cfg(feature = "experimental-http")] #[strum(serialize = "experimental_http")] ExperimentalHttp, } @@ -131,18 +132,14 @@ pub struct ServeArgs { } pub async fn main(config: &Config, args: &ServeArgs) { - if args.device != Device::ExperimentalHttp { - if fs::metadata(&args.model).is_ok() { - info!("Loading model from local path {}", &args.model); - } else { - download_model(&args.model, true).await; - if let Some(chat_model) = &args.chat_model { - download_model(chat_model, true).await; - } - } + #[cfg(feature = "experimental-http")] + if args.device == Device::ExperimentalHttp { + tracing::warn!("HTTP device is unstable and does not comply with semver expectations."); } else { - warn!("HTTP device is unstable and does not comply with semver expectations.") + load_model(args).await; } + #[cfg(not(feature = "experimental-http"))] + load_model(args).await; info!("Starting server, this might takes a few minutes..."); @@ -172,6 +169,17 @@ pub async fn main(config: &Config, args: &ServeArgs) { .unwrap_or_else(|err| fatal!("Error happens during serving: {}", err)) } +async fn load_model(args: &ServeArgs) { + if fs::metadata(&args.model).is_ok() { + info!("Loading model from local path {}", &args.model); + } else { + download_model(&args.model, true).await; + if let Some(chat_model) = &args.chat_model { + download_model(chat_model, true).await; + } + } +} + async fn api_router(args: &ServeArgs, config: &Config) -> Router { let code = Arc::new(create_code_search()); let completion_state = {