Skip to content

Commit

Permalink
[ENH] Query merging
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Apr 26, 2024
1 parent 6dbb4ef commit a316bf4
Show file tree
Hide file tree
Showing 10 changed files with 761 additions and 125 deletions.
4 changes: 2 additions & 2 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,8 @@ def _query(
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)

metadata_reader = self._manager.get_segment(collection_id, MetadataReader)

if where or where_document:
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_reader.get_metadata(
where=where, where_document=where_document
)
Expand Down Expand Up @@ -721,6 +720,7 @@ def _query(
all_ids: Set[str] = set()
for id_list in ids:
all_ids.update(id_list)
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_reader.get_metadata(ids=list(all_ids))
metadata_by_id = {r["id"]: r["metadata"] for r in records}
for id_list in ids:
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ FROM debian:bookworm-slim as query_service

COPY --from=query_service_builder /chroma/query_service .
COPY --from=query_service_builder /chroma/rust/worker/chroma_config.yaml .
RUN apt-get update && apt-get install -y libssl-dev
RUN apt-get update && apt-get install -y libssl-dev ca-certificates

ENTRYPOINT [ "./query_service" ]

Expand Down
36 changes: 36 additions & 0 deletions rust/worker/src/execution/operators/hnsw_knn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::{
errors::ChromaError, execution::operator::Operator,
segment::distributed_hnsw_segment::DistributedHNSWSegment,
};
use async_trait::async_trait;

#[derive(Debug)]
pub struct HnswKnnOperator {}

#[derive(Debug)]
pub struct HnswKnnOperatorInput {
pub segment: Box<DistributedHNSWSegment>,
pub query: Vec<f32>,
pub k: usize,
}

#[derive(Debug)]
pub struct HnswKnnOperatorOutput {
pub offset_ids: Vec<usize>,
pub distances: Vec<f32>,
}

pub type HnswKnnOperatorResult = Result<HnswKnnOperatorOutput, Box<dyn ChromaError>>;

#[async_trait]
impl Operator<HnswKnnOperatorInput, HnswKnnOperatorOutput> for HnswKnnOperator {
type Error = Box<dyn ChromaError>;

async fn run(&self, input: &HnswKnnOperatorInput) -> HnswKnnOperatorResult {
let (offset_ids, distances) = input.segment.query(&input.query, input.k);
Ok(HnswKnnOperatorOutput {
offset_ids,
distances,
})
}
}
145 changes: 145 additions & 0 deletions rust/worker/src/execution/operators/merge_knn_results.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use crate::{
blockstore::provider::BlockfileProvider, errors::ChromaError, execution::operator::Operator,
segment::record_segment::RecordSegmentReader, types::Segment,
};
use async_trait::async_trait;
use thiserror::Error;

#[derive(Debug)]
pub struct MergeKnnResultsOperator {}

#[derive(Debug)]
pub struct MergeKnnResultsOperatorInput {
hnsw_result_offset_ids: Vec<usize>,
hnsw_result_distances: Vec<f32>,
brute_force_result_user_ids: Vec<String>,
brute_force_result_distances: Vec<f32>,
k: usize,
record_segment_definition: Segment,
blockfile_provider: BlockfileProvider,
}

impl MergeKnnResultsOperatorInput {
pub fn new(
hnsw_result_offset_ids: Vec<usize>,
hnsw_result_distances: Vec<f32>,
brute_force_result_user_ids: Vec<String>,
brute_force_result_distances: Vec<f32>,
k: usize,
record_segment_definition: Segment,
blockfile_provider: BlockfileProvider,
) -> Self {
Self {
hnsw_result_offset_ids,
hnsw_result_distances,
brute_force_result_user_ids,
brute_force_result_distances,
k,
record_segment_definition,
blockfile_provider: blockfile_provider,
}
}
}

#[derive(Debug)]
pub struct MergeKnnResultsOperatorOutput {
pub user_ids: Vec<String>,
pub distances: Vec<f32>,
}

#[derive(Error, Debug)]
pub enum MergeKnnResultsOperatorError {
#[error("Input lengths do not match k")]
InputLengthMismatchError,
}

impl ChromaError for MergeKnnResultsOperatorError {
fn code(&self) -> crate::errors::ErrorCodes {
match self {
MergeKnnResultsOperatorError::InputLengthMismatchError => {
crate::errors::ErrorCodes::InvalidArgument
}
}
}
}

pub type MergeKnnResultsOperatorResult =
Result<MergeKnnResultsOperatorOutput, Box<dyn ChromaError>>;

#[async_trait]
impl Operator<MergeKnnResultsOperatorInput, MergeKnnResultsOperatorOutput>
for MergeKnnResultsOperator
{
type Error = Box<dyn ChromaError>;

async fn run(&self, input: &MergeKnnResultsOperatorInput) -> MergeKnnResultsOperatorResult {
// All inputs should be of length k
if input.hnsw_result_offset_ids.len() != input.k
|| input.hnsw_result_distances.len() != input.k
|| input.brute_force_result_user_ids.len() != input.k
|| input.brute_force_result_distances.len() != input.k
{
return Err(Box::new(
MergeKnnResultsOperatorError::InputLengthMismatchError,
));
}

// Convert the HNSW result offset IDs to user IDs
let mut hnsw_result_user_ids = Vec::new();

let record_segment_reader = match RecordSegmentReader::from_segment(
&input.record_segment_definition,
&input.blockfile_provider,
)
.await
{
Ok(reader) => reader,
Err(e) => {
println!("Error creating Record Segment Reader: {:?}", e);
return Err(e);
}
};

for offset_id in &input.hnsw_result_offset_ids {
let user_id = record_segment_reader
.get_user_id_for_offset_id(*offset_id as u32)
.await;
match user_id {
Ok(user_id) => {
hnsw_result_user_ids.push(user_id);
}
Err(e) => {
return Err(e);
}
}
}

let mut result_user_ids = Vec::with_capacity(input.k);
let mut result_distances = Vec::with_capacity(input.k);

// Merge the HNSW and brute force results together by the minimum distance top k
let mut hnsw_index = 0;
let mut brute_force_index = 0;

// We know that the input lengths are the same. For now this logic clones, but it could
// be optimized to avoid cloning.
for _ in 0..input.k {
if input.hnsw_result_distances[hnsw_index]
< input.brute_force_result_distances[brute_force_index]
{
result_user_ids.push(hnsw_result_user_ids[hnsw_index].to_string());
result_distances.push(input.hnsw_result_distances[hnsw_index]);
hnsw_index += 1;
} else {
result_user_ids.push(input.brute_force_result_user_ids[brute_force_index].clone());
result_distances.push(input.brute_force_result_distances[brute_force_index]);
brute_force_index += 1;
}
}

Ok(MergeKnnResultsOperatorOutput {
user_ids: result_user_ids,
distances: result_distances,
})
}
}
2 changes: 2 additions & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub(super) mod brute_force_knn;
pub(super) mod flush_s3;
pub(super) mod hnsw_knn;
pub(super) mod merge_knn_results;
pub(super) mod normalize_vectors;
pub(super) mod partition;
pub(super) mod pull_log;
Expand Down
Loading

0 comments on commit a316bf4

Please sign in to comment.