diff --git a/demo.py b/demo.py index 94f038a..9d2d419 100755 --- a/demo.py +++ b/demo.py @@ -8,12 +8,12 @@ # print(result) def substring_test(): - internal.index_files_substring(["example_data/a.parquet"], "text", "index0", token_skip_factor = 10) - internal.index_files_substring(["example_data/b.parquet"], "text", "index1", token_skip_factor = 10) - internal.merge_index_substring("merged_index", ["index0", "index1"]) + # internal.index_files_substring(["example_data/a.parquet"], "text", "index0", token_skip_factor = 10) + # internal.index_files_substring(["example_data/b.parquet"], "text", "index1", token_skip_factor = 10) + # internal.merge_index_substring("merged_index", ["index0", "index1"]) result = internal.search_index_substring(["index0"], "One step you have to remember not to skip is to use Disk Utility to partition the SSD as GUID partition scheme HFS+ before doing the clone.", - K = 10, sample_factor = 10) + K = 10, token_viable_limit= 1, sample_factor = 10) print(result) # table1 = polars.read_parquet("uuid_data/a.parquet") @@ -23,16 +23,18 @@ def substring_test(): # write_deltalake("uuid_data_delta", table2.to_arrow(), mode = "append", engine = 'rust', writer_properties = WriterProperties(data_page_size_limit=1000000, compression = 'ZSTD')) # rottnest.index_delta("uuid_data_delta", "hashes", "uuid_rottnest_index", "uuid") - def uuid_test(): - internal.index_files_uuid(["uuid_data/a.parquet"], "hashes", "index0") - internal.index_files_uuid(["uuid_data/b.parquet"], "hashes", "index1") - internal.merge_index_uuid("merged_index_new", ["index0", "index1"]) - result = internal.search_index_uuid(["merged_index_new"], "93b9f88dd22cb168cbc45000fcb05042cd1fc4b5602a56e70383fa26d33d21b08d004d78a7c97a463331da2da64e88f5546367e16e5fd2539bb9b8796ffffc7f", K = 10) - print(result) + # internal.index_files_uuid(["uuid_data/a.parquet"], "hashes", "index0") + # internal.index_files_uuid(["uuid_data/b.parquet"], "hashes", "index1") + + internal.index_files_uuid(["s3://txhashesbenchmark/0.parquet"], "hashes", "index0") + internal.index_files_uuid(["s3://txhashesbenchmark/1.parquet"], "hashes", "index1") + + internal.merge_index_uuid("merged_index", ["index0", "index1"]) result = internal.search_index_uuid(["merged_index"], "93b9f88dd22cb168cbc45000fcb05042cd1fc4b5602a56e70383fa26d33d21b08d004d78a7c97a463331da2da64e88f5546367e16e5fd2539bb9b8796ffffc7f", K = 10) print(result) +substring_test() # result = rottnest.search_index_uuid(["merged_index"], "650243a9024fe6595fa953e309c722c225cb2fae1f70c74364917eb901bcdce1f9a878d22345a8576a201646b6da815ebd6397cfd313447ee3a548259f63825a", K = 10) # print(result) diff --git a/python/rottnest/internal.py b/python/rottnest/internal.py index 2b95ee6..568d495 100644 --- a/python/rottnest/internal.py +++ b/python/rottnest/internal.py @@ -27,9 +27,9 @@ def index_files_bm25(file_paths: list[str], column_name: str, name = uuid.uuid4( file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)}) pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd') -def index_files_substring(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical", tokenizer_file = None, token_skip_factor = None): +def index_files_substring(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical", tokenizer_file = None, token_skip_factor = None, remote = None): - arr, uid, file_data = get_physical_layout(file_paths, column_name) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid") + arr, uid, file_data = get_physical_layout(file_paths, column_name, remote = remote) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", remote = remote) cache_ranges = rottnest.build_lava_substring(f"{name}.lava", arr, uid, tokenizer_file, token_skip_factor) @@ -37,9 +37,9 @@ def index_files_substring(file_paths: list[str], column_name: str, name = uuid.u file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)}) pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd') -def index_files_uuid(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical"): +def index_files_uuid(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical", remote = None): - arr, uid, file_data = get_physical_layout(file_paths, column_name) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid") + arr, uid, file_data = get_physical_layout(file_paths, column_name, remote = remote) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", remote = remote) idx = pac.sort_indices(arr) arr = arr.take(idx) @@ -51,7 +51,9 @@ def index_files_uuid(file_paths: list[str], column_name: str, name = uuid.uuid4( file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)}) pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd') -def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, dtype = 'f32', index_mode = "physical", gpu = False): + + +def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, dtype = 'f32', index_mode = "physical", gpu = False, remote = None): try: import faiss @@ -64,7 +66,7 @@ def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid assert dtype == 'f32' dtype_size = 4 - arr, uid, file_data = get_physical_layout(file_paths, column_name, type = "binary") if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", type = "binary") + arr, uid, file_data = get_physical_layout(file_paths, column_name, type = "binary", remote = remote) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", type = "binary", remote = remote) uid = uid.to_numpy() # arr will be a array of largebinary, we need to convert it into numpy, time for some arrow ninja @@ -75,9 +77,8 @@ def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid dim = diffs.item() // dtype_size x = np.frombuffer(buffers[2], dtype = np.float32).reshape(len(arr), dim) - num_centroids = 1000 # len(arr) // 10_000 + num_centroids = len(arr) // 10_000 - # kmeans = faiss.Kmeans(128, len(arr) // 10_000, niter=30, verbose=True, gpu = gpu) kmeans = faiss.Kmeans(128,num_centroids, niter=30, verbose=True, gpu = gpu) kmeans.train(x) centroids = kmeans.centroids @@ -91,21 +92,45 @@ def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid posting_lists = [[] for _ in range(num_centroids)] codes_lists = [[] for _ in range(num_centroids)] - for i in tqdm(range(len(arr) // batch_size)): - batch = x[i * batch_size:(i + 1) * batch_size] + if gpu: + + res = faiss.StandardGpuResources() + d = centroids.shape[1] + index = faiss.GpuIndexFlatL2(res, d) + index.add(centroids.astype('float32')) + + # Process batches + for i in tqdm(range(len(arr) // batch_size)): + batch = x[i * batch_size:(i + 1) * batch_size].astype('float32') + k = 20 + distances, indices = index.search(batch, k) + + # The indices are already sorted by distance, so we don't need to sort again + closest_centroids = indices[:, 0] - distances = -np.sum(centroids ** 2, axis=1, keepdims=True).T + 2 * np.dot(batch, centroids.T) - indices = np.argpartition(-distances, kth=20, axis=1)[:, :20] - sorted_indices = np.argsort(-distances[np.arange(distances.shape[0])[:, None], indices], axis=1) - indices = indices[np.arange(indices.shape[0])[:, None], sorted_indices] + for k in range(batch_size): + # TODO: this uses UID! Just a warning. because gemv is fast even on lowly CPUs for final reranking. + posting_lists[closest_centroids[k]].append(uid[i * batch_size + k]) + codes_lists[closest_centroids[k]].append(codes[i * batch_size + k]) - closest_centroids = list(indices[:,0]) - # closest2_centroids = list(indices[:,1]) + + else: + for i in tqdm(range(len(arr) // batch_size)): + batch = x[i * batch_size:(i + 1) * batch_size] + + distances = -np.sum(centroids ** 2, axis=1, keepdims=True).T + 2 * np.dot(batch, centroids.T) + indices = np.argpartition(-distances, kth=20, axis=1)[:, :20] + sorted_indices = np.argsort(-distances[np.arange(distances.shape[0])[:, None], indices], axis=1) + indices = indices[np.arange(indices.shape[0])[:, None], sorted_indices] + + closest_centroids = list(indices[:,0]) + # closest2_centroids = list(indices[:,1]) + + for k in range(batch_size): + # TODO: this uses UID! Just a warning. because gemv is fast even on lowly CPUs for final reranking. + posting_lists[closest_centroids[k]].append(uid[i * batch_size + k]) + codes_lists[closest_centroids[k]].append(codes[i * batch_size + k]) - for k in range(batch_size): - # TODO: this uses UID! Just a warning. because gemv is fast even on lowly CPUs for final reranking. - posting_lists[closest_centroids[k]].append(uid[i * batch_size + k]) - codes_lists[closest_centroids[k]].append(codes[i * batch_size + k]) f = open(f"{name}.lava", "wb") centroid_offsets = [0] @@ -241,15 +266,18 @@ def search_index_uuid(indices: List[str], query: str, K: int, columns = []): return return_full_result(result, metadata, column_name, columns) -def search_index_substring(indices: List[str], query: str, K: int, sample_factor = None, columns = []): +def search_index_substring(indices: List[str], query: str, K: int, sample_factor = None, token_viable_limit = 1, columns = []): metadata = get_metadata_and_populate_cache(indices) - index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K, "aws", sample_factor = sample_factor) + index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K, "aws", sample_factor = sample_factor, token_viable_limit = token_viable_limit) print(index_search_results) if len(index_search_results) == 0: return None + + if len(index_search_results) > 10000: + return "Brute Force Please" result, column_name, metadata = get_result_from_index_result(metadata, index_search_results) result = polars.from_arrow(result).filter(polars.col(column_name).str.to_lowercase().str.contains(query.lower(), literal=True)) diff --git a/python/rottnest/table.py b/python/rottnest/table.py index c055d0c..aeb8c78 100644 --- a/python/rottnest/table.py +++ b/python/rottnest/table.py @@ -1,10 +1,5 @@ -from deltalake import DeltaTable, write_deltalake -import duckdb -import pyarrow import polars -from deltalake._internal import TableNotFoundError import uuid - from . import internal """ the schema of the metadata table will be: @@ -16,6 +11,9 @@ def index_delta(table: str, column: str, index_dir: str, type: str, index_impl = 'physical', extra_configs = {}): + from deltalake import DeltaTable, write_deltalake + from deltalake._internal import TableNotFoundError + assert type in {"uuid", "substring", "vector"} assert index_impl in {"physical", "virtual"} if index_impl == "virtual": @@ -62,30 +60,45 @@ def index_delta(table: str, column: str, index_dir: str, type: str, index_impl = # we should deprecate the type argument, and figure out the type automatically. def search_delta(table: str, index_dir: str, query, type: str, K: int, snapshot : int | None = None, extra_configs = {}): + from deltalake import DeltaTable, write_deltalake + from deltalake._internal import TableNotFoundError + assert type in {"uuid", "substring", "vector"} main_table = DeltaTable(table) if snapshot is not None: main_table.load_as_version(snapshot) - existing_parquet_files = polars.from_dict({"covered_parquet_files": main_table.file_uris()}) + existing_parquet_files = set(main_table.file_uris()) index_dir = index_dir.rstrip("/") metadata_table_dir = f"{index_dir}/metadata_table" + selected_indices = [] + uncovered_parquets = set(existing_parquet_files) + try: metadata = polars.from_arrow(DeltaTable(metadata_table_dir).to_pyarrow_table()) - metadata = metadata.explode('covered_parquet_files') + # convert it to a dictionary + metadata = {row['index_file']: row['covered_parquet_files'] for _, row in metadata.iterrows()} + + while True: + overlap_to_index_file = {len(existing_parquet_files.intersection(v)): k for k, v in metadata.items()} + max_overlap = max(overlap_to_index_file.keys()) + if max_overlap == 0: + break + index_file = overlap_to_index_file[max_overlap] + selected_indices.append(index_file) + uncovered_parquets = uncovered_parquets.difference(set(metadata[index_file])) + del metadata[index_file] + except TableNotFoundError: - # brute force pass - - - # we want to figure out the minimal number of index files that cover the existing parquet files to the best of our abilities. - - - - # if type == "uuid": - # index_search_results = internal.search_lava_uuid([f"{index_name}.lava" for index_name in indices], query, K, "aws") \ No newline at end of file + if type == "uuid": + index_search_results = internal.search_lava_uuid([f"{index_name}.lava" for index_name in selected_indices], query, K, "aws") + elif type == "substring": + index_search_results = internal.search_lava_substring([f"{index_name}.lava" for index_name in selected_indices], query, K, "aws") + elif type == "vector": + index_search_results = internal.search_lava_vector([f"{index_name}.lava" for index_name in selected_indices], query, K, "aws") \ No newline at end of file diff --git a/python/rottnest/utils.py b/python/rottnest/utils.py index a909ce4..cace90f 100644 --- a/python/rottnest/utils.py +++ b/python/rottnest/utils.py @@ -58,7 +58,7 @@ def read_parquet_file(file, row_group, row_nr): return pyarrow.concat_tables(results) -def get_physical_layout(file_paths: list[str], column_name: str, type = "str"): +def get_physical_layout(file_paths: list[str], column_name: str, type = "str", remote = None): assert type in {"str", "binary"} @@ -80,7 +80,7 @@ def get_physical_layout(file_paths: list[str], column_name: str, type = "str"): metadata = polars.from_dict({ "uid": np.arange(len(data_page_num_rows) + 1), - "file_path": [file_path] * (len(data_page_num_rows) + 1), + "file_path": [file_path if remote is None else remote + file_path] * (len(data_page_num_rows) + 1), "column_name": [column_name] * (len(data_page_num_rows) + 1), # TODO: figure out a better way to handle this. Currently this is definitely not a bottleneck. Write ampl factor is almost 10x # writing just one row followed by a bunch of Nones don't help, likely because it's already smart enough to do dict encoding. @@ -105,7 +105,7 @@ def get_physical_layout(file_paths: list[str], column_name: str, type = "str"): return pyarrow.concat_arrays(all_arrs), pyarrow.array(all_uids.astype(np.uint64)), polars.concat(metadatas) -def get_virtual_layout(file_paths: list[str], column_name: str, key_column_name: str, type = "str", stride = 500): +def get_virtual_layout(file_paths: list[str], column_name: str, key_column_name: str, type = "str", stride = 500, remote = None): fs = get_fs_from_file_path(file_paths[0]) metadatas = [] @@ -119,7 +119,7 @@ def get_virtual_layout(file_paths: list[str], column_name: str, key_column_name: arr = table[column_name].to_arrow().cast(pyarrow.large_string() if type == 'str' else pyarrow.large_binary()) uid = table['__uid__'].to_arrow().cast(pyarrow.uint64()) - metadata = table.groupby("__uid__").agg([polars.col(key_column_name).min().alias("min"), polars.col(key_column_name).max().alias("max")]).sort("__uid__") + metadata = table.group_by("__uid__").agg([polars.col(key_column_name).min().alias("min"), polars.col(key_column_name).max().alias("max")]).sort("__uid__") return arr, uid, metadata @@ -164,7 +164,7 @@ def get_result_from_index_result(metadata: polars.DataFrame, index_search_result def return_full_result(result: polars.DataFrame, metadata: polars.DataFrame, column_name: str, columns: List[str]): if columns != []: result = result.join(metadata.select(["__metadata_key__", "file_path", "row_groups"]), on = "__metadata_key__", how = "left") - grouped = result.groupby(["file_path", "row_groups"]).agg([polars.col('__metadata_key__'), polars.col('__row_group_rownr__')]) + grouped = result.group_by(["file_path", "row_groups"]).agg([polars.col('__metadata_key__'), polars.col('__row_group_rownr__')]) collected_results = polars.from_arrow(read_columns(grouped["file_path"].to_list(), grouped["row_groups"].to_list(), grouped["__row_group_rownr__"].to_list())) unnested_metadata_key = grouped['__metadata_key__'].explode() unnested_row_group_rownr = grouped['__row_group_rownr__'].explode() diff --git a/src/lava/build.rs b/src/lava/build.rs index bb2d33f..8bc8328 100644 --- a/src/lava/build.rs +++ b/src/lava/build.rs @@ -523,8 +523,6 @@ pub async fn build_lava_substring( }) .collect::, Vec)>>(); - println!("{:?}", named_encodings[1]); - let uids: Vec = named_encodings .iter() .map(|(uid, _)| uid) @@ -558,8 +556,6 @@ pub async fn build_lava_substring( (encodings, uids) }; - println!("{:?}", &encodings[..200]); - for i in 10..encodings.len() { suffices.push(encodings[i - 10..i].to_vec()); } @@ -598,16 +594,16 @@ pub async fn build_lava_substring( // write out the bwt to a numpy array - let file = File::create("output.bin")?; - let mut writer = BufWriter::new(file); + // let file = File::create("output.bin")?; + // let mut writer = BufWriter::new(file); - // Write each u32 to the file as bytes - for number in bwt.iter() { - writer.write_u32::(*number)?; - } + // // Write each u32 to the file as bytes + // for number in bwt.iter() { + // writer.write_u32::(*number)?; + // } // Flush the buffer to ensure all data is written to the file - writer.flush()?; + // writer.flush()?; let mut file = File::create(output_file_name)?; file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; diff --git a/src/lava/search.rs b/src/lava/search.rs index 35f296c..09bfd8b 100644 --- a/src/lava/search.rs +++ b/src/lava/search.rs @@ -66,6 +66,105 @@ async fn get_tokenizer_async( Ok((tokenizer, result)) } +async fn process_substring_query( + query: Vec, + n: u64, + fm_chunk_offsets: &[u64], + cumulative_counts: &[u64], + posting_list_offsets: &[u64], + reader: &mut AsyncReader, + file_id: u64, +) -> Vec<(u64, u64)> { + let mut res: Vec<(u64, u64)> = vec![]; + let mut start: usize = 0; + let mut end: usize = n as usize; + + for i in (0..query.len()).rev() { + let current_token = query[i]; + + let start_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS]; + let end_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS + 1]; + let start_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); + + let start_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS]; + let end_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS + 1]; + let end_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); + + start = cumulative_counts[current_token as usize] as usize + + FMChunk::new(start_chunk) + .unwrap() + .search(current_token, start % FM_CHUNK_TOKS) + .unwrap() as usize; + end = cumulative_counts[current_token as usize] as usize + + FMChunk::new(end_chunk) + .unwrap() + .search(current_token, end % FM_CHUNK_TOKS) + .unwrap() as usize; + + if start >= end { + return res; + } + + if end <= start + 2 { + break; + } + } + + let start_offset = posting_list_offsets[start / FM_CHUNK_TOKS]; + let end_offset = posting_list_offsets[end / FM_CHUNK_TOKS + 1]; + let total_chunks = end / FM_CHUNK_TOKS - start / FM_CHUNK_TOKS + 1; + + let plist_chunks = reader.read_range(start_offset, end_offset).await.unwrap(); + + let mut chunk_set = JoinSet::new(); + + for i in 0..total_chunks { + let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; + let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; + let this_chunk = plist_chunks + [(this_start - start_offset) as usize..(this_end - start_offset) as usize] + .to_vec(); + + chunk_set.spawn(async move { + let mut decompressor = Decoder::new(&this_chunk[..]).unwrap(); + let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len()); + decompressor + .read_to_end(&mut serialized_plist_chunk) + .unwrap(); + let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk).unwrap(); + + let chunk_res: Vec<(u64, u64)> = if i == 0 { + if total_chunks == 1 { + plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } else { + plist_chunk[start % FM_CHUNK_TOKS..] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } + } else if i == total_chunks - 1 { + plist_chunk[..end % FM_CHUNK_TOKS] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } else { + plist_chunk.iter().map(|&uid| (file_id, uid)).collect() + }; + + chunk_res + }); + } + + while let Some(chunk_res) = chunk_set.join_next().await { + res.extend(chunk_res.unwrap()); + } + + res +} + async fn search_substring_one_file( file_id: u64, mut reader: AsyncReader, @@ -92,86 +191,33 @@ async fn search_substring_one_file( .read_range_and_decompress(total_counts_offset, (file_size - 32) as u64) .await?; - // let previous_range: u64 = u64::MAX; - - let mut res: Vec<(u64, u64)> = vec![]; + let mut query_set = JoinSet::new(); for query in queries { - let mut start: usize = 0; - let mut end: usize = n as usize; - for i in (0..query.len()).rev() { - let current_token = query[i]; - - let start_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS]; - let end_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS + 1]; - let start_chunk = reader.read_range(start_byte, end_byte).await?; - - let start_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS]; - let end_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS + 1]; - let end_chunk = reader.read_range(start_byte, end_byte).await?; - - start = cumulative_counts[current_token as usize] as usize - + FMChunk::new(start_chunk)? - .search(current_token, start % FM_CHUNK_TOKS) - .unwrap() as usize; - end = cumulative_counts[current_token as usize] as usize - + FMChunk::new(end_chunk)? - .search(current_token, end % FM_CHUNK_TOKS) - .unwrap() as usize; - - if start >= end { - break; - } - } - - if start >= end { - continue; - } - - let start_offset = posting_list_offsets[start / FM_CHUNK_TOKS]; - let end_offset = posting_list_offsets[end / FM_CHUNK_TOKS + 1]; - let total_chunks = end / FM_CHUNK_TOKS - start / FM_CHUNK_TOKS + 1; - - // println!("total chunks: {}", total_chunks); - - let plist_chunks = reader.read_range(start_offset, end_offset).await?; - for i in 0..total_chunks { - let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; - let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; - let this_chunk = &plist_chunks - [(this_start - start_offset) as usize..(this_end - start_offset) as usize]; - - // decompress this chunk - let mut decompressor = Decoder::new(&this_chunk[..])?; - let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len() as usize); - decompressor.read_to_end(&mut serialized_plist_chunk)?; - let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk)?; + let fm_chunk_offsets = fm_chunk_offsets.clone(); + let cumulative_counts = cumulative_counts.clone(); + let posting_list_offsets = posting_list_offsets.clone(); + let mut reader = reader.clone(); + + query_set.spawn(async move { + process_substring_query( + query, + n, + &fm_chunk_offsets, + &cumulative_counts, + &posting_list_offsets, + &mut reader, + file_id, + ) + .await + }); + } - if i == 0 { - if total_chunks == 1 { - for uid in &plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS] { - // println!("push file_id {}", file_id); - res.push((file_id as u64, *uid)); - } - } else { - for uid in &plist_chunk[start % FM_CHUNK_TOKS..] { - // println!("push file_id {}", file_id); - res.push((file_id as u64, *uid)); - } - } - } else if i == total_chunks - 1 { - for uid in &plist_chunk[..end % FM_CHUNK_TOKS] { - // println!("push file_id {}", file_id); - res.push((file_id as u64, *uid)); - } - } else { - for uid in &plist_chunk[..] { - // println!("push file_id {}", file_id); - res.push((file_id as u64, *uid)); - } - } - } + let mut res = Vec::new(); + while let Some(query_res) = query_set.join_next().await { + res.extend(query_res.unwrap()); } + Ok(res) } @@ -229,16 +275,22 @@ async fn search_generic_async( while let Some(res) = join_set.join_next().await { let res = res.unwrap().unwrap(); result.extend(res); - if result.len() >= k { - break; - } + /* + This is not safe. This is because the index might raise false positives, such that the top K only contains false positives. + We should support doing this if the index is guaranteed not to have false positives. + E.g. SSA index will have false positives with skip_factor > 1 + E.g. Trie index will have false positives since values on the itnermediate nodes, due to the merge process. + */ + // if result.len() >= k { + // break; + // } } join_set.shutdown().await; // keep only k elements in the result - let mut result: Vec<(u64, u64)> = result.into_iter().collect_vec(); - result.truncate(k); + let result: Vec<(u64, u64)> = result.into_iter().collect_vec(); + // result.truncate(k); Ok(result) } @@ -431,6 +483,7 @@ pub async fn search_lava_substring( query: String, k: usize, reader_type: ReaderType, + token_viable_limit: Option, sample_factor: Option, ) -> Result, LavaError> { let (_file_sizes, readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; @@ -471,9 +524,7 @@ pub async fn search_lava_substring( .cloned() .collect(); - println!("{:?}", result); - - let query: Vec> = if let Some(sample_factor) = sample_factor { + let mut query: Vec> = if let Some(sample_factor) = sample_factor { (0..sample_factor) .map(|offset| { result @@ -489,7 +540,23 @@ pub async fn search_lava_substring( vec![result] }; - // println!("{:?}", result); + // query = [i[-token_viable_limit:] for i in query] + + if let Some(token_viable_limit) = token_viable_limit { + query.iter_mut().for_each(|vec| { + if vec.len() > token_viable_limit { + *vec = vec + .iter() + .rev() + .take(token_viable_limit) + .rev() + .cloned() + .collect(); + } + }); + } + + println!("{:?}", query); let (file_sizes, readers) = get_file_sizes_and_readers(&files, reader_type).await?; search_generic_async(file_sizes, readers, QueryParam::Substring(query), k).await @@ -551,13 +618,15 @@ pub async fn search_lava_vector_async( decompressor.read_to_end(&mut centroid_vectors).unwrap(); let centroid_vectors = bytes_to_f32_vec(¢roid_vectors); - let array2 = Array2::::from_shape_vec((1000, 128), centroid_vectors).unwrap(); + let num_vectors = centroid_vectors.len() / 128; + let array2 = + Array2::::from_shape_vec((num_vectors, 128), centroid_vectors).unwrap(); - (results, array2) + (num_vectors, array2) })); } - let result: Vec, Array2), tokio::task::JoinError>> = + let result: Vec), tokio::task::JoinError>> = futures::future::join_all(futures).await; let end = Instant::now(); @@ -565,6 +634,19 @@ pub async fn search_lava_vector_async( let start = Instant::now(); + let arr_lens = result + .iter() + .map(|x| x.as_ref().unwrap().0) + .collect::>(); + // get cumulative arr len starting from 0 + let cumsum = arr_lens + .iter() + .scan(0, |acc, &x| { + *acc += x; + Some(*acc) + }) + .collect::>(); + let arrays: Vec> = result.into_iter().map(|x| x.unwrap().1).collect(); let centroids = concatenate( Axis(0), @@ -595,7 +677,21 @@ pub async fn search_lava_vector_async( let mut file_indices: Vec> = vec![vec![]; files.len()]; for idx in smallest_indices.iter() { - file_indices[*idx / 1000].push(*idx % 1000 as usize); + // figure out which file idx based on cumsum. need to find the index of the thing that is just bigger than idx + + let file_idx = cumsum + .iter() + .enumerate() + .find(|(_, &val)| val > *idx) + .unwrap() + .0; + let last_cumsum = if file_idx == 0 { + 0 + } else { + cumsum[file_idx - 1] + }; + let remainder = idx - last_cumsum; + file_indices[file_idx].push(remainder); } let end = Instant::now(); @@ -749,6 +845,7 @@ mod tests { "Samsung Galaxy Note".to_string(), 10, ReaderType::default(), + Some(10), None, ); println!("{:?}", result.unwrap()); diff --git a/src/lava_py/lava.rs b/src/lava_py/lava.rs index 465b74d..24e8637 100644 --- a/src/lava_py/lava.rs +++ b/src/lava_py/lava.rs @@ -33,12 +33,20 @@ pub fn search_lava_substring( query: String, k: usize, reader_type: Option<&PyString>, + token_viable_limit: Option, sample_factor: Option, ) -> Result, LavaError> { let reader_type = reader_type.map(|x| x.to_string()).unwrap_or_default(); py.allow_threads(|| { - lava::search_lava_substring(files, query, k, reader_type.into(), sample_factor) + lava::search_lava_substring( + files, + query, + k, + reader_type.into(), + token_viable_limit, + sample_factor, + ) }) }