Skip to content

Commit

Permalink
chore: support async embed (#1139)
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy authored Jan 8, 2025
1 parent 2bd6da2 commit b47a635
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 677 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/indexer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ redis = { workspace = true, features = [
] }
tokio-util = "0.7.12"
secrecy = { workspace = true, features = ["serde"] }
reqwest.workspace = true
4 changes: 2 additions & 2 deletions libs/indexer/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ impl IndexerScheduler {
)))
}

pub fn create_search_embeddings(
pub async fn create_search_embeddings(
&self,
request: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
let embedder = self.create_embedder()?;
let embeddings = embedder.embed(request)?;
let embeddings = embedder.async_embed(request).await?;
Ok(embeddings)
}

Expand Down
8 changes: 8 additions & 0 deletions libs/indexer/src/vector/embedder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ impl Embedder {
Self::OpenAI(embedder) => embedder.embed(params),
}
}
pub async fn async_embed(
&self,
params: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
match self {
Self::OpenAI(embedder) => embedder.async_embed(params).await,
}
}

pub fn model(&self) -> EmbeddingModel {
EmbeddingModel::TextEmbedding3Small
Expand Down
51 changes: 43 additions & 8 deletions libs/indexer/src/vector/open_ai.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::vector::rest::check_response;
use crate::vector::rest::check_ureq_response;
use anyhow::anyhow;
use app_error::AppError;
use appflowy_ai_client::dto::{EmbeddingRequest, OpenAIEmbeddingResponse};
Expand All @@ -14,32 +14,39 @@ pub const REQUEST_PARALLELISM: usize = 40;
#[derive(Debug, Clone)]
pub struct Embedder {
bearer: String,
client: ureq::Agent,
sync_client: ureq::Agent,
async_client: reqwest::Client,
}

impl Embedder {
pub fn new(api_key: String) -> Self {
let bearer = format!("Bearer {api_key}");
let client = ureq::AgentBuilder::new()
let sync_client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build();

Self { bearer, client }
let async_client = reqwest::Client::builder().build().unwrap();

Self {
bearer,
sync_client,
async_client,
}
}

pub fn embed(&self, params: EmbeddingRequest) -> Result<OpenAIEmbeddingResponse, AppError> {
for attempt in 0..3 {
let request = self
.client
.sync_client
.post(OPENAI_EMBEDDINGS_URL)
.set("Authorization", &self.bearer)
.set("Content-Type", "application/json");

let result = check_response(request.send_json(&params));
let result = check_ureq_response(request.send_json(&params));
let retry_duration = match result {
Ok(response) => {
let data = from_response::<OpenAIEmbeddingResponse>(response)?;
let data = from_ureq_response::<OpenAIEmbeddingResponse>(response)?;
return Ok(data);
},
Err(retry) => retry.into_duration(attempt),
Expand All @@ -53,9 +60,24 @@ impl Embedder {
"Failed to generate embeddings after 3 attempts"
)))
}

pub async fn async_embed(
&self,
params: EmbeddingRequest,
) -> Result<OpenAIEmbeddingResponse, AppError> {
let request = self
.async_client
.post(OPENAI_EMBEDDINGS_URL)
.header("Authorization", &self.bearer)
.header("Content-Type", "application/json");

let result = request.json(&params).send().await?;
let response = from_response::<OpenAIEmbeddingResponse>(result).await?;
Ok(response)
}
}

pub fn from_response<T>(resp: ureq::Response) -> Result<T, anyhow::Error>
pub fn from_ureq_response<T>(resp: ureq::Response) -> Result<T, anyhow::Error>
where
T: DeserializeOwned,
{
Expand All @@ -69,6 +91,19 @@ where
Ok(resp)
}

pub async fn from_response<T>(resp: reqwest::Response) -> Result<T, anyhow::Error>
where
T: DeserializeOwned,
{
let status_code = resp.status();
if status_code != 200 {
let body = resp.text().await?;
anyhow::bail!("error code: {}, {}", status_code, body)
}

let resp = resp.json().await?;
Ok(resp)
}
/// ## Execution Time Comparison Results
///
/// The following results were observed when running `execution_time_comparison_tests`:
Expand Down
2 changes: 1 addition & 1 deletion libs/indexer/src/vector/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl Retry {
}

#[allow(clippy::result_large_err)]
pub(crate) fn check_response(
pub(crate) fn check_ureq_response(
response: Result<ureq::Response, ureq::Error>,
) -> Result<ureq::Response, Retry> {
match response {
Expand Down
Loading

0 comments on commit b47a635

Please sign in to comment.