Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Apr 5, 2024
2 parents 6be6fa0 + 52aacd1 commit cd0c6fb
Showing 1 changed file with 365 additions and 0 deletions.
365 changes: 365 additions & 0 deletions src/vamana/vamana.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,368 @@ pub fn build_index_par<T: Indexable, D: Distance<T>, V: VectorAccessMethod<T>>(
t: std::marker::PhantomData,
}
}


pub struct MergedAccessMethod<T : Indexable, V : VectorAccessMethod<T>>
{
underlying_access_method : (V, V),
t : std::marker::PhantomData<T>,
}

impl<T : Indexable, V : VectorAccessMethod<T>> VectorAccessMethod<T> for MergedAccessMethod<T, V>
{
fn get_vec<'b>(&'b self, ivec : usize) -> &'b [T]
{
let num_points_0 = self.underlying_access_method.0.num_points();
if ivec < num_points_0
{
self.underlying_access_method.0.get_vec(ivec)
}
else
{
self.underlying_access_method.1.get_vec(ivec - num_points_0)
}
}

fn dim(&self) -> usize
{
self.underlying_access_method.0.dim()
}

fn num_points(&self) -> usize
{
self.underlying_access_method.0.num_points() + self.underlying_access_method.1.num_points()
}

fn iter<'b>(&'b self) -> impl Iterator<Item = &'b [T]>
{
self.underlying_access_method.0.iter().chain(self.underlying_access_method.1.iter())
}

#[allow(unreachable_code)]
fn par_iter<'b>(&'b self) -> impl rayon::prelude::IndexedParallelIterator<Item = &'b [T]>
{
self.underlying_access_method.0.par_iter().chain(self.underlying_access_method.1.par_iter())
}
}

unsafe fn search_merge<T : Indexable, D : Distance<T>, V : VectorAccessMethod<T>>(
ctx : &mut BuildContext,
start : usize,
query : usize,
access_method : &V,
edgelist : *mut usize,
locks : &Vec<Mutex<()>>,
max_num_neighbors : usize,
search_frontier_size : usize,
)
{
ctx.reset();
let dim = max_num_neighbors + 1;
let query_vector = access_method.get_vec(query);
let start_vector = access_method.get_vec(start);
let start_distance = D::calculate(query_vector, start_vector);
let mut closest_unvisited_vertex = 0;
ctx.search_ctx.frontier.push((start, start_distance));
while closest_unvisited_vertex < ctx.search_ctx.frontier.len()
{
let closest = ctx.search_ctx.frontier[closest_unvisited_vertex];
let _guard = locks[closest.0].lock().unwrap();
ctx.cached_distances.push(closest);
ctx.search_ctx.visited.visit(closest.0);
let num_neighbors = *edgelist.add(closest.0 * dim);
for ineighbor in 0..num_neighbors
{
let n = *edgelist.add(closest.0 * dim + ineighbor + 1);
let neighbor_vector = access_method.get_vec(n);
let distance = D::calculate(query_vector, neighbor_vector);
ctx.search_ctx.frontier.push((n, distance));
}
ctx.search_ctx.frontier.sort_unstable_by(|x, y| x.1.partial_cmp(&y.1).unwrap());
let new_frontier_size = dedup_frontier(&mut ctx.search_ctx.frontier).min(search_frontier_size);
ctx.search_ctx.frontier.truncate(new_frontier_size);
closest_unvisited_vertex = ctx.search_ctx.frontier.len();
for i in 0..ctx.search_ctx.frontier.len()
{
let v = ctx.search_ctx.frontier[i].0;
if !ctx.search_ctx.visited.is_visited(v)
{
closest_unvisited_vertex = i;
break;
}
}
}
}

unsafe fn prune_merge<T : Indexable, D : Distance<T>, V : VectorAccessMethod<T>>(
ctx : &mut BuildContext,
query : usize,
access_method : &V,
edgelist : *mut usize,
locks : &Vec<Mutex<()>>,
pruning_threshold : f64,
max_num_neighbors : usize,
should_lock : bool,
)
{
let dim = max_num_neighbors + 1;
let query_vector = access_method.get_vec(query);
let _guard = if should_lock { Some(locks[query].lock().unwrap()) } else { None };
let num_neighbors = *edgelist.add(query * dim);
for ineighbor in 0..num_neighbors
{
let n = *edgelist.add(query * dim + ineighbor + 1);
ctx.search_ctx.visited.visit(n);
let neighbor_vector = access_method.get_vec(n);
let distance = D::calculate(neighbor_vector, query_vector);
ctx.cached_distances.push((n, distance));
}
ctx.search_ctx.visited.mark_unvisited(query);
ctx.cached_distances.sort_unstable_by(|x, y| x.1.partial_cmp(&y.1).unwrap());

*edgelist.add(query * dim) = 0;
for ivisited in 0..ctx.cached_distances.len()
{
let (v, d) = ctx.cached_distances[ivisited];
if !ctx.search_ctx.visited.is_visited(v)
{
continue;
}
let new_num_neighbors = *edgelist.add(query * dim) + 1;
*edgelist.add(query * dim) = new_num_neighbors;
*edgelist.add(query * dim + new_num_neighbors) = v;
if new_num_neighbors == max_num_neighbors
{
return;
}
let curr_vec = access_method.get_vec(v);
for ielim in (ivisited + 1)..ctx.cached_distances.len()
{
let elim_id = ctx.cached_distances[ielim].0;
if !ctx.search_ctx.visited.is_visited(elim_id)
{
continue;
}
let elim_vec = access_method.get_vec(elim_id);
let distance = D::calculate(curr_vec, elim_vec);
if pruning_threshold * distance < d
{
ctx.search_ctx.visited.mark_unvisited(elim_id);
}
}
}
}

unsafe fn insert_backwards_edges_merge<T : Indexable, D : Distance<T>, V : VectorAccessMethod<T>>(
ctx : &mut BuildContext,
query : usize,
access_method : &V,
edgelist : *mut usize,
locks : &Vec<Mutex<()>>,
pruning_threshold : f64,
max_num_neighbors : usize,
)
{
let dim = max_num_neighbors + 1;
let (num_neighbors, neighbors) =
{
let _guard = locks[query].lock().unwrap();
let num_neighbors = *edgelist.add(query * dim);
let neighbors = (0..num_neighbors).map(|x|
{
*edgelist.add(query * dim + x + 1)
})
.collect::<Vec<_>>();
(num_neighbors, neighbors)
};

for ineighbor in 0..num_neighbors
{
let neighbor = neighbors[ineighbor];
let _neighbor_guard = locks[neighbor].lock().unwrap();
let num_neighbor_neighbors = *edgelist.add(neighbor * dim);
if num_neighbor_neighbors < max_num_neighbors
{
let new_num_neighbor_neighbors = num_neighbor_neighbors + 1;
*edgelist.add(neighbor * dim) = new_num_neighbor_neighbors;
*edgelist.add(neighbor * dim + new_num_neighbor_neighbors) = query;
}
else
{
let query_vector = access_method.get_vec(query);
let neighbor_vector = access_method.get_vec(neighbor);
let distance = D::calculate(query_vector, neighbor_vector);
ctx.reset();
ctx.search_ctx.visited.visit(query);
ctx.cached_distances.push((query, distance));
for inn in 0..num_neighbor_neighbors
{
let nn = *edgelist.add(neighbor * dim + inn + 1);
let nn_vec = access_method.get_vec(nn);
let d = D::calculate(nn_vec, neighbor_vector);
ctx.search_ctx.visited.visit(nn);
ctx.cached_distances.push((nn, d));
}
prune_merge::<T, D, V>(
ctx,
neighbor,
access_method,
edgelist,
locks,
pruning_threshold,
max_num_neighbors,
false, /* should_lock */
);
}
}
}

unsafe fn insert_single<T : Indexable, D : Distance<T>, V : VectorAccessMethod<T>>(
build_ctx : &mut BuildContext,
start : usize,
to_insert : usize,
access_method : &V,
edgelist : *mut usize,
locks : &Vec<Mutex<()>>,
pruning_threshold : f64,
max_num_neighbors : usize,
search_frontier_size : usize,
)
{
search_merge::<T, D, V>(
build_ctx,
start,
to_insert,
access_method,
edgelist,
locks,
search_frontier_size,
max_num_neighbors,
);
prune_merge::<T, D, V>(
build_ctx,
to_insert,
access_method,
edgelist,
locks,
pruning_threshold,
max_num_neighbors,
true, /* should_lock */
);
insert_backwards_edges_merge::<T, D, V>(
build_ctx,
to_insert,
access_method,
edgelist,
locks,
pruning_threshold,
max_num_neighbors,
);
}

pub fn merge_indexes<T : Indexable, D : Distance<T>, V : VectorAccessMethod<T>>(
a : VamanaIndex<T, D, V>,
b : VamanaIndex<T, D, V>,
) -> VamanaIndex<T, D, MergedAccessMethod<T, V>>
{
let num_points = a.access_method.num_points() + b.access_method.num_points();
let b_offset = a.access_method.num_points();
let params = a.params;
let access_method = MergedAccessMethod
{
underlying_access_method : (a.access_method, b.access_method),
t : std::marker::PhantomData,
};
let big_graph = concatenate!(Axis(0), a.neighbors, b.neighbors);
let start = a.start;
let mut merged_index = VamanaIndex
{
params : params,
neighbors : big_graph,
access_method : access_method,
start : start,
metric : std::marker::PhantomData,
t : std::marker::PhantomData,
};

for i in b_offset..num_points
{
merged_index.neighbors_mut(i).iter_mut().for_each(|x| *x += b_offset)
}

let mut build_ctx = BuildContext::new(&merged_index);
let prune = merged_index.params.pruning_threshold;
let num_total_points = merged_index.num_points();
for b_vertex in (b_offset..num_total_points)
{
build_ctx.search(&merged_index, b_vertex);
build_ctx.prune_index(&mut merged_index, b_vertex, prune);
build_ctx.insert_backwards_edges(&mut merged_index, b_vertex, prune);
}
merged_index
}

struct PtrWrapper(*mut usize);
unsafe impl Sync for PtrWrapper {}

pub fn merge_indexes_par<T : Indexable, D : Distance<T>, V : VectorAccessMethod<T>>(
a : VamanaIndex<T, D, V>,
b : VamanaIndex<T, D, V>,
) -> VamanaIndex<T, D, MergedAccessMethod<T, V>>
{
let num_points = a.access_method.num_points() + b.access_method.num_points();
let b_offset = a.access_method.num_points();
let params = a.params;
let access_method = MergedAccessMethod
{
underlying_access_method : (a.access_method, b.access_method),
t : std::marker::PhantomData,
};
let big_graph = concatenate!(Axis(0), a.neighbors, b.neighbors);
let start = a.start;
let mut merged_index = VamanaIndex
{
params : params,
neighbors : big_graph,
access_method : access_method,
start : start,
metric : std::marker::PhantomData,
t : std::marker::PhantomData,
};

for i in b_offset..num_points
{
merged_index.neighbors_mut(i).iter_mut().for_each(|x| *x += b_offset)
}

let mut build_ctx = BuildContext::new(&merged_index);
let num_total_points = merged_index.num_points();
let mut locks = Vec::with_capacity(num_total_points);
for _ in 0..num_total_points
{
locks.push(Mutex::new(()));
}
let edgelist_ptr = PtrWrapper(merged_index.neighbors.as_mut_ptr());
(b_offset..num_total_points)
.into_par_iter()
.for_each_with(build_ctx,
|tl_ctx, b_vertex|
{
let _ = &edgelist_ptr;
unsafe
{
insert_single::<T, D, MergedAccessMethod<T, V>>(
tl_ctx,
start,
b_vertex,
&merged_index.access_method,
edgelist_ptr.0,
&locks,
merged_index.params.pruning_threshold,
merged_index.params.num_neighbors,
merged_index.params.search_frontier_size,
);
}
});
merged_index
}

0 comments on commit cd0c6fb

Please sign in to comment.