diff --git a/Cargo.toml b/Cargo.toml index a7f0bcf..ee13130 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rottnest" -version = "1.4.0" +version = "1.5.0" edition = "2021" build = "build.rs" @@ -13,6 +13,7 @@ crate-type = ["cdylib"] default = [] #['py'] py = ["dep:pyo3", "pyarrow", "dep:pyo3-log"] pyarrow = ["arrow/pyarrow"] +logcloud = ["dep:libc"] [dependencies] @@ -86,7 +87,7 @@ ordered-float = "4.2.0" reqwest = "0.12.4" redis = {version = "0", features = ["aio", "tokio-comp"] } divsufsort = "2.0.0" -libc = "0.2.158" +libc = { version = "0.2.158", optional = true } [profile.release] lto = false diff --git a/build.rs b/build.rs index 8a90d66..6b8f942 100644 --- a/build.rs +++ b/build.rs @@ -1,9 +1,10 @@ -use std::env; -use std::path::PathBuf; - +#[cfg(feature = "logcloud")] fn main() { + use std::env; + use std::path::PathBuf; + let dir = env::var("CARGO_MANIFEST_DIR").unwrap(); - let path = PathBuf::from(dir); + let path = PathBuf::from(dir).join("src").join("lava").join("logcloud"); // Specify the directory containing the .a files println!("cargo:rustc-link-search=native={}", path.display()); @@ -18,6 +19,9 @@ fn main() { println!("cargo:rustc-link-lib=dylib=stdc++"); // Rerun the build script if the static libraries change - println!("cargo:rerun-if-changed=libCompressor.a"); - println!("cargo:rerun-if-changed=libTrainer.a"); + println!("cargo:rerun-if-changed=src/lava/logcloud/libCompressor.a"); + println!("cargo:rerun-if-changed=src/lava/logcloud/libTrainer.a"); } + +#[cfg(not(feature = "logcloud"))] +fn main() {} diff --git a/python/rottnest/ahupuaa.py b/python/rottnest/ahupuaa.py deleted file mode 100644 index c0d91ac..0000000 --- a/python/rottnest/ahupuaa.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -The term for the subdivision of an island in Hawaii into multiple agricultural zones is "ahupuaʻa." -Ahupuaʻa is a traditional Hawaiian system of land division that typically extends from -the mountains to the sea, encompassing a variety of ecosystems. This system allowed for the -sustainable management of resources across different ecological zones. Each ahupuaʻa contained -nearly all the resources the Hawaiian community living within its boundaries would need, -which minimized the need for long-distance travel to gather resources. These zones included -upland forests, agricultural lands, and fishing grounds, among others. - -The methods here mostly use Polars and numpy, so moving to Rust is not that necessary so far. -""" - -import polars -import numpy as np -from typing import Optional - -class Partition: - def __init__(self) -> None: - pass - -class MetricsPartition(Partition): - def __init__(self, ts: np.ndarray, values: np.ndarray, metadata = Optional[dict]) -> None: - super().__init__() - self.ts = ts - self.values = values - self.metadata = metadata - - def __str__(self): - return f"metrics segment: " + str(self.compute_stats()) - - def compute_stats(self) -> None: - min_val, max_val, average_val = np.min(self.values), np.max(self.values), np.mean(self.values) - return {"num_measurements": len(self.ts), - "min_ts": str(np.min(self.ts)), - "max_ts": str(np.max(self.ts)), - "time_span": str(np.max(self.ts) - np.min(self.ts)), - "min_val": min_val, - "max_val": max_val, - "average_val": average_val} - - -def check_time_series(ts: np.ndarray, values: np.ndarray): - if len(ts) != len(values): - raise ValueError("ts and values must have the same length") - try: - ts = np.array(ts, dtype = np.datetime64) - except: - raise ValueError("ts must be a numpy array of datetime64 or convertible to it") - return ts - -def partition_sessionize(ts: np.ndarray, values: np.ndarray, gap = 20): - - ts = check_time_series(ts, values) - - nnz = np.nonzero(values) - breaks = np.diff(nnz[0]) > gap - assgnmts = np.hstack([[0], np.cumsum(breaks)]) - split_indices = np.argwhere(assgnmts[:-1] != assgnmts[1:]).flatten() + 1 - ts = ts[nnz] - values = values[nnz] - return [MetricsPartition(ts_split, values_split) for ts_split, values_split in - zip(np.split(ts, split_indices), np.split(values, split_indices))] - - -def partition_periodic(ts: np.ndarray, values: np.ndarray, period = None): - - ts = check_time_series(ts, values) - - confidence_metric = 100 if period is not None else None - if period is None: - dft = np.fft.fft(values) - freqs = np.fft.fftfreq(len(values), d=1) - magnitudes = np.abs(dft) - magnitudes_no_zero = magnitudes[1:] - freqs_no_zero = freqs[1:] - max_magnitude_idx = np.argmax(magnitudes_no_zero) - dominant_freq = freqs_no_zero[max_magnitude_idx] - dominant_magnitude = magnitudes_no_zero[max_magnitude_idx] - most_likely_period = 1 / dominant_freq - - total_magnitude = np.sum(magnitudes_no_zero) - confidence_metric = dominant_magnitude / total_magnitude - period = most_likely_period - - offsets = np.arange(int(period), len(values), int(period)) - return [MetricsPartition(ts_split, values_split) for ts_split, values_split in - zip(np.split(ts, offsets), np.split(values, offsets))] \ No newline at end of file diff --git a/src/formats/parquet.rs b/src/formats/parquet.rs index 8a8d8a1..ae7c8ba 100644 --- a/src/formats/parquet.rs +++ b/src/formats/parquet.rs @@ -41,11 +41,18 @@ use super::readers::ReaderType; use serde::{Deserialize, Serialize}; use tokio::task::JoinSet; -async fn get_metadata_bytes(reader: &mut AsyncReader, file_size: usize) -> Result { +async fn get_metadata_bytes( + reader: &mut AsyncReader, + file_size: usize, +) -> Result { // check file is large enough to hold footer - let footer: [u8; 8] = - reader.read_range(file_size as u64 - 8, file_size as u64).await?.to_byte_slice().try_into().unwrap(); + let footer: [u8; 8] = reader + .read_range(file_size as u64 - 8, file_size as u64) + .await? + .to_byte_slice() + .try_into() + .unwrap(); let metadata_len = decode_footer(&footer)?; let footer_metadata_len = FOOTER_SIZE + metadata_len; @@ -57,7 +64,9 @@ async fn get_metadata_bytes(reader: &mut AsyncReader, file_size: usize) -> Resul } let start = file_size as u64 - footer_metadata_len as u64; - let bytes = reader.read_range(start, start + metadata_len as u64).await?; + let bytes = reader + .read_range(start, start + metadata_len as u64) + .await?; Ok(bytes) } @@ -72,7 +81,8 @@ pub(crate) fn decode_page( let mut can_decompress = true; if let Some(ref header_v2) = page_header.data_page_header_v2 { - offset = (header_v2.definition_levels_byte_length + header_v2.repetition_levels_byte_length) as usize; + offset = (header_v2.definition_levels_byte_length + header_v2.repetition_levels_byte_length) + as usize; // When is_compressed flag is missing the page is considered compressed can_decompress = header_v2.is_compressed.unwrap_or(true); } @@ -85,10 +95,16 @@ pub(crate) fn decode_page( let mut decompressed = Vec::with_capacity(uncompressed_size); let compressed = &buffer.as_ref()[offset..]; decompressed.extend_from_slice(&buffer.as_ref()[..offset]); - decompressor.decompress(compressed, &mut decompressed, Some(uncompressed_size - offset))?; + decompressor.decompress( + compressed, + &mut decompressed, + Some(uncompressed_size - offset), + )?; if decompressed.len() != uncompressed_size { - return Err(LavaError::from(ParquetError::General("messed decompression".to_string()))); + return Err(LavaError::from(ParquetError::General( + "messed decompression".to_string(), + ))); } Bytes::from(decompressed) @@ -98,10 +114,9 @@ pub(crate) fn decode_page( let result = match page_header.type_ { PageType::DICTIONARY_PAGE => { - let dict_header = page_header - .dictionary_page_header - .as_ref() - .ok_or_else(|| ParquetError::General("Missing dictionary page header".to_string()))?; + let dict_header = page_header.dictionary_page_header.as_ref().ok_or_else(|| { + ParquetError::General("Missing dictionary page header".to_string()) + })?; let is_sorted = dict_header.is_sorted.unwrap_or(false); Page::DictionaryPage { buf: buffer, @@ -149,7 +164,10 @@ pub(crate) fn decode_page( Ok(result) } -fn read_page_header(reader: &C, offset: u64) -> Result<(usize, PageHeader), LavaError> { +fn read_page_header( + reader: &C, + offset: u64, +) -> Result<(usize, PageHeader), LavaError> { struct TrackedRead(R, usize); impl Read for TrackedRead { @@ -167,7 +185,10 @@ fn read_page_header(reader: &C, offset: u64) -> Result<(usize, P Ok((tracked.1, header)) } -async fn parse_metadatas(file_paths: &Vec, reader_type: ReaderType) -> HashMap { +async fn parse_metadatas( + file_paths: &Vec, + reader_type: ReaderType, +) -> HashMap { let iter = file_paths.iter().dedup(); let handles = stream::iter(iter) @@ -176,17 +197,25 @@ async fn parse_metadatas(file_paths: &Vec, reader_type: ReaderType) -> H let reader_type = reader_type.clone(); tokio::spawn(async move { - let (file_size, mut reader) = get_file_size_and_reader(file_path.clone(), reader_type).await.unwrap(); + let (file_size, mut reader) = + get_file_size_and_reader(file_path.clone(), reader_type) + .await + .unwrap(); - let metadata_bytes = get_metadata_bytes(&mut reader, file_size as usize).await.unwrap(); + let metadata_bytes = get_metadata_bytes(&mut reader, file_size as usize) + .await + .unwrap(); - let metadata = decode_metadata(metadata_bytes.to_byte_slice()).map_err(LavaError::from).unwrap(); + let metadata = decode_metadata(metadata_bytes.to_byte_slice()) + .map_err(LavaError::from) + .unwrap(); (file_path, metadata) }) }) .collect::>() .await; - let res: Vec> = futures::future::join_all(handles).await; + let res: Vec> = + futures::future::join_all(handles).await; let mut metadatas = HashMap::new(); @@ -217,11 +246,14 @@ pub async fn get_parquet_layout( file_path: &str, reader_type: ReaderType, ) -> Result<(Vec, ParquetLayout), LavaError> { - let (file_size, mut reader) = get_file_size_and_reader(file_path.to_string(), reader_type).await?; + let (file_size, mut reader) = + get_file_size_and_reader(file_path.to_string(), reader_type).await?; let metadata_bytes = get_metadata_bytes(&mut reader, file_size as usize).await?; let metadata = decode_metadata(metadata_bytes.to_byte_slice()).map_err(LavaError::from)?; - let codec_options = CodecOptionsBuilder::default().set_backward_compatible_lz4(false).build(); + let codec_options = CodecOptionsBuilder::default() + .set_backward_compatible_lz4(false) + .build(); let mut parquet_layout = ParquetLayout { num_row_groups: metadata.num_row_groups(), @@ -242,14 +274,19 @@ pub async fn get_parquet_layout( .columns() .iter() .position(|column| column.name() == column_name) - .expect(&format!("column {} not found in parquet file {}", column_name, file_path)); + .expect(&format!( + "column {} not found in parquet file {}", + column_name, file_path + )); //TODO: @rain we should parallelize this across row groups using tokio // this need to refactor the ParquetLayout data structure, since it won't cost too much time, postpone for now. for row_group in 0..metadata.num_row_groups() { let column = metadata.row_group(row_group).column(column_index); - let mut start = column.dictionary_page_offset().unwrap_or_else(|| column.data_page_offset()) as u64; + let mut start = column + .dictionary_page_offset() + .unwrap_or_else(|| column.data_page_offset()) as u64; let end = start + column.compressed_size() as u64; let compression_scheme = column.compression(); @@ -283,8 +320,10 @@ pub async fn get_parquet_layout( dictionary_page_size = page_header.compressed_page_size as usize + header_len; let page: Page = decode_page( page_header, - column_chunk_bytes - .slice((start as usize + header_len)..(start as usize + dictionary_page_size as usize)), + column_chunk_bytes.slice( + (start as usize + header_len) + ..(start as usize + dictionary_page_size as usize), + ), Type::BYTE_ARRAY, codec.as_mut(), ) @@ -294,10 +333,16 @@ pub async fn get_parquet_layout( } PageType::DATA_PAGE | PageType::DATA_PAGE_V2 => { let compressed_page_size = page_header.compressed_page_size; - parquet_layout.data_page_sizes.push(compressed_page_size as usize + header_len); - parquet_layout.data_page_offsets.push((column_chunk_offset + start) as usize); - - parquet_layout.dictionary_page_sizes.push(dictionary_page_size); + parquet_layout + .data_page_sizes + .push(compressed_page_size as usize + header_len); + parquet_layout + .data_page_offsets + .push((column_chunk_offset + start) as usize); + + parquet_layout + .dictionary_page_sizes + .push(dictionary_page_size); total_data_pages += 1; let page = decode_page( @@ -311,7 +356,9 @@ pub async fn get_parquet_layout( ) .unwrap(); - parquet_layout.data_page_num_rows.push(page.num_values() as usize); + parquet_layout + .data_page_num_rows + .push(page.num_values() as usize); total_values += page.num_values() as usize; start += compressed_page_size as u64 + header_len as u64; @@ -345,18 +392,23 @@ pub async fn get_parquet_layout( for _ in (0..total_values).step_by(10_000) { let array = array_reader.next_batch(10_000).unwrap(); - let new_array: Result<&arrow_array::GenericByteArray>, ArrowError> = - array - .as_any() - .downcast_ref::() - .ok_or_else(|| ArrowError::ParseError("Expects string array as first argument".to_string())); + let new_array: Result< + &arrow_array::GenericByteArray>, + ArrowError, + > = array.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::ParseError("Expects string array as first argument".to_string()) + }); let data = match new_array { Ok(_) => new_array.unwrap().to_data(), Err(_) => array .as_any() .downcast_ref::() - .ok_or_else(|| ArrowError::ParseError("Expects string or binary array as first argument".to_string())) + .ok_or_else(|| { + ArrowError::ParseError( + "Expects string or binary array as first argument".to_string(), + ) + }) .unwrap() .to_data(), }; @@ -389,14 +441,21 @@ pub async fn read_indexed_pages_async( // current implementation might re-read dictionary pages, this should be optimized // we are assuming that all the files are either on disk or cloud. - let codec_options = CodecOptionsBuilder::default().set_backward_compatible_lz4(false).build(); + let codec_options = CodecOptionsBuilder::default() + .set_backward_compatible_lz4(false) + .build(); let metadatas = match file_metadatas { Some(file_metadatas) => { println!("Using provided file metadatas"); let mut metadatas: HashMap = HashMap::new(); for (key, value) in file_metadatas.into_iter() { - metadatas.insert(key, decode_metadata(value.to_byte_slice()).map_err(LavaError::from).unwrap()); + metadatas.insert( + key, + decode_metadata(value.to_byte_slice()) + .map_err(LavaError::from) + .unwrap(), + ); } metadatas } @@ -405,9 +464,17 @@ pub async fn read_indexed_pages_async( let in_order: bool = in_order.unwrap_or(true); - let mut reader = get_reader(file_paths[0].clone(), reader_type.clone()).await.unwrap(); + let mut reader = get_reader(file_paths[0].clone(), reader_type.clone()) + .await + .unwrap(); - let iter = izip!(file_paths, row_groups, page_offsets, page_sizes, dict_page_sizes); + let iter = izip!( + file_paths, + row_groups, + page_offsets, + page_sizes, + dict_page_sizes + ); let start = std::time::Instant::now(); @@ -415,84 +482,115 @@ pub async fn read_indexed_pages_async( let mut join_set = JoinSet::new(); let iter: Vec<_> = stream::iter(iter) - .map(|(file_path, row_group, page_offset, page_size, dict_page_size)| { - let column_index = metadatas[&file_path] - .file_metadata() - .schema_descr() - .columns() - .iter() - .position(|column| column.name() == column_name) - .expect(&format!("column {} not found in parquet file {}", column_name, file_path)); - let column_descriptor = metadatas[&file_path].row_group(row_group).schema_descr().column(column_index); - - let compression_scheme = metadatas[&file_path].row_group(row_group).column(column_index).compression(); - let dict_page_offset = - metadatas[&file_path].row_group(row_group).column(column_index).dictionary_page_offset(); - let mut codec = create_codec(compression_scheme, &codec_options).unwrap().unwrap(); - - let mut reader_c = reader.clone(); - reader_c.update_filename(file_path).unwrap(); - - let future = async move { - let mut pages: Vec = Vec::new(); - if dict_page_size > 0 { - let start = dict_page_offset.unwrap() as u64; - let dict_page_bytes = reader_c.read_range(start, start + dict_page_size as u64).await.unwrap(); - let dict_page_bytes = Bytes::from(dict_page_bytes); - let (dict_header_len, dict_header) = read_page_header(&dict_page_bytes, 0).unwrap(); - let dict_page = decode_page( - dict_header, - dict_page_bytes.slice(dict_header_len..dict_page_size), + .map( + |(file_path, row_group, page_offset, page_size, dict_page_size)| { + let column_index = metadatas[&file_path] + .file_metadata() + .schema_descr() + .columns() + .iter() + .position(|column| column.name() == column_name) + .expect(&format!( + "column {} not found in parquet file {}", + column_name, file_path + )); + let column_descriptor = metadatas[&file_path] + .row_group(row_group) + .schema_descr() + .column(column_index); + + let compression_scheme = metadatas[&file_path] + .row_group(row_group) + .column(column_index) + .compression(); + let dict_page_offset = metadatas[&file_path] + .row_group(row_group) + .column(column_index) + .dictionary_page_offset(); + let mut codec = create_codec(compression_scheme, &codec_options) + .unwrap() + .unwrap(); + + let mut reader_c = reader.clone(); + reader_c.update_filename(file_path).unwrap(); + + let future = async move { + let mut pages: Vec = Vec::new(); + if dict_page_size > 0 { + let start = dict_page_offset.unwrap() as u64; + let dict_page_bytes = reader_c + .read_range(start, start + dict_page_size as u64) + .await + .unwrap(); + let dict_page_bytes = Bytes::from(dict_page_bytes); + let (dict_header_len, dict_header) = + read_page_header(&dict_page_bytes, 0).unwrap(); + let dict_page = decode_page( + dict_header, + dict_page_bytes.slice(dict_header_len..dict_page_size), + Type::BYTE_ARRAY, + Some(&mut codec), + ) + .unwrap(); + pages.push(dict_page); + } + + let page_bytes = reader_c + .read_range(page_offset, page_offset + page_size as u64) + .await + .unwrap(); + let (header_len, header) = read_page_header(&page_bytes, 0).unwrap(); + let page: Page = decode_page( + header, + page_bytes.slice(header_len..page_size), Type::BYTE_ARRAY, Some(&mut codec), ) .unwrap(); - pages.push(dict_page); - } - - let page_bytes = reader_c.read_range(page_offset, page_offset + page_size as u64).await.unwrap(); - let (header_len, header) = read_page_header(&page_bytes, 0).unwrap(); - let page: Page = - decode_page(header, page_bytes.slice(header_len..page_size), Type::BYTE_ARRAY, Some(&mut codec)) - .unwrap(); - let num_values = page.num_values(); - - pages.push(page); - let page_iterator = InMemoryPageIterator::new(vec![pages]); - let mut array_reader = - make_byte_array_reader(Box::new(page_iterator), column_descriptor.clone(), None).unwrap(); - let array = array_reader.next_batch(num_values as usize).unwrap(); - - let new_array: Result< - &arrow_array::GenericByteArray>, - ArrowError, - > = array - .as_any() - .downcast_ref::() - .ok_or_else(|| ArrowError::ParseError("Expects string array as first argument".to_string())); - - let data = match new_array { - Ok(_) => new_array.unwrap().to_data(), - Err(_) => array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - ArrowError::ParseError("Expects string or binary array as first argument".to_string()) - }) - .unwrap() - .to_data(), + let num_values = page.num_values(); + + pages.push(page); + let page_iterator = InMemoryPageIterator::new(vec![pages]); + let mut array_reader = make_byte_array_reader( + Box::new(page_iterator), + column_descriptor.clone(), + None, + ) + .unwrap(); + let array = array_reader.next_batch(num_values as usize).unwrap(); + + let new_array: Result< + &arrow_array::GenericByteArray>, + ArrowError, + > = array.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::ParseError("Expects string array as first argument".to_string()) + }); + + let data = match new_array { + Ok(_) => new_array.unwrap().to_data(), + Err(_) => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Expects string or binary array as first argument".to_string(), + ) + }) + .unwrap() + .to_data(), + }; + + data }; - data - }; - - if in_order { - let handle = tokio::spawn(future); - future_handles.push(handle); - } else { - join_set.spawn(future); - } - }) + if in_order { + let handle = tokio::spawn(future); + future_handles.push(handle); + } else { + join_set.spawn(future); + } + }, + ) .collect::>() .await; @@ -529,7 +627,10 @@ pub fn read_indexed_pages( file_metadatas: Option>, in_order: Option, ) -> Result, LavaError> { - let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); let res = rt.block_on(read_indexed_pages_async( column_name, diff --git a/src/formats/readers/mod.rs b/src/formats/readers/mod.rs index 5eb1b33..08b444a 100644 --- a/src/formats/readers/mod.rs +++ b/src/formats/readers/mod.rs @@ -3,6 +3,7 @@ use crate::lava::error::LavaError; use async_trait::async_trait; use bytes::Bytes; use local_reader::AsyncLocalReader; +use serde::de::DeserializeOwned; use std::collections::BTreeMap; use std::{env, os}; use std::{ @@ -108,7 +109,11 @@ impl AsyncReader { // only check the cache if self.filename has extension .lava if self.filename.ends_with(".lava") { - if "true" == env::var_os("CACHE_ENABLE").map(|s| s.to_ascii_lowercase()).unwrap_or_default() { + if "true" + == env::var_os("CACHE_ENABLE") + .map(|s| s.to_ascii_lowercase()) + .unwrap_or_default() + { // let path = std::path::Path::new(&value); // find path/filename.cache let mut conn = cache::get_redis_connection().await?; @@ -121,7 +126,9 @@ impl AsyncReader { if from >= start as u64 && to <= end as u64 { println!("cache hit"); let data = conn.get_data(&self.filename, from, to).await?; - let data = data[(from - start as u64) as usize..(to - start as u64) as usize].to_vec(); + let data = data + [(from - start as u64) as usize..(to - start as u64) as usize] + .to_vec(); return Ok(Bytes::from(data)); } } @@ -133,7 +140,11 @@ impl AsyncReader { } // theoretically we should try to return different types here, but Vec is def. the most common - pub async fn read_range_and_decompress(&mut self, from: u64, to: u64) -> Result, LavaError> { + pub async fn read_range_and_decompress( + &mut self, + from: u64, + to: u64, + ) -> Result, LavaError> { let compressed_posting_list_offsets = self.read_range(from, to).await?; let mut decompressor = Decoder::new(&compressed_posting_list_offsets[..])?; let mut serialized_posting_list_offsets: Vec = @@ -144,7 +155,9 @@ impl AsyncReader { } pub async fn read_usize_from_end(&mut self, n: u64) -> Result, LavaError> { - self.deref_mut().read_usize_from_end(-8 * (n as i64), n).await + self.deref_mut() + .read_usize_from_end(-8 * (n as i64), n) + .await } } @@ -194,14 +207,22 @@ pub async fn get_file_sizes_and_readers( readers.push(reader); } Ok(Err(e)) => return Err(e), // Handle error from inner task - Err(e) => return Err(LavaError::Parse(format!("Task join error: {}", e.to_string()))), // Handle join error + Err(e) => { + return Err(LavaError::Parse(format!( + "Task join error: {}", + e.to_string() + ))) + } // Handle join error } } Ok((file_sizes, readers)) } -pub async fn get_readers(files: &[String], reader_type: ReaderType) -> Result, LavaError> { +pub async fn get_readers( + files: &[String], + reader_type: ReaderType, +) -> Result, LavaError> { let tasks: Vec<_> = files .iter() .map(|file| { @@ -223,7 +244,12 @@ pub async fn get_readers(files: &[String], reader_type: ReaderType) -> Result return Err(e), // Handle error from inner task - Err(e) => return Err(LavaError::Parse(format!("Task join error: {}", e.to_string()))), // Handle join error + Err(e) => { + return Err(LavaError::Parse(format!( + "Task join error: {}", + e.to_string() + ))) + } // Handle join error } } @@ -300,3 +326,19 @@ pub async fn get_reader(file: String, reader_type: ReaderType) -> Result( + reader: &mut AsyncReader, + start: u64, + size: u64, +) -> Result +where + T: DeserializeOwned, +{ + let compressed = reader.read_range(start, start + size).await?; + let mut decompressor = Decoder::new(&compressed[..]).unwrap(); + let mut decompressed = Vec::new(); + std::io::copy(&mut decompressor, &mut decompressed)?; + let result: T = bincode::deserialize(&decompressed)?; + Ok(result) +} diff --git a/src/lava/bm25/bm25.rs b/src/lava/bm25/bm25.rs new file mode 100644 index 0000000..2e9384a --- /dev/null +++ b/src/lava/bm25/bm25.rs @@ -0,0 +1,581 @@ +use crate::formats::readers::get_file_size_and_reader; +use crate::lava::error::LavaError; +use crate::lava::plist::PListChunk; +use arrow::array::{make_array, Array, ArrayData, LargeStringArray, UInt64Array}; +use bincode; + +use std::collections::BTreeMap; + +use std::fs::File; +use std::io::Read; +use std::io::{BufWriter, Seek, SeekFrom, Write}; +use tokenizers::parallelism::MaybeParallelIterator; +use zstd::stream::encode_all; + +/* +Structure of the lava file +It is important to put the posting lists first. Just trust me bro. +compressed_serialized_tokenizer | compressed posting lists line by line | compressed term dictionary | compressed posting list offsets| +8 bytes = offsets of compressed term dict | 8 bytes = offset of compressed posting list offsets +*/ + +fn get_tokenizer(tokenizer_file: Option) -> Result<(Tokenizer, Vec), LavaError> { + // if the tokenizer file is provided, check if the file exists. If it does not exist, raise an Error + let tokenizer = if let Some(tokenizer_file) = tokenizer_file { + if !std::path::Path::new(&tokenizer_file).exists() { + return Err(LavaError::Parse( + "Tokenizer file does not exist".to_string(), + )); + } + println!("Tokenizer file: {}", tokenizer_file); + Tokenizer::from_file(tokenizer_file).unwrap() + } else { + Tokenizer::from_pretrained("bert-base-uncased", None).unwrap() + }; + + let serialized_tokenizer = serde_json::to_string(&tokenizer).unwrap(); + let compressed_tokenizer = + encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); + Ok((tokenizer, compressed_tokenizer)) +} + +/// Function that tokenizes the input text and returns a list of tokens. +#[tokio::main] +pub async fn build_lava_bm25( + output_file_name: String, + array: ArrayData, + uid: ArrayData, + tokenizer_file: Option, + k1: Option, + b: Option, +) -> Result, LavaError> { + // if k1 and b are not provided, set them to default value + let k1: f32 = k1.unwrap_or(1.2); + let b: f32 = b.unwrap_or(0.75); + + let array = make_array(array); + // let uid = make_array(ArrayData::from_pyarrow(uid)?); + let uid = make_array(uid); + let array: &arrow_array::GenericByteArray> = array + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects string array as first argument".to_string(), + ))?; + + let uid = uid + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects uint64 array as second argument".to_string(), + ))?; + + if array.len() != uid.len() { + return Err(LavaError::Parse( + "The length of the array and the uid array must be the same".to_string(), + )); + } + + let (tokenizer, compressed_tokenizer) = get_tokenizer(tokenizer_file)?; + let vocab_size: usize = tokenizer.get_vocab_size(false); + + let mut texts = Vec::with_capacity(array.len()); + for i in 0..array.len() { + let text = array.value(i); + texts.push(text); + } + + let encodings = texts + .into_maybe_par_iter() + .map(|text| { + let encoding = tokenizer.encode(text, false).unwrap(); + encoding.get_ids().to_vec() + }) + .collect::>>(); + + let mut inverted_index: Vec> = vec![BTreeMap::new(); vocab_size]; + let mut token_counts: Vec = vec![0; vocab_size]; + + let mut avg_len: f32 = 0.0; + for encoding in encodings.iter() { + avg_len += encoding.len() as f32; + } + avg_len /= encodings.len() as f32; + + for (i, encoding) in encodings.iter().enumerate() { + let this_uid = uid.value(i) as usize; + let mut local_token_counts: BTreeMap = BTreeMap::new(); + for key in encoding { + *local_token_counts.entry(*key).or_insert(0) += 1; + } + for key in local_token_counts.keys() { + let local_count = local_token_counts[key]; + let local_factor: f32 = (local_count as f32) * (k1 + 1.0) + / (local_count as f32 + k1 * (1.0 - b + b * encoding.len() as f32 / avg_len)); + + inverted_index[*key as usize] + .entry(this_uid) + .and_modify(|e| *e = (*e).max(local_factor)) + .or_insert(local_factor); + + token_counts[*key as usize] += 1; + } + } + + let mut file = File::create(output_file_name)?; + file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; + file.write_all(&compressed_tokenizer)?; + + let bytes = bincode::serialize(&token_counts)?; + let compressed_token_counts: Vec = encode_all(&bytes[..], 0).expect("Compression failed"); + + // Handle the compressed data (for example, saving to a file or sending over a network) + println!( + "Compressed token counts size: {} number of tokens: {}", + compressed_token_counts.len(), + inverted_index.len() + ); + + let mut plist_offsets: Vec = vec![file.seek(SeekFrom::Current(0))?]; + let mut plist_elems: Vec = vec![0]; + let mut plist_chunk = PListChunk::new()?; + let mut counter: u64 = 0; + + for (_key, value) in inverted_index.iter().enumerate() { + let plist = if value.len() == 0 { + vec![] + } else { + let mut result = vec![]; + for (key, val) in value.iter() { + result.push(*key as u64); + // quantize the score to int. + result.push((*val * 100 as f32) as u64); + } + result + }; + + counter += 1; + + let written = plist_chunk.add_plist(&plist)?; + if written > 1024 * 1024 || counter == inverted_index.len() as u64 { + let bytes = plist_chunk.finalize_compression()?; + file.write_all(&bytes)?; + plist_offsets.push(plist_offsets[plist_offsets.len() - 1] + bytes.len() as u64); + plist_elems.push(counter); + plist_chunk = PListChunk::new()?; + } + } + + plist_offsets.append(&mut plist_elems); + + let compressed_term_dict_offset = file.seek(SeekFrom::Current(0))?; + file.write_all(&compressed_token_counts)?; + + let compressed_plist_offsets_offset = file.seek(SeekFrom::Current(0))?; + let serialized = bincode::serialize(&plist_offsets).unwrap(); + let compressed_plist_offsets = + encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); + file.write_all(&compressed_plist_offsets)?; + + file.write_all(&(compressed_term_dict_offset as u64).to_le_bytes())?; + file.write_all(&(compressed_plist_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(encodings.len() as u64).to_le_bytes())?; + + let cache_end = file.seek(SeekFrom::Current(0))? as usize; + + Ok(vec![(compressed_term_dict_offset as usize, cache_end)]) +} + +struct PListChunkIterator { + reader: AsyncReader, + current_offset_in_chunk: usize, + current_chunk_offset: usize, + current_chunk: Vec>, + plist_offsets: Vec, + plist_elems: Vec, +} + +impl PListChunkIterator { + // take ownership of the data structures + pub async fn new( + mut reader: AsyncReader, + plist_offsets: Vec, + plist_elems: Vec, + ) -> Result { + // read the first chunk + + let buffer3 = reader + .read_range(plist_offsets[0], plist_offsets[1]) + .await?; + let result: Vec> = + PListChunk::search_compressed(buffer3.to_vec(), &(0..plist_elems[1]).collect()) + .unwrap(); + + Ok(Self { + reader: reader, + current_offset_in_chunk: 0, + current_chunk_offset: 0, + current_chunk: result, + plist_offsets: plist_offsets, + plist_elems: plist_elems, + }) + } + + pub fn get(&mut self) -> Vec { + self.current_chunk[self.current_offset_in_chunk as usize].clone() + } + + pub async fn advance(&mut self) -> Result<(), LavaError> { + self.current_offset_in_chunk += 1; + if self.current_offset_in_chunk == self.current_chunk.len() { + // read the next chunk + self.current_offset_in_chunk = 0; + self.current_chunk_offset += 1; + if self.current_chunk_offset + 2 > self.plist_offsets.len() { + return Err(LavaError::Parse("out of chunks".to_string())); + } + + let buffer3 = self + .reader + .read_range( + self.plist_offsets[self.current_chunk_offset], + self.plist_offsets[self.current_chunk_offset + 1], + ) + .await?; + + self.current_chunk = PListChunk::search_compressed( + buffer3.to_vec(), + &(0..(self.plist_elems[self.current_chunk_offset + 1] + - self.plist_elems[self.current_chunk_offset])) + .collect(), + ) + .unwrap(); + } + + Ok(()) + } +} + +pub(crate) async fn merge_lava_bm25( + condensed_lava_file: &str, + lava_files: Vec, + uid_offsets: Vec, + reader_type: ReaderType, +) -> Result, LavaError> { + // let mut builder = Fs::default(); + // let current_path = env::current_dir()?; + // builder.root(current_path.to_str().expect("no path")); + // let operator = Operator::new(builder)?.finish(); + + let mut file_sizes: Vec = Vec::with_capacity(lava_files.len()); + let mut plist_chunk_iterators: Vec = Vec::with_capacity(lava_files.len()); + + let mut combined_token_counts: Vec = Vec::new(); + let mut total_num_documents: u64 = 0; + let mut compressed_tokenizer: Option> = None; + + for file in lava_files { + let reader_type = reader_type.clone(); + let (file_size, mut reader) = get_file_size_and_reader(file, reader_type).await?; + let file_size = file_size as u64; + + let results = reader.read_usize_from_end(3).await?; + let compressed_term_dict_offset = results[0]; + let compressed_plist_offsets_offset = results[1]; + let num_documents = results[2]; + total_num_documents += num_documents; + + let compressed_token_counts = reader + .read_range(compressed_term_dict_offset, compressed_plist_offsets_offset) + .await?; + + let mut decompressed_token_counts: Vec = Vec::new(); + let mut decompressor: Decoder<'_, BufReader<&[u8]>> = + Decoder::new(&compressed_token_counts[..])?; + decompressor.read_to_end(&mut decompressed_token_counts)?; + let token_counts: Vec = bincode::deserialize(&decompressed_token_counts)?; + + if combined_token_counts.len() == 0 { + combined_token_counts = token_counts; + } else { + // add token_counts to combined_token_counts + for (i, count) in token_counts.iter().enumerate() { + combined_token_counts[i] += count; + } + } + + let buffer2 = reader + .read_range(compressed_plist_offsets_offset, file_size - 24) + .await?; + + decompressor = Decoder::new(&buffer2[..])?; + let mut decompressed_serialized_plist_offsets: Vec = + Vec::with_capacity(buffer2.len() as usize); + decompressor.read_to_end(&mut decompressed_serialized_plist_offsets)?; + let this_plist_offsets: Vec = + bincode::deserialize(&decompressed_serialized_plist_offsets)?; + + if (this_plist_offsets.len() % 2) != 0 { + let err = LavaError::Parse("data corruption".to_string()); + return Err(err); + } + let num_elements = this_plist_offsets.len() / 2; + + let compressed_tokenizer_size = reader.read_usize_from_start(0, 1).await?[0]; + let this_compressed_tokenizer: bytes::Bytes = + reader.read_range(8, 8 + compressed_tokenizer_size).await?; + + match &compressed_tokenizer { + Some(value) => assert!( + this_compressed_tokenizer == value, + "detected different tokenizers, cannot merge, something is very wrong." + ), + None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()), + } + + file_sizes.push(file_size); + plist_chunk_iterators.push( + PListChunkIterator::new( + reader, + this_plist_offsets[..num_elements].to_vec(), + this_plist_offsets[num_elements..].to_vec(), + ) + .await?, + ); + } + + let mut output_file = File::create(condensed_lava_file)?; + + let compressed_tokenizer = compressed_tokenizer.unwrap(); + // let compressed_tokenizer_len = compressed_tokenizer.len(); + output_file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; + output_file.write_all(&compressed_tokenizer)?; + + let mut new_plist_offsets: Vec = vec![output_file.seek(SeekFrom::Current(0))?]; + let mut new_plist_elems: Vec = vec![0]; + let mut plist_chunk = PListChunk::new()?; + let mut counter: u64 = 0; + + for tok in 0..combined_token_counts.len() { + // Find the smallest current line + + let mut plist: Vec = vec![]; + + for i in 0..plist_chunk_iterators.len() { + let this_plist: Vec = plist_chunk_iterators[i].get(); + assert_eq!(this_plist.len() % 2, 0); + + for (j, item) in this_plist.iter().enumerate() { + if j % 2 == 0 { + // page offset + plist.push(*item + uid_offsets[i]); + } else { + // quantized score + plist.push(*item); + } + } + + // this will return error for the last one, but it's ok + let _ = plist_chunk_iterators[i].advance().await; + } + + counter += 1; + + let plist = Vec::from_iter(plist.into_iter()); + let written = plist_chunk.add_plist(&plist)?; + if written > 1024 * 1024 || tok == combined_token_counts.len() - 1 { + let bytes = plist_chunk.finalize_compression()?; + let this_len: u64 = bytes.len() as u64; + + output_file.write(&bytes)?; + new_plist_offsets.push(new_plist_offsets[new_plist_offsets.len() - 1] + this_len); + new_plist_elems.push(counter); + plist_chunk = PListChunk::new()?; + } + } + + new_plist_offsets.append(&mut new_plist_elems); + + let bytes = bincode::serialize(&combined_token_counts)?; + let compressed_token_counts = encode_all(&bytes[..], 0).expect("Compression failed"); + + let compressed_term_dict_offset = output_file.seek(SeekFrom::Current(0))?; + output_file.write(&compressed_token_counts)?; + + let serialized = bincode::serialize(&new_plist_offsets).unwrap(); + let compressed_plist_offsets = + encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); + + let compressed_plist_offsets_offset = + compressed_term_dict_offset + compressed_token_counts.len() as u64; + output_file.write(&compressed_plist_offsets)?; + + output_file.write(&(compressed_term_dict_offset as u64).to_le_bytes())?; + output_file.write(&(compressed_plist_offsets_offset as u64).to_le_bytes())?; + output_file.write(&(total_num_documents as u64).to_le_bytes())?; + + Ok(vec![( + compressed_term_dict_offset as usize, + output_file.seek(SeekFrom::Current(0))? as usize, + )]) +} + +pub(crate) async fn search_bm25_async( + file_sizes: Vec, + mut readers: Vec, + query_tokens: Vec, + query_weights: Vec, + k: usize, +) -> Result, LavaError> { + let mut idf: HashMap = HashMap::new(); + let mut total_token_counts: HashMap = HashMap::new(); + for token in query_tokens.iter() { + total_token_counts.insert(*token, 0); + } + let mut total_documents: usize = 0; + let mut all_plist_offsets: Vec> = Vec::new(); + let mut chunks_to_search: HashMap<(usize, usize), Vec<(u32, u64)>> = HashMap::new(); + + for i in 0..readers.len() { + let results = readers[i].read_usize_from_end(3).await?; + let compressed_term_dictionary_offset = results[0]; + let compressed_plist_offsets_offset = results[1]; + let num_documents = results[2]; + + // now read the term dictionary + let token_counts = readers[i] + .read_range_and_decompress( + compressed_term_dictionary_offset, + compressed_plist_offsets_offset, + ) + .await?; + + for query_token in query_tokens.iter() { + total_token_counts.insert( + *query_token, + total_token_counts[query_token] + token_counts[*query_token as usize] as usize, + ); + } + total_documents += num_documents as usize; + + let plist_offsets = readers[i] + .read_range_and_decompress(compressed_plist_offsets_offset, file_sizes[i] as u64 - 24) + .await?; + + if plist_offsets.len() % 2 != 0 { + let err = LavaError::Parse("data corruption".to_string()); + return Err(err); + } + + let num_chunks: usize = plist_offsets.len() / 2; + let term_dict_len: &[u64] = &plist_offsets[num_chunks..]; + + for token in query_tokens.iter() { + let tok = *token as u64; + let (idx, offset) = match term_dict_len.binary_search(&tok) { + Ok(idx) => (idx, 0), + Err(idx) => (idx - 1, tok - term_dict_len[idx - 1]), + }; + + chunks_to_search + .entry((i as usize, idx)) + .or_insert_with(Vec::new) + .push((*token, offset as u64)); + } + + all_plist_offsets.push(plist_offsets); + } + + // compute the weighted IDF for each query token + for (i, query_token) in query_tokens.iter().enumerate() { + let query_weight = query_weights[i]; + let query_token = *query_token; + let token_count = total_token_counts[&query_token]; + idf.insert( + query_token, + query_weight + * ((total_documents as f32 - token_count as f32 + 0.5) + / (token_count as f32 + 0.5) + + 1.0) + .ln(), + ); + } + + let mut plist_result: Vec<(u64, u64)> = Vec::new(); + let mut page_scores: HashMap<(u64, u64), f32> = HashMap::new(); + + let mut join_set: JoinSet, LavaError>> = JoinSet::new(); + // need to parallelize this @Rain. + for (file_id, chunk_id, tokens, offsets) in + chunks_to_search + .into_iter() + .map(|((file_id, chunk_id), token_offsets)| { + let (tokens, offsets): (Vec, Vec) = token_offsets.into_iter().unzip(); + (file_id, chunk_id, Arc::new(tokens), Arc::new(offsets)) + }) + { + let reader_type = match readers[file_id].reader { + ClonableAsyncReader::AwsSdk(_) => ReaderType::AwsSdk, + ClonableAsyncReader::Http(_) => ReaderType::Http, + ClonableAsyncReader::Local(_) => ReaderType::Local, + }; + + let mut reader = match reader_type { + ReaderType::AwsSdk | ReaderType::Http => readers[file_id].clone(), + ReaderType::Local => { + get_file_size_and_reader(readers[file_id].filename.clone(), reader_type) + .await + .unwrap() + .1 + } + }; + let start = all_plist_offsets[file_id][chunk_id]; + let end = all_plist_offsets[file_id][chunk_id + 1]; + let tokens = tokens.clone(); + let offsets = offsets.clone(); + + join_set.spawn(async move { + // println!("file_id: {}, chunk_id: {}", file_id, chunk_id); + let buffer3 = reader.read_range(start, end).await?; + + // get all the second item in the offsets into its own vector + + let results: Vec> = + PListChunk::search_compressed(buffer3.to_vec(), offsets.as_ref())?; + + let mut res = vec![]; + for (i, result) in results.iter().enumerate() { + let token = &tokens[i]; + assert_eq!(result.len() % 2, 0); + for i in (0..result.len()).step_by(2) { + let uid = result[i]; + let page_score = result[i + 1]; + res.push((file_id, uid, *token, page_score)); + } + } + Ok(res) + }); + } + + while let Some(res) = join_set.join_next().await { + let res = res.map_err(|e| LavaError::Parse(format!("join error: {:?}", e)))??; + for (file_id, uid, token, page_score) in res { + page_scores + .entry((file_id as u64, uid)) + .and_modify(|e| *e += idf[&token] * page_score as f32) + .or_insert(idf[&token] * page_score as f32); + } + } + + // sort the page scores by descending order + let mut page_scores_vec: Vec<((u64, u64), f32)> = page_scores.into_iter().collect(); + page_scores_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + // get the top k results + for (uid, _score) in page_scores_vec.iter().take(k) { + // println!("{}", score); + plist_result.push(*uid); + } + + Ok(plist_result) +} diff --git a/src/lava/bm25/mod.rs b/src/lava/bm25/mod.rs new file mode 100644 index 0000000..af29b15 --- /dev/null +++ b/src/lava/bm25/mod.rs @@ -0,0 +1,5 @@ +mod bm25; + +pub use bm25::build_lava_bm25; +pub(crate) use bm25::merge_lava_bm25; +pub(crate) use bm25::search_bm25_async; diff --git a/src/lava/build.rs b/src/lava/build.rs deleted file mode 100644 index eaf1cbc..0000000 --- a/src/lava/build.rs +++ /dev/null @@ -1,661 +0,0 @@ -use crate::lava::constants::*; -use crate::lava::error::LavaError; -use crate::lava::plist::PListChunk; -use crate::lava::trie::{BinaryTrieNode, FastTrie}; -use crate::lava::wavelet_tree::{construct_wavelet_tree, write_wavelet_tree_to_disk, WaveletTree}; -use arrow::array::{make_array, Array, ArrayData, LargeStringArray, UInt64Array}; -use bincode; -use byteorder::{LittleEndian, WriteBytesExt}; -use bytes; -use divsufsort::sort_in_place; -use itertools::Itertools; -use serde_json; -use std::collections::BTreeMap; -use std::collections::BTreeSet; -use std::collections::HashMap; -use std::collections::HashSet; -use std::fs::File; -use std::io::Read; -use std::io::{BufWriter, Seek, SeekFrom, Write}; -use tokenizers::parallelism::MaybeParallelIterator; -use tokenizers::tokenizer::Tokenizer; // You'll need the `byteorder` crate -use zstd::stream::encode_all; - -use rayon::prelude::*; - -fn get_tokenizer(tokenizer_file: Option) -> Result<(Tokenizer, Vec), LavaError> { - // if the tokenizer file is provided, check if the file exists. If it does not exist, raise an Error - let tokenizer = if let Some(tokenizer_file) = tokenizer_file { - if !std::path::Path::new(&tokenizer_file).exists() { - return Err(LavaError::Parse("Tokenizer file does not exist".to_string())); - } - println!("Tokenizer file: {}", tokenizer_file); - Tokenizer::from_file(tokenizer_file).unwrap() - } else { - Tokenizer::from_pretrained("bert-base-uncased", None).unwrap() - }; - - let serialized_tokenizer = serde_json::to_string(&tokenizer).unwrap(); - let compressed_tokenizer = encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); - Ok((tokenizer, compressed_tokenizer)) -} - -#[tokio::main] -pub async fn build_lava_uuid( - output_file_name: String, - array: ArrayData, - uid: ArrayData, -) -> Result, LavaError> { - let array = make_array(array); - // let uid = make_array(ArrayData::from_pyarrow(uid)?); - let uid = make_array(uid); - let array: &arrow_array::GenericByteArray> = array - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; - - let uid: &arrow_array::PrimitiveArray = uid - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; - - if array.len() != uid.len() { - return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); - } - - let mut texts = Vec::with_capacity(array.len()); - for i in 0..array.len() { - let text = array.value(i); - texts.push(text.as_bytes().to_vec()); - } - let mut inds = Vec::with_capacity(array.len()); - for i in 0..uid.len() { - inds.push(vec![uid.value(i) as usize]); - } - - let root = BinaryTrieNode::build(&texts, &inds); - let fast_trie = FastTrie::new(root, Some(16)); - let (serialized_fast_trie, (cache_start, cache_end)) = fast_trie.serialize(); - std::fs::write(output_file_name, serialized_fast_trie).unwrap(); - - Ok(vec![(cache_start, cache_end)]) -} - -/* -Structure of the lava file -It is important to put the posting lists first. Just trust me bro. -compressed_serialized_tokenizer | compressed posting lists line by line | compressed term dictionary | compressed posting list offsets| -8 bytes = offsets of compressed term dict | 8 bytes = offset of compressed posting list offsets -*/ - -/// Function that tokenizes the input text and returns a list of tokens. -#[tokio::main] -pub async fn build_lava_bm25( - output_file_name: String, - array: ArrayData, - uid: ArrayData, - tokenizer_file: Option, - k1: Option, - b: Option, -) -> Result, LavaError> { - // if k1 and b are not provided, set them to default value - let k1: f32 = k1.unwrap_or(1.2); - let b: f32 = b.unwrap_or(0.75); - - let array = make_array(array); - // let uid = make_array(ArrayData::from_pyarrow(uid)?); - let uid = make_array(uid); - let array: &arrow_array::GenericByteArray> = array - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; - - let uid = uid - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; - - if array.len() != uid.len() { - return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); - } - - let (tokenizer, compressed_tokenizer) = get_tokenizer(tokenizer_file)?; - let vocab_size: usize = tokenizer.get_vocab_size(false); - - let mut texts = Vec::with_capacity(array.len()); - for i in 0..array.len() { - let text = array.value(i); - texts.push(text); - } - - let encodings = texts - .into_maybe_par_iter() - .map(|text| { - let encoding = tokenizer.encode(text, false).unwrap(); - encoding.get_ids().to_vec() - }) - .collect::>>(); - - let mut inverted_index: Vec> = vec![BTreeMap::new(); vocab_size]; - let mut token_counts: Vec = vec![0; vocab_size]; - - let mut avg_len: f32 = 0.0; - for encoding in encodings.iter() { - avg_len += encoding.len() as f32; - } - avg_len /= encodings.len() as f32; - - for (i, encoding) in encodings.iter().enumerate() { - let this_uid = uid.value(i) as usize; - let mut local_token_counts: BTreeMap = BTreeMap::new(); - for key in encoding { - *local_token_counts.entry(*key).or_insert(0) += 1; - } - for key in local_token_counts.keys() { - let local_count = local_token_counts[key]; - let local_factor: f32 = (local_count as f32) * (k1 + 1.0) - / (local_count as f32 + k1 * (1.0 - b + b * encoding.len() as f32 / avg_len)); - - inverted_index[*key as usize] - .entry(this_uid) - .and_modify(|e| *e = (*e).max(local_factor)) - .or_insert(local_factor); - - token_counts[*key as usize] += 1; - } - } - - let mut file = File::create(output_file_name)?; - file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; - file.write_all(&compressed_tokenizer)?; - - let bytes = bincode::serialize(&token_counts)?; - let compressed_token_counts: Vec = encode_all(&bytes[..], 0).expect("Compression failed"); - - // Handle the compressed data (for example, saving to a file or sending over a network) - println!( - "Compressed token counts size: {} number of tokens: {}", - compressed_token_counts.len(), - inverted_index.len() - ); - - let mut plist_offsets: Vec = vec![file.seek(SeekFrom::Current(0))?]; - let mut plist_elems: Vec = vec![0]; - let mut plist_chunk = PListChunk::new()?; - let mut counter: u64 = 0; - - for (_key, value) in inverted_index.iter().enumerate() { - let plist = if value.len() == 0 { - vec![] - } else { - let mut result = vec![]; - for (key, val) in value.iter() { - result.push(*key as u64); - // quantize the score to int. - result.push((*val * 100 as f32) as u64); - } - result - }; - - counter += 1; - - let written = plist_chunk.add_plist(&plist)?; - if written > 1024 * 1024 || counter == inverted_index.len() as u64 { - let bytes = plist_chunk.finalize_compression()?; - file.write_all(&bytes)?; - plist_offsets.push(plist_offsets[plist_offsets.len() - 1] + bytes.len() as u64); - plist_elems.push(counter); - plist_chunk = PListChunk::new()?; - } - } - - plist_offsets.append(&mut plist_elems); - - let compressed_term_dict_offset = file.seek(SeekFrom::Current(0))?; - file.write_all(&compressed_token_counts)?; - - let compressed_plist_offsets_offset = file.seek(SeekFrom::Current(0))?; - let serialized = bincode::serialize(&plist_offsets).unwrap(); - let compressed_plist_offsets = encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); - file.write_all(&compressed_plist_offsets)?; - - file.write_all(&(compressed_term_dict_offset as u64).to_le_bytes())?; - file.write_all(&(compressed_plist_offsets_offset as u64).to_le_bytes())?; - file.write_all(&(encodings.len() as u64).to_le_bytes())?; - - let cache_end = file.seek(SeekFrom::Current(0))? as usize; - - Ok(vec![(compressed_term_dict_offset as usize, cache_end)]) -} - -pub async fn _build_lava_substring_char_wavelet( - output_file_name: String, - texts: Vec<(u64, String)>, - char_skip_factor: u32, -) -> Result, LavaError> { - let named_encodings = texts - .into_iter() - .map(|(uid, text)| { - let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); - let result: Vec = if char_skip_factor == 1 { - lower.chars().filter(|id| !SKIP.chars().contains(id)).map(|c| c as u8).collect() - } else { - lower - .chars() - .filter(|id| !SKIP.chars().contains(id)) - .enumerate() - .filter(|&(index, _)| index % char_skip_factor as usize == 1) - .map(|(_, c)| c as u8) - .collect() - }; - (vec![uid; result.len()], result) - }) - .collect::, Vec)>>(); - - let uids: Vec = named_encodings.iter().map(|(uid, _)| uid).flatten().cloned().collect::>(); - let encodings: Vec = named_encodings.into_iter().map(|(_, text)| text).flatten().collect::>(); - - let mut sa: Vec = (0..encodings.len() as i32).collect(); - - sort_in_place(&encodings, &mut sa); - - let mut idx: Vec = Vec::with_capacity(encodings.len()); - let mut bwt: Vec = Vec::with_capacity(encodings.len()); - let mut total_counts: Vec = vec![0; 256]; - for i in 0..sa.len() { - let char = if sa[i] == 0 { encodings[encodings.len() - 1] } else { encodings[(sa[i] - 1) as usize] }; - bwt.push(char); - total_counts[char as usize] += 1; - if sa[i] == 0 { - idx.push(uids[uids.len() - 1]); - } else { - idx.push(uids[(sa[i] - 1) as usize]); - } - } - - let mut cumulative_counts = vec![0; 256]; - cumulative_counts[0] = 0; - for i in 1..256 { - cumulative_counts[i] = cumulative_counts[i - 1] + total_counts[i - 1]; - } - - let wavelet_tree = construct_wavelet_tree(&bwt); - - let mut file = File::create(output_file_name)?; - - let (offsets, level_offsets) = write_wavelet_tree_to_disk(&wavelet_tree, &mut file).unwrap(); - - // print out total file size so far - println!("total file size: {}", file.seek(SeekFrom::Current(0))?); - - let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; - - for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { - let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; - let serialized_slice = bincode::serialize(slice)?; - let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); - file.write_all(&compressed_slice)?; - posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); - } - - let metadata: (Vec, Vec, Vec, Vec, usize) = - (offsets, level_offsets, posting_list_offsets, cumulative_counts, bwt.len()); - - let cache_start = file.seek(SeekFrom::Current(0))? as usize; - - let serialized_metadata = bincode::serialize(&metadata)?; - let compressed_metadata = encode_all(&serialized_metadata[..], 0).expect("Compression failed"); - file.write_all(&compressed_metadata)?; - file.write_all(&cache_start.to_le_bytes())?; - - let cache_end = file.seek(SeekFrom::Current(0))? as usize; - - Ok(vec![(cache_start, cache_end)]) -} - -pub async fn _build_lava_substring_char( - output_file_name: String, - texts: Vec<(u64, String)>, - char_skip_factor: u32, -) -> Result, LavaError> { - let named_encodings = texts - .into_iter() - .map(|(uid, text)| { - let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); - let result: Vec = if char_skip_factor == 1 { - lower.chars().filter(|id| !SKIP.chars().contains(id)).map(|c| c as u8).collect() - } else { - lower - .chars() - .filter(|id| !SKIP.chars().contains(id)) - .enumerate() - .filter(|&(index, _)| index % char_skip_factor as usize == 1) - .map(|(_, c)| c as u8) - .collect() - }; - (vec![uid; result.len()], result) - }) - .collect::, Vec)>>(); - - let uids: Vec = named_encodings.iter().map(|(uid, _)| uid).flatten().cloned().collect::>(); - let encodings: Vec = named_encodings.into_iter().map(|(_, text)| text).flatten().collect::>(); - - let mut sa: Vec = (0..encodings.len() as i32).collect(); - - sort_in_place(&encodings, &mut sa); - - let mut idx: Vec = Vec::with_capacity(encodings.len()); - let mut bwt: Vec = Vec::with_capacity(encodings.len()); - for i in 0..sa.len() { - if sa[i] == 0 { - bwt.push(encodings[encodings.len() - 1]); - idx.push(uids[uids.len() - 1]); - } else { - bwt.push(encodings[(sa[i] - 1) as usize]); - idx.push(uids[(sa[i] - 1) as usize]); - } - } - - let mut file = File::create(output_file_name)?; - - let mut fm_chunk_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; - - let mut current_chunk: Vec = vec![]; - let mut current_chunk_counts: HashMap = HashMap::new(); - let mut next_chunk_counts: HashMap = HashMap::new(); - - for i in 0..bwt.len() { - let current_tok = bwt[i]; - next_chunk_counts.entry(current_tok).and_modify(|count| *count += 1).or_insert(1); - current_chunk.push(current_tok); - - if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt.len() - 1 { - let serialized_counts = bincode::serialize(¤t_chunk_counts)?; - let compressed_counts = encode_all(&serialized_counts[..], 10).expect("Compression failed"); - println!("chunk size: {}", compressed_counts.len()); - file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; - file.write_all(&compressed_counts)?; - let serialized_chunk = bincode::serialize(¤t_chunk)?; - let compressed_chunk = encode_all(&serialized_chunk[..], 10).expect("Compression failed"); - file.write_all(&compressed_chunk)?; - fm_chunk_offsets.push(file.seek(SeekFrom::Current(0))? as usize); - current_chunk_counts = next_chunk_counts.clone(); - current_chunk = vec![]; - } - } - // print out total file size so far - println!("total file size: {}", file.seek(SeekFrom::Current(0))?); - - let mut cumulative_counts: Vec = vec![0]; - for i in 0..256 { - cumulative_counts.push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u8)).unwrap_or(&0)); - } - - let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; - - for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { - let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; - let serialized_slice = bincode::serialize(slice)?; - let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); - file.write_all(&compressed_slice)?; - posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); - } - - let cache_start = file.seek(SeekFrom::Current(0))? as usize; - - let fm_chunk_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; - let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; - let compressed_fm_chunk_offsets = encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); - file.write_all(&compressed_fm_chunk_offsets)?; - - let posting_list_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; - let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; - let compressed_posting_list_offsets = - encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); - file.write_all(&compressed_posting_list_offsets)?; - - let total_counts_offset = file.seek(SeekFrom::Current(0))? as usize; - let serialized_total_counts = bincode::serialize(&cumulative_counts)?; - let compressed_total_counts: Vec = encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); - file.write_all(&compressed_total_counts)?; - - file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; - file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; - file.write_all(&(total_counts_offset as u64).to_le_bytes())?; - file.write_all(&(bwt.len() as u64).to_le_bytes())?; - - let cache_end = file.seek(SeekFrom::Current(0))? as usize; - - Ok(vec![(cache_start, cache_end)]) -} - -#[tokio::main] -pub async fn build_lava_substring_char( - output_file_name: String, - array: ArrayData, - uid: ArrayData, - char_skip_factor: Option, -) -> Result, LavaError> { - let array = make_array(array); - // let uid = make_array(ArrayData::from_pyarrow(uid)?); - let uid = make_array(uid); - - let char_skip_factor = char_skip_factor.unwrap_or(1); - - let array: &arrow_array::GenericByteArray> = array - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; - - let uid = uid - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; - - if array.len() != uid.len() { - return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); - } - - let mut texts: Vec<(u64, String)> = Vec::with_capacity(array.len()); - for i in 0..array.len() { - let text = array.value(i); - texts.push((uid.value(i), text.to_string())); - } - - println!("made it to this point"); - // _build_lava_substring_char(output_file_name, texts, char_skip_factor).await - _build_lava_substring_char_wavelet(output_file_name, texts, char_skip_factor).await -} - -#[tokio::main] -pub async fn build_lava_substring( - output_file_name: String, - array: ArrayData, - uid: ArrayData, - tokenizer_file: Option, - token_skip_factor: Option, -) -> Result, LavaError> { - let array = make_array(array); - // let uid = make_array(ArrayData::from_pyarrow(uid)?); - let uid = make_array(uid); - - let token_skip_factor = token_skip_factor.unwrap_or(1); - - let tokenizer = if let Some(tokenizer_file) = tokenizer_file { - if !std::path::Path::new(&tokenizer_file).exists() { - return Err(LavaError::Parse("Tokenizer file does not exist".to_string())); - } - println!("Tokenizer file: {}", tokenizer_file); - Tokenizer::from_file(tokenizer_file).unwrap() - } else { - Tokenizer::from_pretrained("bert-base-uncased", None).unwrap() - }; - - let serialized_tokenizer = serde_json::to_string(&tokenizer).unwrap(); - let compressed_tokenizer = encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); - - let array: &arrow_array::GenericByteArray> = array - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; - - let uid = uid - .as_any() - .downcast_ref::() - .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; - - if array.len() != uid.len() { - return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); - } - - let mut texts: Vec<(u64, &str)> = Vec::with_capacity(array.len()); - for i in 0..array.len() { - let text = array.value(i); - texts.push((uid.value(i), text)); - } - - let mut skip_tokens: HashSet = HashSet::new(); - for char in SKIP.chars() { - let char_str = char.to_string(); - skip_tokens.extend(tokenizer.encode(char_str.clone(), false).unwrap().get_ids().to_vec()); - skip_tokens.extend(tokenizer.encode(format!(" {}", char_str), false).unwrap().get_ids().to_vec()); - skip_tokens.extend(tokenizer.encode(format!("{} ", char_str), false).unwrap().get_ids().to_vec()); - } - - let named_encodings = texts - .into_maybe_par_iter() - .map(|(uid, text)| { - // strip out things in skip in text - - let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); - let encoding = tokenizer.encode(lower, false).unwrap(); - let result: Vec = encoding.get_ids().iter().filter(|id| !skip_tokens.contains(id)).cloned().collect(); - (vec![uid; result.len()], result) - }) - .collect::, Vec)>>(); - - let uids: Vec = named_encodings.iter().map(|(uid, _)| uid).flatten().cloned().collect::>(); - let encodings: Vec = named_encodings.into_iter().map(|(_, text)| text).flatten().collect::>(); - - let mut suffices: Vec> = vec![]; - - let (encodings, uids) = if token_skip_factor > 1 { - let encodings: Vec = encodings - .into_iter() - .enumerate() // Enumerate to get the index and value - .filter(|&(index, _)| index % token_skip_factor as usize == 1) // Keep only elements with odd indices (every second element) - .map(|(_, value)| value) // Extract the value - .collect(); // Collect into a vector - - let uids: Vec = uids - .into_iter() - .enumerate() // Enumerate to get the index and value - .filter(|&(index, _)| index % token_skip_factor as usize == 1) // Keep only elements with odd indices (every second element) - .map(|(_, value)| value) // Extract the value - .collect(); - (encodings, uids) - } else { - (encodings, uids) - }; - - for i in 10..encodings.len() { - suffices.push(encodings[i - 10..i].to_vec()); - } - - for i in encodings.len()..encodings.len() + 10 { - let mut suffix = encodings[i - 10..encodings.len()].to_vec(); - suffix.append(&mut vec![0; i - encodings.len()]); - suffices.push(suffix); - } - - let mut sa: Vec = (0..suffices.len()).collect(); - - sa.par_sort_by(|&a, &b| suffices[a].cmp(&suffices[b])); - - let mut idx: Vec = Vec::with_capacity(encodings.len()); - let mut bwt: Vec = Vec::with_capacity(encodings.len()); - for i in 0..sa.len() { - if sa[i] == 0 { - bwt.push(encodings[encodings.len() - 1]); - idx.push(uids[uids.len() - 1]); - } else { - bwt.push(encodings[(sa[i] - 1) as usize]); - idx.push(uids[(sa[i] - 1) as usize]); - } - } - - let mut file = File::create(output_file_name)?; - file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; - file.write_all(&compressed_tokenizer)?; - - let mut fm_chunk_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; - - let mut current_chunk: Vec = vec![]; - let mut current_chunk_counts: HashMap = HashMap::new(); - let mut next_chunk_counts: HashMap = HashMap::new(); - - for i in 0..bwt.len() { - let current_tok = bwt[i]; - next_chunk_counts.entry(current_tok).and_modify(|count| *count += 1).or_insert(1); - current_chunk.push(current_tok); - - if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt.len() - 1 { - let serialized_counts = bincode::serialize(¤t_chunk_counts)?; - let compressed_counts = encode_all(&serialized_counts[..], 10).expect("Compression failed"); - - file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; - file.write_all(&compressed_counts)?; - let serialized_chunk = bincode::serialize(¤t_chunk)?; - let compressed_chunk = encode_all(&serialized_chunk[..], 10).expect("Compression failed"); - file.write_all(&compressed_chunk)?; - - fm_chunk_offsets.push(file.seek(SeekFrom::Current(0))? as usize); - current_chunk_counts = next_chunk_counts.clone(); - current_chunk = vec![]; - } - } - // print out total file size so far - println!("total file size: {}", file.seek(SeekFrom::Current(0))?); - - let mut cumulative_counts: Vec = vec![0]; - for i in 0..tokenizer.get_vocab_size(false) { - cumulative_counts.push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u32)).unwrap_or(&0)); - } - - let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; - - for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { - let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; - let serialized_slice = bincode::serialize(slice)?; - let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); - file.write_all(&compressed_slice)?; - posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); - } - - let cache_start = file.seek(SeekFrom::Current(0))? as usize; - - let fm_chunk_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; - let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; - let compressed_fm_chunk_offsets = encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); - file.write_all(&compressed_fm_chunk_offsets)?; - - let posting_list_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; - let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; - let compressed_posting_list_offsets = - encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); - file.write_all(&compressed_posting_list_offsets)?; - - let total_counts_offset = file.seek(SeekFrom::Current(0))? as usize; - let serialized_total_counts = bincode::serialize(&cumulative_counts)?; - let compressed_total_counts: Vec = encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); - file.write_all(&compressed_total_counts)?; - - file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; - file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; - file.write_all(&(total_counts_offset as u64).to_le_bytes())?; - file.write_all(&(bwt.len() as u64).to_le_bytes())?; - - let cache_end = file.seek(SeekFrom::Current(0))? as usize; - - Ok(vec![(cache_start, cache_end)]) -} diff --git a/libCompressor.a b/src/lava/logcloud/libCompressor.a similarity index 100% rename from libCompressor.a rename to src/lava/logcloud/libCompressor.a diff --git a/libTrainer.a b/src/lava/logcloud/libTrainer.a similarity index 100% rename from libTrainer.a rename to src/lava/logcloud/libTrainer.a diff --git a/src/lava/logcloud.rs b/src/lava/logcloud/logcloud.rs similarity index 74% rename from src/lava/logcloud.rs rename to src/lava/logcloud/logcloud.rs index de98b4c..51fc04c 100644 --- a/src/lava/logcloud.rs +++ b/src/lava/logcloud/logcloud.rs @@ -7,14 +7,14 @@ use tokio::{task::JoinSet, time::sleep}; use crate::{ formats::readers::{ - get_file_size_and_reader, get_file_sizes_and_readers, get_reader, AsyncReader, ClonableAsyncReader, ReaderType, - READ_RANGE_COUNTER, + get_file_size_and_reader, get_file_sizes_and_readers, get_reader, AsyncReader, + ClonableAsyncReader, ReaderType, READ_RANGE_COUNTER, }, lava::{ - build::{_build_lava_substring_char, _build_lava_substring_char_wavelet}, error::LavaError, - logcloud_common::{get_all_types, get_type, PListChunk, PlistSize}, + logcloud::logcloud_common::{get_all_types, get_type, PListChunk, PlistSize}, search::_search_lava_substring_char, + substring::{_build_lava_substring_char, _build_lava_substring_char_wavelet}, }, }; use serde::de::DeserializeOwned; @@ -32,12 +32,14 @@ use std::{ }; use zstd::stream::{encode_all, read::Decoder}; -use super::wavelet_tree; - const BRUTE_THRESHOLD: usize = 5; const USE_EXPERIMENTAL_NUMERICS: bool = false; -async fn read_and_decompress(reader: &mut AsyncReader, start: u64, size: u64) -> Result +async fn read_and_decompress( + reader: &mut AsyncReader, + start: u64, + size: u64, +) -> Result where T: DeserializeOwned, { @@ -84,12 +86,20 @@ fn merge_files( for (i, file) in input_files_linenumbers.iter_mut().enumerate() { let mut line = String::new(); file.read_line(&mut line)?; - current_linenumbers[i] = line.split_whitespace().filter_map(|n| n.parse::().ok()).collect(); + current_linenumbers[i] = line + .split_whitespace() + .filter_map(|n| n.parse::().ok()) + .collect(); } while current_lines.iter().any(|s| !s.is_empty()) { // Find the smallest string in `current_lines` without holding a reference to it - let it = current_lines.iter().filter(|s| !s.is_empty()).min().cloned().unwrap_or_else(|| String::new()); + let it = current_lines + .iter() + .filter(|s| !s.is_empty()) + .min() + .cloned() + .unwrap_or_else(|| String::new()); if it.is_empty() { // If `it` is empty, print a warning with the current lines @@ -129,8 +139,10 @@ fn merge_files( *line = next_line; let mut lineno_line = String::new(); input_files_linenumbers[i].read_line(&mut lineno_line)?; - current_linenumbers[i] = - lineno_line.split_whitespace().filter_map(|n| n.parse::().ok()).collect(); + current_linenumbers[i] = lineno_line + .split_whitespace() + .filter_map(|n| n.parse::().ok()) + .collect(); } } } @@ -152,7 +164,12 @@ fn compact(num_groups: usize) -> io::Result<()> { } if !input_filenames.is_empty() { - merge_files(&input_filenames, &input_filenames_linenumbers, "compressed/outlier", "compressed/outlier_lineno")?; + merge_files( + &input_filenames, + &input_filenames_linenumbers, + "compressed/outlier", + "compressed/outlier_lineno", + )?; } // Process types 1 to 63 @@ -167,7 +184,8 @@ fn compact(num_groups: usize) -> io::Result<()> { } input_filenames.push(format!("compressed/{}/compacted_type_{}", i, type_)); - input_filenames_linenumbers.push(format!("compressed/{}/compacted_type_{}_lineno", i, type_)); + input_filenames_linenumbers + .push(format!("compressed/{}/compacted_type_{}_lineno", i, type_)); } if input_filenames.is_empty() { @@ -177,7 +195,12 @@ fn compact(num_groups: usize) -> io::Result<()> { let output_filename = format!("compressed/compacted_type_{}", type_); let output_filename_linenumbers = format!("compressed/compacted_type_{}_lineno", type_); - merge_files(&input_filenames, &input_filenames_linenumbers, &output_filename, &output_filename_linenumbers)?; + merge_files( + &input_filenames, + &input_filenames_linenumbers, + &output_filename, + &output_filename_linenumbers, + )?; } println!("Files merged"); @@ -205,7 +228,10 @@ pub fn write_kauai(filename: &str, num_groups: usize) -> std::io::Result<()> { for group_number in 0..num_groups { let mut group_template_idx = HashMap::new(); - let template_file = BufReader::new(File::open(format!("compressed/{}_{}.templates", filename, group_number))?); + let template_file = BufReader::new(File::open(format!( + "compressed/{}_{}.templates", + filename, group_number + ))?); for line in template_file.lines().skip(1) { let line = line?; @@ -220,19 +246,33 @@ pub fn write_kauai(filename: &str, num_groups: usize) -> std::io::Result<()> { } let total_chunks = (0..) - .take_while(|&chunk| Path::new(&format!("compressed/{}/chunk{:04}.eid", group_number, chunk)).exists()) + .take_while(|&chunk| { + Path::new(&format!( + "compressed/{}/chunk{:04}.eid", + group_number, chunk + )) + .exists() + }) .count(); for chunk in 0..total_chunks { println!("Reading chunk {}", chunk); - let eid_file = File::open(format!("compressed/{}/chunk{:04}.eid", group_number, chunk))?; + let eid_file = + File::open(format!("compressed/{}/chunk{:04}.eid", group_number, chunk))?; // let mut outlier_file = File::open(format!("compressed/{}/chunk{:04}.outlier", group_number, chunk))?; - let mut outlier_file: Option> = - if File::open(format!("compressed/{}/chunk{:04}.outlier", group_number, chunk)).is_ok() { - Some(BufReader::new(File::open(format!("compressed/{}/chunk{:04}.outlier", group_number, chunk))?)) - } else { - None - }; + let mut outlier_file: Option> = if File::open(format!( + "compressed/{}/chunk{:04}.outlier", + group_number, chunk + )) + .is_ok() + { + Some(BufReader::new(File::open(format!( + "compressed/{}/chunk{:04}.outlier", + group_number, chunk + ))?)) + } else { + None + }; for line in BufReader::new(eid_file).lines() { let eid: i64 = line?.parse().unwrap(); @@ -249,7 +289,10 @@ pub fn write_kauai(filename: &str, num_groups: usize) -> std::io::Result<()> { let idx = if group_template_idx.contains_key(&(eid as usize)) { group_template_idx[&(eid as usize)] } else { - panic!("Template not found for eid: {}, {:?}", eid, group_template_idx); + panic!( + "Template not found for eid: {}, {:?}", + eid, group_template_idx + ); }; // if template_posting_lists[idx].is_empty() // || template_posting_lists[idx].last() != Some(&(lineno as u32)) @@ -297,22 +340,38 @@ pub fn write_kauai(filename: &str, num_groups: usize) -> std::io::Result<()> { let mut outlier_type_linenos = Vec::new(); let outlier_file = File::open("compressed/outlier")?; let outlier_lineno_file = File::open("compressed/outlier_lineno")?; - for (line, outlier_type_line) in - BufReader::new(outlier_lineno_file).lines().zip(BufReader::new(outlier_file).lines()) + for (line, outlier_type_line) in BufReader::new(outlier_lineno_file) + .lines() + .zip(BufReader::new(outlier_file).lines()) { let line = line?; let outlier_type_line = outlier_type_line?; outlier_type_str.push_str(&outlier_type_line); outlier_type_str.push('\n'); - let numbers: Vec = - line.split_whitespace().filter_map(|s| s.parse().ok()).collect::>().into_iter().collect(); + let numbers: Vec = line + .split_whitespace() + .filter_map(|s| s.parse().ok()) + .collect::>() + .into_iter() + .collect(); outlier_type_linenos.push(numbers); } // print out sum of each outlier length - println!("outlier length {:?}", outliers.iter().map(|s| s.len()).sum::()); + println!( + "outlier length {:?}", + outliers.iter().map(|s| s.len()).sum::() + ); let outlier_str = outliers.join("") + "\n"; - let kauai_metadata: (String, String, Vec>, String, Vec>, String, Vec>) = ( + let kauai_metadata: ( + String, + String, + Vec>, + String, + Vec>, + String, + Vec>, + ) = ( dictionary_str, template_str, template_posting_lists, @@ -322,7 +381,8 @@ pub fn write_kauai(filename: &str, num_groups: usize) -> std::io::Result<()> { outlier_type_linenos, ); - let compressed_metadata_page = encode_all(&bincode::serialize(&kauai_metadata).unwrap()[..], 10).unwrap(); + let compressed_metadata_page = + encode_all(&bincode::serialize(&kauai_metadata).unwrap()[..], 10).unwrap(); fp.write_all(&compressed_metadata_page)?; fp.write_all(&compressed_metadata_page.len().to_le_bytes())?; @@ -343,24 +403,47 @@ async fn search_kauai( let metadata_page_length = reader.read_usize_from_end(1).await?[0]; // Read the metadata page let start_time = std::time::Instant::now(); - let metadata_page: (String, String, Vec>, String, Vec>, String, Vec>) = - read_and_decompress( - &mut reader, - file_size as u64 - metadata_page_length as u64 - 8, - metadata_page_length as u64, - ) - .await?; - let (dictionary, template, template_plist, outlier, outlier_plist, outlier_type, outlier_type_plist) = - metadata_page; + let metadata_page: ( + String, + String, + Vec>, + String, + Vec>, + String, + Vec>, + ) = read_and_decompress( + &mut reader, + file_size as u64 - metadata_page_length as u64 - 8, + metadata_page_length as u64, + ) + .await?; + let ( + dictionary, + template, + template_plist, + outlier, + outlier_plist, + outlier_type, + outlier_type_plist, + ) = metadata_page; // print out all the sizes println!("dictionary length: {}", dictionary.len()); println!("template length: {}", template.len()); - println!("template_pl total size : {}", template_plist.iter().map(|x| x.len()).sum::()); + println!( + "template_pl total size : {}", + template_plist.iter().map(|x| x.len()).sum::() + ); println!("outlier length: {}", outlier.len()); - println!("outlier_pl total size : {}", outlier_plist.iter().map(|x| x.len()).sum::()); + println!( + "outlier_pl total size : {}", + outlier_plist.iter().map(|x| x.len()).sum::() + ); println!("outlier_type length: {}", outlier_type.len()); - println!("outlier_type_pl total size : {}", outlier_type_plist.iter().map(|x| x.len()).sum::()); + println!( + "outlier_type_pl total size : {}", + outlier_type_plist.iter().map(|x| x.len()).sum::() + ); let end_time = std::time::Instant::now(); @@ -410,37 +493,61 @@ async fn search_kauai( search_text(&query, &outlier, &outlier_plist, &mut match_uids, false); if match_uids.len() >= limit.try_into().unwrap() { - println!("inexact query for top K satisfied by template and outlier {}", query); + println!( + "inexact query for top K satisfied by template and outlier {}", + query + ); return Ok((1, match_uids)); } // Search in outlier types - search_text(&query, &outlier_type, &outlier_type_plist, &mut match_uids, false); + search_text( + &query, + &outlier_type, + &outlier_type_plist, + &mut match_uids, + false, + ); return Ok((1, match_uids)); } -fn write_1_block(fp: &mut File, numbers: Vec, lineno_buffer: &[Vec], byte_offsets: &mut Vec) { - let compressed_buffer = zstd::encode_all(&bincode::serialize(&numbers).unwrap()[..], 10).unwrap(); +fn write_1_block( + fp: &mut File, + numbers: Vec, + lineno_buffer: &[Vec], + byte_offsets: &mut Vec, +) { + let compressed_buffer = + zstd::encode_all(&bincode::serialize(&numbers).unwrap()[..], 10).unwrap(); let plist = PListChunk::new(lineno_buffer.to_vec()); let serialized = plist.serialize().unwrap(); - fp.write_all(&(compressed_buffer.len() as u64).to_le_bytes()).unwrap(); + fp.write_all(&(compressed_buffer.len() as u64).to_le_bytes()) + .unwrap(); fp.write_all(&compressed_buffer).unwrap(); fp.write_all(&serialized).unwrap(); - byte_offsets.push(byte_offsets.last().unwrap() + compressed_buffer.len() + serialized.len() + 8); + byte_offsets + .push(byte_offsets.last().unwrap() + compressed_buffer.len() + serialized.len() + 8); } -fn write_block(fp: &mut File, buffer: &str, lineno_buffer: &[Vec], byte_offsets: &mut Vec) { +fn write_block( + fp: &mut File, + buffer: &str, + lineno_buffer: &[Vec], + byte_offsets: &mut Vec, +) { let compressed_buffer = zstd::encode_all(buffer.as_bytes(), 0).unwrap(); let plist = PListChunk::new(lineno_buffer.to_vec()); let serialized = plist.serialize().unwrap(); - fp.write_all(&(compressed_buffer.len() as u64).to_le_bytes()).unwrap(); + fp.write_all(&(compressed_buffer.len() as u64).to_le_bytes()) + .unwrap(); fp.write_all(&compressed_buffer).unwrap(); fp.write_all(&serialized).unwrap(); - byte_offsets.push(byte_offsets.last().unwrap() + compressed_buffer.len() + serialized.len() + 8); + byte_offsets + .push(byte_offsets.last().unwrap() + compressed_buffer.len() + serialized.len() + 8); } const BLOCK_BYTE_LIMIT: usize = 1000000; @@ -453,7 +560,12 @@ pub fn write_oahu(output_name: &str) -> Vec<(u64, String)> { let path_str = path.to_str().unwrap(); if path_str.contains("compacted_type") && !path_str.contains("lineno") { println!("Processing file: {}", path_str); - let type_num = path_str.split("compacted_type_").nth(1).unwrap().parse::().unwrap(); + let type_num = path_str + .split("compacted_type_") + .nth(1) + .unwrap() + .parse::() + .unwrap(); if type_num != 0 { types.push(type_num); } @@ -486,18 +598,26 @@ pub fn write_oahu(output_name: &str) -> Vec<(u64, String)> { let mut lines_in_buffer = 0; if USE_EXPERIMENTAL_NUMERICS && (type_num == 1) { let mut all_numbers: Vec = vec![]; - for (str_line, lineno_line) in BufReader::new(string_file).lines().zip(BufReader::new(lineno_file).lines()) + for (str_line, lineno_line) in BufReader::new(string_file) + .lines() + .zip(BufReader::new(lineno_file).lines()) { let str_line = str_line.unwrap(); let lineno_line = lineno_line.unwrap(); //cast the str_line to a usize let number: usize = str_line.parse().unwrap(); all_numbers.push(number); - let numbers: Vec = lineno_line.split_whitespace().map(|n| n.parse().unwrap()).collect(); + let numbers: Vec = lineno_line + .split_whitespace() + .map(|n| n.parse().unwrap()) + .collect(); lineno_buffer.push(numbers); } // sort numbers, lineno_buffer by numbers - let mut paired: Vec<_> = all_numbers.into_iter().zip(lineno_buffer.into_iter()).collect(); + let mut paired: Vec<_> = all_numbers + .into_iter() + .zip(lineno_buffer.into_iter()) + .collect(); paired.par_sort_unstable_by(|a, b| a.0.cmp(&b.0)); all_numbers = paired.iter().map(|a| a.0).collect(); lineno_buffer = paired.into_iter().map(|a| a.1).collect(); @@ -507,7 +627,9 @@ pub fn write_oahu(output_name: &str) -> Vec<(u64, String)> { let mut buffer = String::new(); let mut this_for_hawaii: Vec<(u64, String)> = vec![]; - for (str_line, lineno_line) in BufReader::new(string_file).lines().zip(BufReader::new(lineno_file).lines()) + for (str_line, lineno_line) in BufReader::new(string_file) + .lines() + .zip(BufReader::new(lineno_file).lines()) { let str_line = str_line.unwrap(); let lineno_line = lineno_line.unwrap(); @@ -517,16 +639,21 @@ pub fn write_oahu(output_name: &str) -> Vec<(u64, String)> { lines_in_buffer += 1; this_for_hawaii.push((byte_offsets.len() as u64 - 1, str_line)); - let numbers: Vec = lineno_line.split_whitespace().map(|n| n.parse().unwrap()).collect(); + let numbers: Vec = lineno_line + .split_whitespace() + .map(|n| n.parse().unwrap()) + .collect(); lineno_buffer.push(numbers); if uncompressed_lines_in_block == 0 && buffer.len() > BLOCK_BYTE_LIMIT / 2 { let compressed_buffer = zstd::encode_all(buffer.as_bytes(), 0).unwrap(); uncompressed_lines_in_block = - ((BLOCK_BYTE_LIMIT as f32 / compressed_buffer.len() as f32) * lines_in_buffer as f32) as usize; + ((BLOCK_BYTE_LIMIT as f32 / compressed_buffer.len() as f32) + * lines_in_buffer as f32) as usize; } - if uncompressed_lines_in_block > 0 && lines_in_buffer == uncompressed_lines_in_block { + if uncompressed_lines_in_block > 0 && lines_in_buffer == uncompressed_lines_in_block + { write_block(&mut fp, &buffer, &lineno_buffer, &mut byte_offsets); buffer.clear(); lines_in_buffer = 0; @@ -554,12 +681,18 @@ pub fn write_oahu(output_name: &str) -> Vec<(u64, String)> { } println!("type_chunks: {:?}", type_chunks); - println!("type_uncompressed_lines_in_block: {:?}", type_uncompressed_lines_in_block); + println!( + "type_uncompressed_lines_in_block: {:?}", + type_uncompressed_lines_in_block + ); - let metadata_page: (Vec, Vec, Vec, Vec) = (types, type_offsets, byte_offsets, hawaii_types); - let compressed_metadata = encode_all(&bincode::serialize(&metadata_page).unwrap()[..], 10).unwrap(); + let metadata_page: (Vec, Vec, Vec, Vec) = + (types, type_offsets, byte_offsets, hawaii_types); + let compressed_metadata = + encode_all(&bincode::serialize(&metadata_page).unwrap()[..], 10).unwrap(); fp.write_all(&compressed_metadata).unwrap(); - fp.write_all(&(compressed_metadata.len() as u64).to_le_bytes()).unwrap(); + fp.write_all(&(compressed_metadata.len() as u64).to_le_bytes()) + .unwrap(); println!("{:?}", for_hawaii.len()); @@ -578,7 +711,11 @@ pub async fn search_hawaii_oahu( ) -> Result, LavaError> { info!("query: {}", query); - let types_to_search = if exact { vec![get_type(&query)] } else { get_all_types(get_type(&query)) }; + let types_to_search = if exact { + vec![get_type(&query)] + } else { + get_all_types(get_type(&query)) + }; for &type_to_search in &types_to_search { info!("type to search: {}", type_to_search); } @@ -597,11 +734,17 @@ pub async fn search_hawaii_oahu( // see if anything in hawaii_types intersects with type_to_search - let type_intersection = - hawaii_types.iter().filter(|&&type_| types_to_search.contains(&type_)).copied().collect::>(); + let type_intersection = hawaii_types + .iter() + .filter(|&&type_| types_to_search.contains(&type_)) + .copied() + .collect::>(); - let remainder_types = - types_to_search.iter().filter(|&&type_| !type_intersection.contains(&type_)).copied().collect::>(); + let remainder_types = types_to_search + .iter() + .filter(|&&type_| !type_intersection.contains(&type_)) + .copied() + .collect::>(); let mut chunks: Vec = if type_intersection.is_empty() { vec![] @@ -639,7 +782,8 @@ pub async fn search_hawaii_oahu( ) .await .unwrap(); - let compressed_nums_length = u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; + let compressed_nums_length = + u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; println!("compressed_nums_length {:?}", compressed_nums_length); println!("total bytes {:?}", block.len()); let mut decompressor = Decoder::new(&block[8..8 + compressed_nums_length]).unwrap(); @@ -652,7 +796,13 @@ pub async fn search_hawaii_oahu( for (line_number, this_number) in all_numbers.iter().enumerate() { if this_number.to_string().contains(&query) { - all_uids.extend(plist.lookup(line_number).unwrap().iter().map(|x| (file_id, *x))); + all_uids.extend( + plist + .lookup(line_number) + .unwrap() + .iter() + .map(|x| (file_id, *x)), + ); } } } else { @@ -662,18 +812,23 @@ pub async fn search_hawaii_oahu( println!("chunks {:?}", chunks); let search_chunk = |block: bytes::Bytes, query_clone: String| { - let compressed_strings_length = u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; + let compressed_strings_length = + u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; let compressed_strings = &block[8..8 + compressed_strings_length]; let mut decompressor = Decoder::new(compressed_strings).unwrap(); - let mut decompressed_strings: Vec = Vec::with_capacity(compressed_strings.len() as usize); + let mut decompressed_strings: Vec = + Vec::with_capacity(compressed_strings.len() as usize); decompressor.read_to_end(&mut decompressed_strings).unwrap(); let compressed_plist = &block[8 + compressed_strings_length..]; let plist = PListChunk::from_compressed(compressed_plist).unwrap(); let mut uids = Vec::new(); - for (line_number, line) in String::from_utf8_lossy(&decompressed_strings).lines().enumerate() { + for (line_number, line) in String::from_utf8_lossy(&decompressed_strings) + .lines() + .enumerate() + { if format!("\n{}\n", line).contains(&query_clone) { uids.extend(plist.lookup(line_number).unwrap()); } @@ -691,7 +846,10 @@ pub async fn search_hawaii_oahu( set.spawn(async move { let block: bytes::Bytes = reader_clone - .read_range(byte_offsets_clone[chunk as usize] as u64, byte_offsets_clone[chunk as usize + 1] as u64) + .read_range( + byte_offsets_clone[chunk as usize] as u64, + byte_offsets_clone[chunk as usize + 1] as u64, + ) .await .unwrap(); search_chunk(block, query_clone) @@ -716,9 +874,13 @@ pub async fn index_logcloud(index_name: &str, num_groups: usize, use_wavelet: Op let _ = write_kauai(index_name, num_groups).unwrap(); let texts: Vec<(u64, String)> = write_oahu(index_name); if use_wavelet { - let _ = _build_lava_substring_char_wavelet(format!("{}.hawaii", index_name), texts, 1).await.unwrap(); + let _ = _build_lava_substring_char_wavelet(format!("{}.hawaii", index_name), texts, 1) + .await + .unwrap(); } else { - let _ = _build_lava_substring_char(format!("{}.hawaii", index_name), texts, 1).await.unwrap(); + let _ = _build_lava_substring_char(format!("{}.hawaii", index_name), texts, 1) + .await + .unwrap(); } } @@ -735,9 +897,13 @@ pub async fn index_analysis(split_index_prefixes: Vec, reader_type: Read .collect::>(); let (oahu_sizes, mut reader_oahus) = - get_file_sizes_and_readers(&oahu_filenames, reader_type.clone()).await.unwrap(); + get_file_sizes_and_readers(&oahu_filenames, reader_type.clone()) + .await + .unwrap(); let (hawaii_sizes, mut reader_hawaiis) = - get_file_sizes_and_readers(&hawaii_filenames, reader_type.clone()).await.unwrap(); + get_file_sizes_and_readers(&hawaii_filenames, reader_type.clone()) + .await + .unwrap(); let mut total_fm_index_size = 0; let mut total_suffix_array_size = 0; @@ -746,16 +912,20 @@ pub async fn index_analysis(split_index_prefixes: Vec, reader_type: Read let mut total_csr_length = 0; let mut total_roaring_length = 0; - for (hawaii_size, mut reader_hawaii) in hawaii_sizes.into_iter().zip(reader_hawaiis.into_iter()) { + for (hawaii_size, mut reader_hawaii) in hawaii_sizes.into_iter().zip(reader_hawaiis.into_iter()) + { let results = reader_hawaii.read_usize_from_end(4).await.unwrap(); let posting_list_offsets_offset = results[1]; let total_counts_offset = results[2]; - let posting_list_offsets: Vec = - reader_hawaii.read_range_and_decompress(posting_list_offsets_offset, total_counts_offset).await.unwrap(); + let posting_list_offsets: Vec = reader_hawaii + .read_range_and_decompress(posting_list_offsets_offset, total_counts_offset) + .await + .unwrap(); total_fm_index_size += posting_list_offsets[0]; - total_suffix_array_size += posting_list_offsets[posting_list_offsets.len() - 1] - posting_list_offsets[0]; + total_suffix_array_size += + posting_list_offsets[posting_list_offsets.len() - 1] - posting_list_offsets[0]; } for (oahu_size, mut reader_oahu) in oahu_sizes.into_iter().zip(reader_oahus.into_iter()) { @@ -771,12 +941,17 @@ pub async fn index_analysis(split_index_prefixes: Vec, reader_type: Read let (types, type_offsets, byte_offsets, hawaii_types) = metadata_page; for i in 0..byte_offsets.len() - 1 { - let block = reader_oahu.read_range(byte_offsets[i] as u64, byte_offsets[i + 1] as u64).await.unwrap(); - let compressed_strings_length = u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; + let block = reader_oahu + .read_range(byte_offsets[i] as u64, byte_offsets[i + 1] as u64) + .await + .unwrap(); + let compressed_strings_length = + u64::from_le_bytes(block[0..8].try_into().unwrap()) as usize; let compressed_strings = &block[8..8 + compressed_strings_length]; let mut decompressor = Decoder::new(compressed_strings).unwrap(); - let mut decompressed_strings: Vec = Vec::with_capacity(compressed_strings.len() as usize); + let mut decompressed_strings: Vec = + Vec::with_capacity(compressed_strings.len() as usize); decompressor.read_to_end(&mut decompressed_strings).unwrap(); let compressed_plist = &block[8 + compressed_strings_length..]; @@ -809,12 +984,14 @@ pub async fn index_analysis(split_index_prefixes: Vec, reader_type: Read let compressed_roaring_offsets = encode_all(&bincode::serialize(&roaring_offsets).unwrap()[..], 10).unwrap(); - block_roaring_length = - encode_all(&total_serialized_string[..], 10).unwrap().len() + compressed_roaring_offsets.len(); + block_roaring_length = encode_all(&total_serialized_string[..], 10).unwrap().len() + + compressed_roaring_offsets.len(); // Compress CSR offsets and values - let compressed_csr_offsets = encode_all(&bincode::serialize(&csr_offsets).unwrap()[..], 10).unwrap(); - let compressed_values = encode_all(&bincode::serialize(&values).unwrap()[..], 10).unwrap(); + let compressed_csr_offsets = + encode_all(&bincode::serialize(&csr_offsets).unwrap()[..], 10).unwrap(); + let compressed_values = + encode_all(&bincode::serialize(&values).unwrap()[..], 10).unwrap(); // println!("Block {} compressed csr offsets length: {}", i, compressed_csr_offsets.len()); // println!("Block {} compressed values length: {}", i, compressed_values.len()); @@ -831,8 +1008,14 @@ pub async fn index_analysis(split_index_prefixes: Vec, reader_type: Read println!("Total fm index size: {}", total_fm_index_size); println!("Total suffix array size: {}", total_suffix_array_size); - println!("Total compressed strings length: {}", total_compressed_strings_length); - println!("Total compressed plist length: {}", total_compressed_plist_length); + println!( + "Total compressed strings length: {}", + total_compressed_strings_length + ); + println!( + "Total compressed plist length: {}", + total_compressed_plist_length + ); println!("Total compressed roaring length: {}", total_roaring_length); println!("Total compressed csr length: {}", total_csr_length); } @@ -855,13 +1038,26 @@ pub async fn search_logcloud( .map(|split_index_prefix| format!("{}.kauai", split_index_prefix)) .collect::>(); - let (kauai_sizes, reader_kauais) = get_file_sizes_and_readers(&kauai_filenames, reader_type.clone()).await?; + let (kauai_sizes, reader_kauais) = + get_file_sizes_and_readers(&kauai_filenames, reader_type.clone()).await?; let mut set = JoinSet::new(); - for (file_id, (kauai_size, reader_kauai)) in kauai_sizes.into_iter().zip(reader_kauais.into_iter()).enumerate() { + for (file_id, (kauai_size, reader_kauai)) in kauai_sizes + .into_iter() + .zip(reader_kauais.into_iter()) + .enumerate() + { let query_clone = query.clone(); set.spawn(async move { - search_kauai(file_id, reader_kauai, kauai_size, query_clone, limit.try_into().unwrap()).await.unwrap() + search_kauai( + file_id, + reader_kauai, + kauai_size, + query_clone, + limit.try_into().unwrap(), + ) + .await + .unwrap() }); } @@ -880,7 +1076,11 @@ pub async fn search_logcloud( return Ok((1, all_uids)); } } - _ => return Err(LavaError::Parse("Unexpected result from search_kauai".to_string())), + _ => { + return Err(LavaError::Parse( + "Unexpected result from search_kauai".to_string(), + )) + } } } @@ -901,12 +1101,17 @@ pub async fn search_logcloud( .map(|split_index_prefix| format!("{}.hawaii", split_index_prefix)) .collect::>(); - let (oahu_sizes, mut reader_oahus) = get_file_sizes_and_readers(&oahu_filenames, reader_type.clone()).await?; + let (oahu_sizes, mut reader_oahus) = + get_file_sizes_and_readers(&oahu_filenames, reader_type.clone()).await?; let mut set = JoinSet::new(); let new_limit = limit - all_uids.len(); - for (file_id, (oahu_size, reader_oahu)) in oahu_sizes.into_iter().zip(reader_oahus.into_iter()).enumerate() { + for (file_id, (oahu_size, reader_oahu)) in oahu_sizes + .into_iter() + .zip(reader_oahus.into_iter()) + .enumerate() + { let hawaii_filename = hawaii_filenames.remove(0); let query_clone = query.clone(); set.spawn(async move { diff --git a/src/lava/logcloud_common.rs b/src/lava/logcloud/logcloud_common.rs similarity index 100% rename from src/lava/logcloud_common.rs rename to src/lava/logcloud/logcloud_common.rs diff --git a/src/lava/logcloud_rex.rs b/src/lava/logcloud/logcloud_rex.rs similarity index 71% rename from src/lava/logcloud_rex.rs rename to src/lava/logcloud/logcloud_rex.rs index bd3e2aa..aaa9bf5 100644 --- a/src/lava/logcloud_rex.rs +++ b/src/lava/logcloud/logcloud_rex.rs @@ -11,8 +11,8 @@ use std::io::{BufRead, BufReader, Write}; use std::path::{Path, PathBuf}; use std::{fs, panic}; +use super::logcloud_common::{get_all_types, get_type}; use crate::lava::error::LavaError; -use crate::lava::logcloud_common::{get_all_types, get_type}; const CHUNK_SIZE: usize = 67108864; // const CHUNK_SIZE: usize = 268435456; @@ -32,16 +32,20 @@ extern "C" { fn trainer_wrapper_rust(sample_str: &str, output_path: &str) -> PyResult<()> { let sample_str = remove_null_bytes(&sample_str); - let sample_str_c = - CString::new(sample_str).map_err(|e| PyErr::new::(e.to_string())).unwrap(); + let sample_str_c = CString::new(sample_str) + .map_err(|e| PyErr::new::(e.to_string())) + .unwrap(); // strip blackslashes from the sample_str - let output_path_c = - CString::new(output_path).map_err(|e| PyErr::new::(e.to_string())).unwrap(); + let output_path_c = CString::new(output_path) + .map_err(|e| PyErr::new::(e.to_string())) + .unwrap(); let result = panic::catch_unwind(|| unsafe { let result = trainer_wrapper(sample_str_c.as_ptr(), output_path_c.as_ptr()); if result != 0 { - return Err(PyErr::new::("trainer_wrapper_c failed")); + return Err(PyErr::new::( + "trainer_wrapper_c failed", + )); } Ok(()) }); @@ -61,21 +65,33 @@ fn remove_null_bytes(s: &str) -> String { } } -fn compressor_wrapper_rust(chunk: &str, output_path: &str, template_path: &str, prefix: i32) -> PyResult<()> { +fn compressor_wrapper_rust( + chunk: &str, + output_path: &str, + template_path: &str, + prefix: i32, +) -> PyResult<()> { let chunk_c = CString::new(remove_null_bytes(chunk)) .map_err(|e| PyErr::new::(e.to_string())) .unwrap(); - let output_path_c = - CString::new(output_path).map_err(|e| PyErr::new::(e.to_string())).unwrap(); + let output_path_c = CString::new(output_path) + .map_err(|e| PyErr::new::(e.to_string())) + .unwrap(); let template_path_c = CString::new(template_path) .map_err(|e| PyErr::new::(e.to_string())) .unwrap(); unsafe { - let result = - compressor_wrapper(chunk_c.as_ptr(), output_path_c.as_ptr(), template_path_c.as_ptr(), prefix as c_int); + let result = compressor_wrapper( + chunk_c.as_ptr(), + output_path_c.as_ptr(), + template_path_c.as_ptr(), + prefix as c_int, + ); if result != 0 { - return Err(PyErr::new::("compressor_wrapper_c failed")); + return Err(PyErr::new::( + "compressor_wrapper_c failed", + )); } } println!("compressor_wrapper_rust done"); @@ -85,7 +101,10 @@ fn compressor_wrapper_rust(chunk: &str, output_path: &str, template_path: &str, fn get_variable_info( total_chunks: usize, group_number: usize, -) -> PyResult<(HashMap>, HashMap>)> { +) -> PyResult<( + HashMap>, + HashMap>, +)> { let mut variable_to_type = HashMap::new(); let mut chunk_variables: HashMap> = HashMap::new(); let mut eid_to_variables: HashMap> = HashMap::new(); @@ -98,9 +117,9 @@ fn get_variable_info( for line in reader.lines() { let line = line?; let mut parts = line.split_whitespace(); - let variable_str = parts - .next() - .ok_or_else(|| PyErr::new::("Invalid variable string"))?; + let variable_str = parts.next().ok_or_else(|| { + PyErr::new::("Invalid variable string") + })?; let tag = parts .next() .ok_or_else(|| PyErr::new::("Invalid tag"))? @@ -108,20 +127,32 @@ fn get_variable_info( let mut var_parts = variable_str.split('_'); - let a_part = var_parts - .next() - .ok_or_else(|| PyErr::new::("Invalid variable format"))?; - let a = - a_part.chars().skip_while(|c| !c.is_digit(10)).collect::().parse::().map_err(|_| { - PyErr::new::("Invalid integer in variable format") + let a_part = var_parts.next().ok_or_else(|| { + PyErr::new::("Invalid variable format") + })?; + let a = a_part + .chars() + .skip_while(|c| !c.is_digit(10)) + .collect::() + .parse::() + .map_err(|_| { + PyErr::new::( + "Invalid integer in variable format", + ) })?; - let b_part = var_parts - .next() - .ok_or_else(|| PyErr::new::("Invalid variable format"))?; - let b = - b_part.chars().skip_while(|c| !c.is_digit(10)).collect::().parse::().map_err(|_| { - PyErr::new::("Invalid integer in variable format") + let b_part = var_parts.next().ok_or_else(|| { + PyErr::new::("Invalid variable format") + })?; + let b = b_part + .chars() + .skip_while(|c| !c.is_digit(10)) + .collect::() + .parse::() + .map_err(|_| { + PyErr::new::( + "Invalid integer in variable format", + ) })?; let variable = (a, b); @@ -131,7 +162,10 @@ fn get_variable_info( } } - let eid_to_variables = eid_to_variables.into_iter().map(|(k, v)| (k, v.into_iter().collect())).collect(); + let eid_to_variables = eid_to_variables + .into_iter() + .map(|(k, v)| (k, v.into_iter().collect())) + .collect(); Ok((chunk_variables, eid_to_variables)) } @@ -161,19 +195,26 @@ fn compress_chunk( println!("compressing chunk"); let chunk_filename = format!("compressed/{}/chunk{:04}", group_number, chunk_file_counter); - compressor_wrapper_rust(current_chunk, &chunk_filename, template_name, chunk_file_counter as i32)?; + compressor_wrapper_rust( + current_chunk, + &chunk_filename, + template_name, + chunk_file_counter as i32, + )?; // Rename files let source_dir = dir_path; - let target_dir = - Path::new("compressed").join(group_number.to_string()).join(format!("variable_{}", chunk_file_counter)); + let target_dir = Path::new("compressed") + .join(group_number.to_string()) + .join(format!("variable_{}", chunk_file_counter)); println!("source_dir: {:?}", source_dir); println!("target_dir: {:?}", target_dir); std::fs::rename(&source_dir, &target_dir)?; let source_tag = tag_path; - let target_tag = - Path::new("compressed").join(group_number.to_string()).join(format!("variable_{}_tag.txt", chunk_file_counter)); + let target_tag = Path::new("compressed") + .join(group_number.to_string()) + .join(format!("variable_{}_tag.txt", chunk_file_counter)); if !source_tag.exists() { return Err(format!("Source tag file does not exist: {:?}", source_tag).into()); @@ -213,15 +254,21 @@ pub async fn compress_logs( let array: &arrow_array::GenericByteArray> = array .as_any() .downcast_ref::() - .ok_or(LavaError::Parse("Expects string array as first argument".to_string()))?; + .ok_or(LavaError::Parse( + "Expects string array as first argument".to_string(), + ))?; let uid: &arrow_array::PrimitiveArray = uid .as_any() .downcast_ref::() - .ok_or(LavaError::Parse("Expects uint64 array as second argument".to_string()))?; + .ok_or(LavaError::Parse( + "Expects uint64 array as second argument".to_string(), + ))?; if array.len() != uid.len() { - return Err(LavaError::Parse("The length of the array and the uid array must be the same".to_string())); + return Err(LavaError::Parse( + "The length of the array and the uid array must be the same".to_string(), + )); } let mut logs1 = Vec::with_capacity(array.len()); @@ -263,7 +310,10 @@ pub async fn compress_logs( // Attempt to parse the timestamp let mut epoch_ts = if line.len() >= timestamp_bytes { let extract_timestamp_from_this_line = &line[..timestamp_bytes]; - match NaiveDateTime::parse_from_str(extract_timestamp_from_this_line.trim(), ×tamp_format) { + match NaiveDateTime::parse_from_str( + extract_timestamp_from_this_line.trim(), + ×tamp_format, + ) { Ok(dt) => dt.timestamp() as u64, Err(_) => last_timestamp, } @@ -338,22 +388,34 @@ pub async fn compress_logs( */ let total_chunks = chunk_uids.len(); - let (chunk_variables, eid_to_variables) = get_variable_info(total_chunks, group_number).unwrap(); + let (chunk_variables, eid_to_variables) = + get_variable_info(total_chunks, group_number).unwrap(); let mut touched_types = std::collections::HashSet::new(); - let mut expanded_items: std::collections::HashMap> = std::collections::HashMap::new(); - let mut expanded_lineno: std::collections::HashMap> = std::collections::HashMap::new(); + let mut expanded_items: std::collections::HashMap> = + std::collections::HashMap::new(); + let mut expanded_lineno: std::collections::HashMap> = + std::collections::HashMap::new(); println!("total_chunks: {}", total_chunks); for chunk in 0..total_chunks { let mut variable_files = std::collections::HashMap::new(); let mut variable_idx = std::collections::HashMap::new(); - for &variable in chunk_variables.get(&chunk).unwrap_or(&std::collections::HashSet::new()) { - let file_path = format!("compressed/{}/variable_{}/E{}_V{}", group_number, chunk, variable.0, variable.1); + for &variable in chunk_variables + .get(&chunk) + .unwrap_or(&std::collections::HashSet::new()) + { + let file_path = format!( + "compressed/{}/variable_{}/E{}_V{}", + group_number, chunk, variable.0, variable.1 + ); let file_content = fs::read_to_string(file_path).unwrap(); - let lines = file_content.lines().map(String::from).collect::>(); + let lines = file_content + .lines() + .map(String::from) + .collect::>(); variable_files.insert(variable, lines); variable_idx.insert(variable, 0); } @@ -371,7 +433,8 @@ pub async fn compress_logs( let mut type_vars = std::collections::HashMap::new(); for &variable in this_variables { - let item = variable_files.get_mut(&variable).unwrap()[variable_idx[&variable]].to_string(); + let item = + variable_files.get_mut(&variable).unwrap()[variable_idx[&variable]].to_string(); variable_idx.entry(variable).and_modify(|v| *v += 1); let t = get_type(&item); if t == 0 { @@ -387,7 +450,10 @@ pub async fn compress_logs( for (&t, items) in &type_vars { // println!("{} {} {}", chunk, t, items.len()); - expanded_items.entry(t).or_default().extend(items.iter().cloned()); + expanded_items + .entry(t) + .or_default() + .extend(items.iter().cloned()); expanded_lineno .entry(t) .or_default() @@ -399,17 +465,25 @@ pub async fn compress_logs( // Process and write compacted types and outliers let mut compacted_type_files = std::collections::HashMap::new(); let mut compacted_lineno_files = std::collections::HashMap::new(); - let mut outlier_file = std::fs::File::create(format!("compressed/{}/outlier", group_number)).unwrap(); - let mut outlier_lineno_file = std::fs::File::create(format!("compressed/{}/outlier_lineno", group_number)).unwrap(); + let mut outlier_file = + std::fs::File::create(format!("compressed/{}/outlier", group_number)).unwrap(); + let mut outlier_lineno_file = + std::fs::File::create(format!("compressed/{}/outlier_lineno", group_number)).unwrap(); let mut outlier_items = Vec::new(); let mut outlier_lineno = Vec::new(); for &t in &touched_types { if expanded_items[&t].is_empty() { - panic!("Error in variable extraction. No items detected for type {}", t); + panic!( + "Error in variable extraction. No items detected for type {}", + t + ); } - let mut paired: Vec<_> = expanded_items[&t].iter().zip(expanded_lineno[&t].iter()).collect(); + let mut paired: Vec<_> = expanded_items[&t] + .iter() + .zip(expanded_lineno[&t].iter()) + .collect(); paired.par_sort_unstable_by(|a, b| a.0.cmp(b.0).then_with(|| a.1.cmp(b.1))); let mut compacted_items = Vec::new(); @@ -428,16 +502,29 @@ pub async fn compress_logs( if compacted_items.len() > OUTLIER_THRESHOLD { let type_file = compacted_type_files.entry(t).or_insert_with(|| { - std::fs::File::create(format!("compressed/{}/compacted_type_{}", group_number, t)).unwrap() + std::fs::File::create(format!("compressed/{}/compacted_type_{}", group_number, t)) + .unwrap() }); let lineno_file = compacted_lineno_files.entry(t).or_insert_with(|| { - std::fs::File::create(format!("compressed/{}/compacted_type_{}_lineno", group_number, t)).unwrap() + std::fs::File::create(format!( + "compressed/{}/compacted_type_{}_lineno", + group_number, t + )) + .unwrap() }); for (item, linenos) in compacted_items.iter().zip(compacted_lineno.iter()) { writeln!(type_file, "{}", item).unwrap(); - writeln!(lineno_file, "{}", linenos.iter().map(|&n| n.to_string()).collect::>().join(" ")) - .unwrap(); + writeln!( + lineno_file, + "{}", + linenos + .iter() + .map(|&n| n.to_string()) + .collect::>() + .join(" ") + ) + .unwrap(); } } else { outlier_items.extend(compacted_items); @@ -446,12 +533,23 @@ pub async fn compress_logs( } // Sort and write outliers - let mut paired: Vec<_> = outlier_items.into_iter().zip(outlier_lineno.into_iter()).collect(); + let mut paired: Vec<_> = outlier_items + .into_iter() + .zip(outlier_lineno.into_iter()) + .collect(); paired.par_sort_unstable_by(|a, b| a.0.cmp(&b.0)); for (item, linenos) in paired { writeln!(outlier_file, "{}", item).unwrap(); - writeln!(outlier_lineno_file, "{}", linenos.iter().map(|&n| n.to_string()).collect::>().join(" ")) - .unwrap(); + writeln!( + outlier_lineno_file, + "{}", + linenos + .iter() + .map(|&n| n.to_string()) + .collect::>() + .join(" ") + ) + .unwrap(); } // flush the files diff --git a/src/lava/logcloud/mod.rs b/src/lava/logcloud/mod.rs new file mode 100644 index 0000000..acdd7a8 --- /dev/null +++ b/src/lava/logcloud/mod.rs @@ -0,0 +1,8 @@ +mod logcloud; +mod logcloud_common; +mod logcloud_rex; + +pub use logcloud::index_analysis; +pub use logcloud::index_logcloud; +pub use logcloud::search_logcloud; +pub use logcloud_rex::compress_logs; diff --git a/src/lava/merge.rs b/src/lava/merge.rs index 9aa891b..a4f75cd 100644 --- a/src/lava/merge.rs +++ b/src/lava/merge.rs @@ -1,639 +1,18 @@ -use arrow::datatypes::ToByteSlice; use async_recursion::async_recursion; -use bincode; -use bit_vec::BitVec; + use itertools::Itertools; -use ndarray::{concatenate, Array2, Axis}; use std::collections::BTreeSet; -use std::fs::File; -use std::io::{BufReader, Read, Seek, SeekFrom, Write}; use std::sync::{Arc, Mutex}; -use zstd::stream::encode_all; -use zstd::stream::read::Decoder; -use crate::formats::readers::{get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ReaderType}; -use crate::lava::constants::*; -use crate::lava::error::LavaError; -use crate::lava::fm_chunk::FMChunk; -use crate::lava::plist::PListChunk; -use crate::lava::trie::FastTrie; -use std::collections::HashMap; +use crate::formats::readers::ReaderType; -use crate::vamana::{access::InMemoryAccessMethodF32, merge_indexes_par, EuclideanF32, IndexParams, VamanaIndex}; +use crate::lava::bm25::merge_lava_bm25; +use crate::lava::error::LavaError; +use crate::lava::substring::merge_lava_substring; +use crate::lava::uuid::merge_lava_uuid; // @Rain chore: we need to simplify all the iterator impls -struct PListIterator { - reader: AsyncReader, - plist_offsets: Vec, - current_chunk_offset: usize, - pub current_chunk: Vec, -} - -impl PListIterator { - // take ownership of the data structures - pub async fn new(mut reader: AsyncReader, plist_offsets: Vec) -> Result { - let plist_chunk = reader.read_range_and_decompress(plist_offsets[0], plist_offsets[1]).await?; - Ok(Self { reader: reader, plist_offsets: plist_offsets, current_chunk_offset: 0, current_chunk: plist_chunk }) - } - - pub async fn advance(&mut self) -> Result<(), LavaError> { - self.current_chunk_offset += 1; - if self.current_chunk_offset + 2 > self.plist_offsets.len() { - return Err(LavaError::Parse("out of chunks".to_string())); - } - self.current_chunk = self - .reader - .read_range_and_decompress( - self.plist_offsets[self.current_chunk_offset], - self.plist_offsets[self.current_chunk_offset + 1], - ) - .await?; - Ok(()) - } -} - -struct FMChunkIterator { - reader: AsyncReader, - fm_chunk_offsets: Vec, - current_chunk_offset: usize, - pub current_chunk: FMChunk, -} - -impl FMChunkIterator { - // take ownership of the data structures - pub async fn new(mut reader: AsyncReader, fm_chunk_offsets: Vec) -> Result { - let buffer3 = reader.read_range(fm_chunk_offsets[0], fm_chunk_offsets[1]).await?; - let current_chunk = FMChunk::new(buffer3)?; - - Ok(Self { - reader: reader, - fm_chunk_offsets: fm_chunk_offsets, - current_chunk_offset: 0, - current_chunk: current_chunk, - }) - } - - pub async fn advance(&mut self) -> Result<(), LavaError> { - self.current_chunk_offset += 1; - - if self.current_chunk_offset + 2 > self.fm_chunk_offsets.len() { - return Err(LavaError::Parse("out of chunks".to_string())); - } - let buffer3 = self - .reader - .read_range( - self.fm_chunk_offsets[self.current_chunk_offset], - self.fm_chunk_offsets[self.current_chunk_offset + 1], - ) - .await?; - self.current_chunk = FMChunk::new(buffer3)?; - - Ok(()) - } - - pub async fn reset(&mut self) -> Result<(), LavaError> { - self.current_chunk = - FMChunk::new(self.reader.read_range(self.fm_chunk_offsets[0], self.fm_chunk_offsets[1]).await?)?; - self.current_chunk_offset = 0; - - Ok(()) - } -} - -struct PListChunkIterator { - reader: AsyncReader, - current_offset_in_chunk: usize, - current_chunk_offset: usize, - current_chunk: Vec>, - plist_offsets: Vec, - plist_elems: Vec, -} - -impl PListChunkIterator { - // take ownership of the data structures - pub async fn new( - mut reader: AsyncReader, - plist_offsets: Vec, - plist_elems: Vec, - ) -> Result { - // read the first chunk - - let buffer3 = reader.read_range(plist_offsets[0], plist_offsets[1]).await?; - let result: Vec> = - PListChunk::search_compressed(buffer3.to_vec(), &(0..plist_elems[1]).collect()).unwrap(); - - Ok(Self { - reader: reader, - current_offset_in_chunk: 0, - current_chunk_offset: 0, - current_chunk: result, - plist_offsets: plist_offsets, - plist_elems: plist_elems, - }) - } - - pub fn get(&mut self) -> Vec { - self.current_chunk[self.current_offset_in_chunk as usize].clone() - } - - pub async fn advance(&mut self) -> Result<(), LavaError> { - self.current_offset_in_chunk += 1; - if self.current_offset_in_chunk == self.current_chunk.len() { - // read the next chunk - self.current_offset_in_chunk = 0; - self.current_chunk_offset += 1; - if self.current_chunk_offset + 2 > self.plist_offsets.len() { - return Err(LavaError::Parse("out of chunks".to_string())); - } - - let buffer3 = self - .reader - .read_range( - self.plist_offsets[self.current_chunk_offset], - self.plist_offsets[self.current_chunk_offset + 1], - ) - .await?; - - self.current_chunk = PListChunk::search_compressed( - buffer3.to_vec(), - &(0..(self.plist_elems[self.current_chunk_offset + 1] - self.plist_elems[self.current_chunk_offset])) - .collect(), - ) - .unwrap(); - } - - Ok(()) - } -} - -async fn merge_lava_uuid( - condensed_lava_file: &str, - lava_files: Vec, - uid_offsets: Vec, - reader_type: ReaderType, -) -> Result, LavaError> { - // currently only support merging two files, but can support more in the future. - assert_eq!(lava_files.len(), 2); - assert_eq!(uid_offsets.len(), 2); - - let (file_size1, mut reader1) = get_file_size_and_reader(lava_files[0].clone(), reader_type.clone()).await?; - let (file_size2, mut reader2) = get_file_size_and_reader(lava_files[1].clone(), reader_type.clone()).await?; - - // let buffer: bytes::Bytes = reader1.read_range(0, file_size1 as u64).await?; - // let mut fast_trie1 = FastTrie::deserialize(buffer.to_vec()); - // let buffer: bytes::Bytes = reader2.read_range(0, file_size2 as u64).await?; - // let mut fast_trie2 = FastTrie::deserialize(buffer.to_vec()); - - // fast_trie1.extend( - // &mut fast_trie2, - // uid_offsets[0] as usize, - // uid_offsets[1] as usize, - // ); - // let (serialized, (cache_start, cache_end)) = fast_trie1.serialize(); - // let mut output_file = File::create(condensed_lava_file)?; - // output_file.write(&serialized)?; - - let (cache_start, cache_end) = FastTrie::extend_with_readers_into_file( - file_size1, - &mut reader1, - file_size2, - &mut reader2, - condensed_lava_file, - uid_offsets[0] as usize, - uid_offsets[1] as usize, - ) - .await?; - - Ok(vec![(cache_start, cache_end)]) -} - -async fn merge_lava_bm25( - condensed_lava_file: &str, - lava_files: Vec, - uid_offsets: Vec, - reader_type: ReaderType, -) -> Result, LavaError> { - // let mut builder = Fs::default(); - // let current_path = env::current_dir()?; - // builder.root(current_path.to_str().expect("no path")); - // let operator = Operator::new(builder)?.finish(); - - let mut file_sizes: Vec = Vec::with_capacity(lava_files.len()); - let mut plist_chunk_iterators: Vec = Vec::with_capacity(lava_files.len()); - - let mut combined_token_counts: Vec = Vec::new(); - let mut total_num_documents: u64 = 0; - let mut compressed_tokenizer: Option> = None; - - for file in lava_files { - let reader_type = reader_type.clone(); - let (file_size, mut reader) = get_file_size_and_reader(file, reader_type).await?; - let file_size = file_size as u64; - - let results = reader.read_usize_from_end(3).await?; - let compressed_term_dict_offset = results[0]; - let compressed_plist_offsets_offset = results[1]; - let num_documents = results[2]; - total_num_documents += num_documents; - - let compressed_token_counts = - reader.read_range(compressed_term_dict_offset, compressed_plist_offsets_offset).await?; - - let mut decompressed_token_counts: Vec = Vec::new(); - let mut decompressor: Decoder<'_, BufReader<&[u8]>> = Decoder::new(&compressed_token_counts[..])?; - decompressor.read_to_end(&mut decompressed_token_counts)?; - let token_counts: Vec = bincode::deserialize(&decompressed_token_counts)?; - - if combined_token_counts.len() == 0 { - combined_token_counts = token_counts; - } else { - // add token_counts to combined_token_counts - for (i, count) in token_counts.iter().enumerate() { - combined_token_counts[i] += count; - } - } - - let buffer2 = reader.read_range(compressed_plist_offsets_offset, file_size - 24).await?; - - decompressor = Decoder::new(&buffer2[..])?; - let mut decompressed_serialized_plist_offsets: Vec = Vec::with_capacity(buffer2.len() as usize); - decompressor.read_to_end(&mut decompressed_serialized_plist_offsets)?; - let this_plist_offsets: Vec = bincode::deserialize(&decompressed_serialized_plist_offsets)?; - - if (this_plist_offsets.len() % 2) != 0 { - let err = LavaError::Parse("data corruption".to_string()); - return Err(err); - } - let num_elements = this_plist_offsets.len() / 2; - - let compressed_tokenizer_size = reader.read_usize_from_start(0, 1).await?[0]; - let this_compressed_tokenizer: bytes::Bytes = reader.read_range(8, 8 + compressed_tokenizer_size).await?; - - match &compressed_tokenizer { - Some(value) => assert!( - this_compressed_tokenizer == value, - "detected different tokenizers, cannot merge, something is very wrong." - ), - None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()), - } - - file_sizes.push(file_size); - plist_chunk_iterators.push( - PListChunkIterator::new( - reader, - this_plist_offsets[..num_elements].to_vec(), - this_plist_offsets[num_elements..].to_vec(), - ) - .await?, - ); - } - - let mut output_file = File::create(condensed_lava_file)?; - - let compressed_tokenizer = compressed_tokenizer.unwrap(); - // let compressed_tokenizer_len = compressed_tokenizer.len(); - output_file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; - output_file.write_all(&compressed_tokenizer)?; - - let mut new_plist_offsets: Vec = vec![output_file.seek(SeekFrom::Current(0))?]; - let mut new_plist_elems: Vec = vec![0]; - let mut plist_chunk = PListChunk::new()?; - let mut counter: u64 = 0; - - for tok in 0..combined_token_counts.len() { - // Find the smallest current line - - let mut plist: Vec = vec![]; - - for i in 0..plist_chunk_iterators.len() { - let this_plist: Vec = plist_chunk_iterators[i].get(); - assert_eq!(this_plist.len() % 2, 0); - - for (j, item) in this_plist.iter().enumerate() { - if j % 2 == 0 { - // page offset - plist.push(*item + uid_offsets[i]); - } else { - // quantized score - plist.push(*item); - } - } - - // this will return error for the last one, but it's ok - let _ = plist_chunk_iterators[i].advance().await; - } - - counter += 1; - - let plist = Vec::from_iter(plist.into_iter()); - let written = plist_chunk.add_plist(&plist)?; - if written > 1024 * 1024 || tok == combined_token_counts.len() - 1 { - let bytes = plist_chunk.finalize_compression()?; - let this_len: u64 = bytes.len() as u64; - - output_file.write(&bytes)?; - new_plist_offsets.push(new_plist_offsets[new_plist_offsets.len() - 1] + this_len); - new_plist_elems.push(counter); - plist_chunk = PListChunk::new()?; - } - } - - new_plist_offsets.append(&mut new_plist_elems); - - let bytes = bincode::serialize(&combined_token_counts)?; - let compressed_token_counts = encode_all(&bytes[..], 0).expect("Compression failed"); - - let compressed_term_dict_offset = output_file.seek(SeekFrom::Current(0))?; - output_file.write(&compressed_token_counts)?; - - let serialized = bincode::serialize(&new_plist_offsets).unwrap(); - let compressed_plist_offsets = encode_all(&serialized[..], 0).expect("Compression of plist offsets failed"); - - let compressed_plist_offsets_offset = compressed_term_dict_offset + compressed_token_counts.len() as u64; - output_file.write(&compressed_plist_offsets)?; - - output_file.write(&(compressed_term_dict_offset as u64).to_le_bytes())?; - output_file.write(&(compressed_plist_offsets_offset as u64).to_le_bytes())?; - output_file.write(&(total_num_documents as u64).to_le_bytes())?; - - Ok(vec![(compressed_term_dict_offset as usize, output_file.seek(SeekFrom::Current(0))? as usize)]) -} - -async fn compute_interleave( - bwt0_reader: &mut FMChunkIterator, - bwt1_reader: &mut FMChunkIterator, - lens: (usize, usize), - cumulative_counts: &Vec, -) -> Result { - let (bwt0_len, bwt1_len) = lens; - - let mut interleave = BitVec::from_elem(bwt0_len + bwt1_len, true); - for i in 0..bwt0_len { - interleave.set(i, false); - } - - // let mut interleave_iterations = 0; - - for _ in 0..10 { - let mut ind: [usize; 2] = [0, 0]; - - let mut bwt0 = &bwt0_reader.current_chunk.bwt_chunk; - let mut bwt1 = &bwt1_reader.current_chunk.bwt_chunk; - - let mut offsets = cumulative_counts.clone(); - let mut new_interleave = BitVec::from_elem(interleave.len(), false); - for i in 0..interleave.len() { - if interleave[i] { - new_interleave.set(offsets[bwt1[ind[1]] as usize] as usize, true); - offsets[bwt1[ind[1]] as usize] += 1; - ind[1] += 1; - - if ind[1] == bwt1.len() { - // will return an Err for the last chunk, that's ok - let _ = bwt1_reader.advance().await; - bwt1 = &bwt1_reader.current_chunk.bwt_chunk; - ind[1] = 0; - } - } else { - offsets[bwt0[ind[0]] as usize] += 1; - ind[0] += 1; - - if ind[0] == bwt0.len() { - let _ = bwt0_reader.advance().await; - bwt0 = &bwt0_reader.current_chunk.bwt_chunk; - ind[0] = 0; - } - } - } - - bwt0_reader.reset().await?; - bwt1_reader.reset().await?; - - // interleave_iterations += 1; - // println!( - // "{} ", - // interleave_iterations, - // ); - - if new_interleave == interleave { - break; - } - interleave = new_interleave; - } - - // println!("interleave iterations: {}", interleave_iterations); - Ok(interleave) -} - -async fn merge_lava_substring( - condensed_lava_file: &str, - lava_files: Vec, - uid_offsets: Vec, - reader_type: ReaderType, -) -> Result, LavaError> { - // first merge the tokenizer, then merge the fm indices then merge the posting lists. - // let mut builder = Fs::default(); - // let current_path = env::current_dir()?; - // builder.root(current_path.to_str().expect("no path")); - // let operator = Operator::new(builder)?.finish(); - - let mut compressed_tokenizer: Option> = None; - - // currently only support merging two files, but can support more in the future. - assert_eq!(lava_files.len(), 2); - assert_eq!(uid_offsets.len(), 2); - - let mut ns: Vec = vec![]; - let mut combined_cumulative_counts: Vec = vec![]; - let mut fm_chunk_iterators: Vec = vec![]; - let mut plist_iterators: Vec = vec![]; - - for file in lava_files { - // @Rain just make two different readers for now because this is hopefully low overhead - // instead of bothering with wrapping this thing in Arc>. Lots of tech debt to clean up - // needed for the FMChunkIterator and PListIterator - let (_, mut reader) = get_file_size_and_reader(file.clone(), reader_type.clone()).await?; - let (file_size, reader1) = get_file_size_and_reader(file.clone(), reader_type.clone()).await?; - let file_size = file_size as u64; - - let results = reader.read_usize_from_end(4).await?; - let fm_chunk_offsets_offset = results[0]; - let posting_list_offsets_offset = results[1]; - let total_counts_offset = results[2]; - let n = results[3]; - - ns.push(n); - - let compressed_tokenizer_size = reader.read_usize_from_start(0, 1).await?[0]; - let this_compressed_tokenizer: bytes::Bytes = reader.read_range(8, 8 + compressed_tokenizer_size).await?; - - match &compressed_tokenizer { - Some(value) => assert!( - this_compressed_tokenizer == value, - "detected different tokenizers, cannot merge, something is very wrong." - ), - None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()), - } - - let fm_chunk_offsets: Vec = - reader.read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset).await?; - let posting_list_offsets: Vec = - reader.read_range_and_decompress(posting_list_offsets_offset, total_counts_offset).await?; - let cumulative_counts: Vec = - reader.read_range_and_decompress(total_counts_offset, (file_size - 32) as u64).await?; - - // println!("{} {}", file, cumulative_counts.len()); - - fm_chunk_iterators.push(FMChunkIterator::new(reader, fm_chunk_offsets).await?); - plist_iterators.push(PListIterator::new(reader1, posting_list_offsets).await?); - - if combined_cumulative_counts.len() == 0 { - combined_cumulative_counts = cumulative_counts; - } else { - // add cumulative_counts to combined_cumulative_counts - for (i, count) in cumulative_counts.iter().enumerate() { - combined_cumulative_counts[i] += count; - } - } - } - - let mut bwt0_reader = fm_chunk_iterators.remove(0); - let mut bwt1_reader = fm_chunk_iterators.remove(0); - let mut plist0_reader = plist_iterators.remove(0); - let mut plist1_reader = plist_iterators.remove(0); - - // let start = std::time::Instant::now(); - let interleave: BitVec = compute_interleave( - &mut bwt0_reader, - &mut bwt1_reader, - (ns[0] as usize, ns[1] as usize), - &combined_cumulative_counts, - ) - .await?; - - let _ = bwt0_reader.reset().await?; - let _ = bwt1_reader.reset().await?; - - // let duration = start.elapsed(); - // println!("interleave time: {:?}", duration); - - let mut output_file = File::create(condensed_lava_file)?; - let compressed_tokenizer = compressed_tokenizer.unwrap(); - output_file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; - output_file.write_all(&compressed_tokenizer)?; - - let mut bwt_output: Vec = Vec::with_capacity(interleave.len()); - let mut index_output: Vec = Vec::with_capacity(interleave.len()); - - let mut bwt_ind0 = 0; - let mut bwt_ind1 = 0; - let mut idx_ind0 = 0; - let mut idx_ind1 = 0; - - let mut bwt0 = &bwt0_reader.current_chunk.bwt_chunk; - let mut bwt1 = &bwt1_reader.current_chunk.bwt_chunk; - let mut idx0 = &plist0_reader.current_chunk; - let mut idx1 = &plist1_reader.current_chunk; - - for i in 0..interleave.len() { - if interleave[i] { - bwt_output.push(bwt1[bwt_ind1]); - index_output.push(idx1[idx_ind1] + uid_offsets[1]); - - bwt_ind1 += 1; - if bwt_ind1 == bwt1.len() { - let _ = bwt1_reader.advance().await; - bwt1 = &bwt1_reader.current_chunk.bwt_chunk; - bwt_ind1 = 0; - } - - idx_ind1 += 1; - if idx_ind1 == idx1.len() { - let _ = plist1_reader.advance().await; - idx1 = &plist1_reader.current_chunk; - idx_ind1 = 0; - } - } else { - bwt_output.push(bwt0[bwt_ind0]); - index_output.push(idx0[idx_ind0] + uid_offsets[0]); - - bwt_ind0 += 1; - if bwt_ind0 == bwt0.len() { - let _ = bwt0_reader.advance().await; - bwt0 = &bwt0_reader.current_chunk.bwt_chunk; - bwt_ind0 = 0; - } - - idx_ind0 += 1; - if idx_ind0 == idx0.len() { - let _ = plist0_reader.advance().await; - idx0 = &plist0_reader.current_chunk; - idx_ind0 = 0; - } - } - } - - let mut current_chunk: Vec = vec![]; - let mut current_chunk_counts: HashMap = HashMap::new(); - let mut next_chunk_counts: HashMap = HashMap::new(); - let mut fm_chunk_offsets: Vec = vec![output_file.seek(SeekFrom::Current(0))? as usize]; - - for i in 0..bwt_output.len() { - let current_tok = bwt_output[i]; - next_chunk_counts.entry(current_tok).and_modify(|count| *count += 1).or_insert(1); - current_chunk.push(current_tok); - - if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt_output.len() - 1 { - let serialized_counts = bincode::serialize(¤t_chunk_counts)?; - let compressed_counts = encode_all(&serialized_counts[..], 0).expect("Compression failed"); - output_file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; - output_file.write_all(&compressed_counts)?; - let serialized_chunk = bincode::serialize(¤t_chunk)?; - let compressed_chunk = encode_all(&serialized_chunk[..], 0).expect("Compression failed"); - output_file.write_all(&compressed_chunk)?; - fm_chunk_offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); - current_chunk_counts = next_chunk_counts.clone(); - current_chunk = vec![]; - } - } - - let mut posting_list_offsets: Vec = vec![output_file.seek(SeekFrom::Current(0))? as usize]; - - for i in (0..index_output.len()).step_by(FM_CHUNK_TOKS) { - let slice = &index_output[i..std::cmp::min(index_output.len(), i + FM_CHUNK_TOKS)]; - let serialized_slice = bincode::serialize(slice)?; - let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); - output_file.write_all(&compressed_slice)?; - posting_list_offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); - } - - let cache_start = output_file.seek(SeekFrom::Current(0))? as usize; - - let fm_chunk_offsets_offset = output_file.seek(SeekFrom::Current(0))? as usize; - let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; - let compressed_fm_chunk_offsets = encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); - output_file.write_all(&compressed_fm_chunk_offsets)?; - - let posting_list_offsets_offset = output_file.seek(SeekFrom::Current(0))? as usize; - let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; - let compressed_posting_list_offsets = - encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); - output_file.write_all(&compressed_posting_list_offsets)?; - - let total_counts_offset = output_file.seek(SeekFrom::Current(0))? as usize; - let serialized_total_counts = bincode::serialize(&combined_cumulative_counts)?; - let compressed_total_counts: Vec = encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); - output_file.write_all(&compressed_total_counts)?; - - output_file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; - output_file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; - output_file.write_all(&(total_counts_offset as u64).to_le_bytes())?; - output_file.write_all(&(bwt_output.len() as u64).to_le_bytes())?; - - Ok(vec![(cache_start, output_file.seek(SeekFrom::Current(0))? as usize)]) -} - #[async_recursion] async fn async_parallel_merge_files( condensed_lava_file: String, @@ -665,17 +44,34 @@ async fn async_parallel_merge_files( let merged_files_shared = Arc::new(Mutex::new(vec![])); let new_uid_offsets_shared = Arc::new(Mutex::new(vec![])); - let chunked_files: Vec> = - files.into_iter().chunks(k).into_iter().map(|chunk| chunk.collect()).collect(); - - let chunked_uid_offsets: Vec> = - uid_offsets.into_iter().chunks(k).into_iter().map(|chunk| chunk.collect()).collect(); - - for (file_chunk, uid_chunk) in chunked_files.into_iter().zip(chunked_uid_offsets.into_iter()) { + let chunked_files: Vec> = files + .into_iter() + .chunks(k) + .into_iter() + .map(|chunk| chunk.collect()) + .collect(); + + let chunked_uid_offsets: Vec> = uid_offsets + .into_iter() + .chunks(k) + .into_iter() + .map(|chunk| chunk.collect()) + .collect(); + + for (file_chunk, uid_chunk) in chunked_files + .into_iter() + .zip(chunked_uid_offsets.into_iter()) + { if file_chunk.len() == 1 { // If there's an odd file out, directly move it to the next level - merged_files_shared.lock().unwrap().push(file_chunk[0].clone()); - new_uid_offsets_shared.lock().unwrap().push(uid_chunk[0].clone()); + merged_files_shared + .lock() + .unwrap() + .push(file_chunk[0].clone()); + new_uid_offsets_shared + .lock() + .unwrap() + .push(uid_chunk[0].clone()); continue; } @@ -741,15 +137,22 @@ async fn async_parallel_merge_files( } // Wait for all tasks to complete, MUST BE IN ORDER due to cache_ranges! - let cache_ranges: Vec> = - futures::future::join_all(tasks).await.into_iter().collect::, _>>().unwrap(); + let cache_ranges: Vec> = futures::future::join_all(tasks) + .await + .into_iter() + .collect::, _>>() + .unwrap(); // Extract the merged files for the next level of merging - let merged_files: Vec = - Arc::try_unwrap(merged_files_shared).expect("Lock still has multiple owners").into_inner().unwrap(); + let merged_files: Vec = Arc::try_unwrap(merged_files_shared) + .expect("Lock still has multiple owners") + .into_inner() + .unwrap(); - let new_uid_offsets = - Arc::try_unwrap(new_uid_offsets_shared).expect("Lock still has multiple owners").into_inner().unwrap(); + let new_uid_offsets = Arc::try_unwrap(new_uid_offsets_shared) + .expect("Lock still has multiple owners") + .into_inner() + .unwrap(); // Recurse with the newly merged files async_parallel_merge_files( @@ -777,9 +180,17 @@ pub async fn parallel_merge_files( reader_type: ReaderType, ) -> Result, LavaError> { let do_not_delete = BTreeSet::from_iter(files.clone().into_iter()); - let result = - async_parallel_merge_files(condensed_lava_file, files, do_not_delete, uid_offsets, k, mode, reader_type, None) - .await?; + let result = async_parallel_merge_files( + condensed_lava_file, + files, + do_not_delete, + uid_offsets, + k, + mode, + reader_type, + None, + ) + .await?; Ok(result) } @@ -805,7 +216,10 @@ mod tests { pub fn test_merge_lava_substring() { let res = parallel_merge_files( "merged.lava".to_string(), - vec!["chinese_index/0.lava".to_string(), "chinese_index/1.lava".to_string()], + vec![ + "chinese_index/0.lava".to_string(), + "chinese_index/1.lava".to_string(), + ], vec![0, 1000000], 2, 1, diff --git a/src/lava/mod.rs b/src/lava/mod.rs index 333621b..e345e7f 100644 --- a/src/lava/mod.rs +++ b/src/lava/mod.rs @@ -1,32 +1,28 @@ -mod build; -mod constants; pub mod error; -mod fm_chunk; + +mod bm25; mod logcloud; -mod logcloud_common; -mod logcloud_rex; mod merge; mod plist; mod search; -mod trie; -mod wavelet_tree; +mod substring; +mod uuid; +mod vector; -pub use build::build_lava_bm25; -pub use build::build_lava_substring; -pub use build::build_lava_substring_char; -pub use build::build_lava_uuid; +pub use bm25::build_lava_bm25; +pub use substring::build_lava_substring; +pub use substring::build_lava_substring_char; +pub use uuid::build_lava_uuid; pub use merge::parallel_merge_files; -pub use search::_search_lava_substring_char; pub use search::get_tokenizer_vocab; -pub use search::search_lava_bm25; pub use search::search_lava_substring; pub use search::search_lava_substring_char; pub use search::search_lava_uuid; -pub use search::search_lava_vector; +pub use vector::search_lava_vector; +pub use logcloud::compress_logs; pub use logcloud::index_analysis; pub use logcloud::index_logcloud; pub use logcloud::search_logcloud; -pub use logcloud_rex::compress_logs; diff --git a/src/lava/search.rs b/src/lava/search.rs index f2516b3..c8851c4 100644 --- a/src/lava/search.rs +++ b/src/lava/search.rs @@ -1,7 +1,6 @@ use crate::lava::constants::*; use crate::lava::fm_chunk::FMChunk; use crate::lava::plist::PListChunk; -use crate::lava::wavelet_tree::search_wavelet_tree_from_reader; use crate::{ formats::readers::{ get_file_size_and_reader, get_file_sizes_and_readers, get_reader, get_readers, AsyncReader, @@ -11,7 +10,6 @@ use crate::{ }; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; use itertools::Itertools; -use ndarray::{concatenate, stack, Array1, Array2, Axis}; use serde::de::DeserializeOwned; use std::collections::BTreeSet; use std::sync::Arc; @@ -22,29 +20,31 @@ use std::{ }; use tokenizers::tokenizer::Tokenizer; use tokio::task::JoinSet; -use zstd::stream::read::Decoder; -use super::trie::FastTrie; -use super::wavelet_tree; use futures::stream::{FuturesUnordered, StreamExt}; use std::cmp::Ordering; use std::io::{self, Cursor}; +use super::bm25::search_bm25_async; + enum QueryParam { SubstringCharWavelet(Vec>), SubstringChar(Vec>), Substring(Vec>), Uuid(String), } -use std::fmt::Debug; -async fn get_tokenizer_async(mut readers: Vec) -> Result<(Tokenizer, Vec), LavaError> { +async fn get_tokenizer_async( + mut readers: Vec, +) -> Result<(Tokenizer, Vec), LavaError> { let mut compressed_tokenizer: Option> = None; for i in 0..readers.len() { // now interpret this as a usize // readers[i].seek(SeekFrom::Start(0)).await?; let compressed_tokenizer_size = readers[i].read_usize_from_start(0, 1).await?[0]; - let this_compressed_tokenizer: bytes::Bytes = readers[i].read_range(8, 8 + compressed_tokenizer_size).await?; + let this_compressed_tokenizer: bytes::Bytes = readers[i] + .read_range(8, 8 + compressed_tokenizer_size) + .await?; match &compressed_tokenizer { Some(value) => assert!( this_compressed_tokenizer == value, @@ -70,261 +70,6 @@ async fn get_tokenizer_async(mut readers: Vec) -> Result<(Tokenizer Ok((tokenizer, result)) } -use num_traits::{AsPrimitive, PrimInt, Unsigned}; -use serde::{Deserialize, Serialize}; -use std::ops::Add; - -async fn process_substring_query( - query: Vec, - n: u64, - fm_chunk_offsets: &[u64], - cumulative_counts: &[u64], - posting_list_offsets: &[u64], - reader: &mut AsyncReader, - file_id: u64, -) -> Vec<(u64, u64)> -where - T: PrimInt - + Unsigned - + Serialize - + for<'de> Deserialize<'de> - + Clone - + Eq - + std::hash::Hash - + AsPrimitive - + 'static, - usize: AsPrimitive, -{ - let mut res: Vec<(u64, u64)> = vec![]; - let mut start: usize = 0; - let mut end: usize = n as usize; - - for i in (0..query.len()).rev() { - let current_token = query[i]; - - let start_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS]; - let end_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS + 1]; - let start_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); - - let start_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS]; - let end_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS + 1]; - let end_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); - - start = cumulative_counts[current_token.as_()] as usize - + FMChunk::::new(start_chunk).unwrap().search(current_token, start % FM_CHUNK_TOKS).unwrap() as usize; - end = cumulative_counts[current_token.as_()] as usize - + FMChunk::::new(end_chunk).unwrap().search(current_token, end % FM_CHUNK_TOKS).unwrap() as usize; - - if start >= end { - return res; - } - - if end <= start + 2 { - break; - } - } - - let start_offset = posting_list_offsets[start / FM_CHUNK_TOKS]; - let end_offset = posting_list_offsets[end / FM_CHUNK_TOKS + 1]; - let total_chunks = end / FM_CHUNK_TOKS - start / FM_CHUNK_TOKS + 1; - - let plist_chunks = reader.read_range(start_offset, end_offset).await.unwrap(); - - let mut chunk_set = JoinSet::new(); - - for i in 0..total_chunks { - let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; - let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; - let this_chunk = - plist_chunks[(this_start - start_offset) as usize..(this_end - start_offset) as usize].to_vec(); - - chunk_set.spawn(async move { - let mut decompressor = Decoder::new(&this_chunk[..]).unwrap(); - let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len()); - decompressor.read_to_end(&mut serialized_plist_chunk).unwrap(); - let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk).unwrap(); - - let chunk_res: Vec<(u64, u64)> = if i == 0 { - if total_chunks == 1 { - plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS].iter().map(|&uid| (file_id, uid)).collect() - } else { - plist_chunk[start % FM_CHUNK_TOKS..].iter().map(|&uid| (file_id, uid)).collect() - } - } else if i == total_chunks - 1 { - plist_chunk[..end % FM_CHUNK_TOKS].iter().map(|&uid| (file_id, uid)).collect() - } else { - plist_chunk.iter().map(|&uid| (file_id, uid)).collect() - }; - - chunk_res - }); - } - - while let Some(chunk_res) = chunk_set.join_next().await { - res.extend(chunk_res.unwrap()); - } - - res -} - -async fn read_and_decompress(reader: &mut AsyncReader, start: u64, size: u64) -> Result -where - T: DeserializeOwned, -{ - let compressed = reader.read_range(start, start + size).await?; - let mut decompressor = Decoder::new(&compressed[..]).unwrap(); - let mut decompressed = Vec::new(); - std::io::copy(&mut decompressor, &mut decompressed)?; - let result: T = bincode::deserialize(&decompressed)?; - Ok(result) -} - -async fn search_substring_wavelet( - file_id: u64, - mut reader: AsyncReader, - file_size: usize, - queries: Vec>, -) -> Result, LavaError> { - println!("{:?}", queries); - - let metadata_start = reader.read_usize_from_end(1).await?[0]; - - let metadata: (Vec, Vec, Vec, Vec, usize) = - read_and_decompress(&mut reader, metadata_start as u64, file_size as u64 - metadata_start - 8).await.unwrap(); - let (offsets, level_offsets, posting_list_offsets, cumulative_counts, n) = metadata; - - // let mut query_set = JoinSet::new(); - - let mut res: Vec<(u64, u64)> = vec![]; - - for query in queries { - let mut reader = reader.clone(); - let (start, end) = - search_wavelet_tree_from_reader(&mut reader, &query, n, &offsets, &level_offsets, &cumulative_counts) - .await?; - - println!("{} {}", start, end); - - if start == end { - continue; - } - - let start_offset = posting_list_offsets[start / FM_CHUNK_TOKS]; - let end_offset = posting_list_offsets[end / FM_CHUNK_TOKS + 1]; - let total_chunks = end / FM_CHUNK_TOKS - start / FM_CHUNK_TOKS + 1; - - let plist_chunks = reader.read_range(start_offset as u64, end_offset as u64).await.unwrap(); - - let mut chunk_set = JoinSet::new(); - - for i in 0..total_chunks { - let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; - let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; - let this_chunk = - plist_chunks[(this_start - start_offset) as usize..(this_end - start_offset) as usize].to_vec(); - - chunk_set.spawn(async move { - let mut decompressor = Decoder::new(&this_chunk[..]).unwrap(); - let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len()); - decompressor.read_to_end(&mut serialized_plist_chunk).unwrap(); - let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk).unwrap(); - - let chunk_res: Vec<(u64, u64)> = if i == 0 { - if total_chunks == 1 { - plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS] - .iter() - .map(|&uid| (file_id, uid)) - .collect() - } else { - plist_chunk[start % FM_CHUNK_TOKS..].iter().map(|&uid| (file_id, uid)).collect() - } - } else if i == total_chunks - 1 { - plist_chunk[..end % FM_CHUNK_TOKS].iter().map(|&uid| (file_id, uid)).collect() - } else { - plist_chunk.iter().map(|&uid| (file_id, uid)).collect() - }; - - chunk_res - }); - } - - while let Some(chunk_res) = chunk_set.join_next().await { - res.extend(chunk_res.unwrap()); - } - } - - // let mut res = Vec::new(); - // while let Some(query_res) = query_set.join_next().await { - // res.extend(query_res.unwrap()); - // } - - Ok(res) -} - -async fn search_substring_one_file( - file_id: u64, - mut reader: AsyncReader, - file_size: usize, - queries: Vec>, -) -> Result, LavaError> -where - T: PrimInt - + Unsigned - + Serialize - + for<'de> Deserialize<'de> - + Clone - + Eq - + std::hash::Hash - + AsPrimitive - + Debug - + Send - + 'static, - usize: AsPrimitive, -{ - println!("{:?}", queries); - - let results = reader.read_usize_from_end(4).await?; - let fm_chunk_offsets_offset = results[0]; - let posting_list_offsets_offset = results[1]; - let total_counts_offset = results[2]; - let n = results[3]; - - let fm_chunk_offsets: Vec = - reader.read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset).await?; - let posting_list_offsets: Vec = - reader.read_range_and_decompress(posting_list_offsets_offset, total_counts_offset).await?; - let cumulative_counts: Vec = - reader.read_range_and_decompress(total_counts_offset, (file_size - 32) as u64).await?; - - let mut query_set = JoinSet::new(); - - for query in queries { - let fm_chunk_offsets = fm_chunk_offsets.clone(); - let cumulative_counts = cumulative_counts.clone(); - let posting_list_offsets = posting_list_offsets.clone(); - let mut reader = reader.clone(); - - query_set.spawn(async move { - process_substring_query::( - query, - n, - &fm_chunk_offsets, - &cumulative_counts, - &posting_list_offsets, - &mut reader, - file_id, - ) - .await - }); - } - - let mut res = Vec::new(); - while let Some(query_res) = query_set.join_next().await { - res.extend(query_res.unwrap()); - } - Ok(res) -} - async fn search_uuid_one_file( file_id: u64, mut reader: AsyncReader, @@ -334,14 +79,10 @@ async fn search_uuid_one_file( let mut result: Vec<(u64, u64)> = Vec::new(); let mut start_time = Instant::now(); - let this_result: Vec = FastTrie::query_with_reader(file_size, &mut reader, &query).await?; + let this_result: Vec = + FastTrie::query_with_reader(file_size, &mut reader, &query).await?; result.extend(this_result.iter().map(|x| (file_id, *x as u64))); - // println!( - // "search_uuid_one_file: {}ms", - // start_time.elapsed().as_millis() - // ); - Ok(result) } @@ -360,16 +101,36 @@ async fn search_generic_async( match query { QueryParam::Substring(ref value) => { - join_set.spawn(search_substring_one_file::(file_id as u64, reader, file_size, value.clone())); + join_set.spawn(search_substring_one_file::( + file_id as u64, + reader, + file_size, + value.clone(), + )); } QueryParam::SubstringChar(ref value) => { - join_set.spawn(search_substring_one_file::(file_id as u64, reader, file_size, value.clone())); + join_set.spawn(search_substring_one_file::( + file_id as u64, + reader, + file_size, + value.clone(), + )); } QueryParam::SubstringCharWavelet(ref value) => { - join_set.spawn(search_substring_wavelet(file_id as u64, reader, file_size, value.clone())); + join_set.spawn(search_substring_wavelet_one_file( + file_id as u64, + reader, + file_size, + value.clone(), + )); } QueryParam::Uuid(ref value) => { - join_set.spawn(search_uuid_one_file(file_id as u64, reader, file_size, value.clone())); + join_set.spawn(search_uuid_one_file( + file_id as u64, + reader, + file_size, + value.clone(), + )); } _ => panic!("invalid mode"), } @@ -395,148 +156,6 @@ async fn search_generic_async( Ok(result) } -async fn search_bm25_async( - file_sizes: Vec, - mut readers: Vec, - query_tokens: Vec, - query_weights: Vec, - k: usize, -) -> Result, LavaError> { - let mut idf: HashMap = HashMap::new(); - let mut total_token_counts: HashMap = HashMap::new(); - for token in query_tokens.iter() { - total_token_counts.insert(*token, 0); - } - let mut total_documents: usize = 0; - let mut all_plist_offsets: Vec> = Vec::new(); - let mut chunks_to_search: HashMap<(usize, usize), Vec<(u32, u64)>> = HashMap::new(); - - for i in 0..readers.len() { - let results = readers[i].read_usize_from_end(3).await?; - let compressed_term_dictionary_offset = results[0]; - let compressed_plist_offsets_offset = results[1]; - let num_documents = results[2]; - - // now read the term dictionary - let token_counts = readers[i] - .read_range_and_decompress(compressed_term_dictionary_offset, compressed_plist_offsets_offset) - .await?; - - for query_token in query_tokens.iter() { - total_token_counts - .insert(*query_token, total_token_counts[query_token] + token_counts[*query_token as usize] as usize); - } - total_documents += num_documents as usize; - - let plist_offsets = - readers[i].read_range_and_decompress(compressed_plist_offsets_offset, file_sizes[i] as u64 - 24).await?; - - if plist_offsets.len() % 2 != 0 { - let err = LavaError::Parse("data corruption".to_string()); - return Err(err); - } - - let num_chunks: usize = plist_offsets.len() / 2; - let term_dict_len: &[u64] = &plist_offsets[num_chunks..]; - - for token in query_tokens.iter() { - let tok = *token as u64; - let (idx, offset) = match term_dict_len.binary_search(&tok) { - Ok(idx) => (idx, 0), - Err(idx) => (idx - 1, tok - term_dict_len[idx - 1]), - }; - - chunks_to_search.entry((i as usize, idx)).or_insert_with(Vec::new).push((*token, offset as u64)); - } - - all_plist_offsets.push(plist_offsets); - } - - // compute the weighted IDF for each query token - for (i, query_token) in query_tokens.iter().enumerate() { - let query_weight = query_weights[i]; - let query_token = *query_token; - let token_count = total_token_counts[&query_token]; - idf.insert( - query_token, - query_weight - * ((total_documents as f32 - token_count as f32 + 0.5) / (token_count as f32 + 0.5) + 1.0).ln(), - ); - } - - let mut plist_result: Vec<(u64, u64)> = Vec::new(); - let mut page_scores: HashMap<(u64, u64), f32> = HashMap::new(); - - let mut join_set: JoinSet, LavaError>> = JoinSet::new(); - // need to parallelize this @Rain. - for (file_id, chunk_id, tokens, offsets) in - chunks_to_search.into_iter().map(|((file_id, chunk_id), token_offsets)| { - let (tokens, offsets): (Vec, Vec) = token_offsets.into_iter().unzip(); - (file_id, chunk_id, Arc::new(tokens), Arc::new(offsets)) - }) - { - let reader_type = match readers[file_id].reader { - ClonableAsyncReader::AwsSdk(_) => ReaderType::AwsSdk, - ClonableAsyncReader::Http(_) => ReaderType::Http, - ClonableAsyncReader::Local(_) => ReaderType::Local, - }; - - let mut reader = match reader_type { - ReaderType::AwsSdk | ReaderType::Http => readers[file_id].clone(), - ReaderType::Local => { - get_file_size_and_reader(readers[file_id].filename.clone(), reader_type).await.unwrap().1 - } - }; - let start = all_plist_offsets[file_id][chunk_id]; - let end = all_plist_offsets[file_id][chunk_id + 1]; - let tokens = tokens.clone(); - let offsets = offsets.clone(); - - join_set.spawn(async move { - // println!("file_id: {}, chunk_id: {}", file_id, chunk_id); - let buffer3 = reader.read_range(start, end).await?; - - // get all the second item in the offsets into its own vector - - let results: Vec> = PListChunk::search_compressed(buffer3.to_vec(), offsets.as_ref())?; - - let mut res = vec![]; - for (i, result) in results.iter().enumerate() { - let token = &tokens[i]; - assert_eq!(result.len() % 2, 0); - for i in (0..result.len()).step_by(2) { - let uid = result[i]; - let page_score = result[i + 1]; - res.push((file_id, uid, *token, page_score)); - } - } - Ok(res) - }); - } - - while let Some(res) = join_set.join_next().await { - let res = res.map_err(|e| LavaError::Parse(format!("join error: {:?}", e)))??; - for (file_id, uid, token, page_score) in res { - page_scores - .entry((file_id as u64, uid)) - .and_modify(|e| *e += idf[&token] * page_score as f32) - .or_insert(idf[&token] * page_score as f32); - } - } - - // sort the page scores by descending order - let mut page_scores_vec: Vec<((u64, u64), f32)> = page_scores.into_iter().collect(); - page_scores_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); - - // get the top k results - for (uid, _score) in page_scores_vec.iter().take(k) { - // println!("{}", score); - plist_result.push(*uid); - } - - Ok(plist_result) -} - #[tokio::main] pub async fn search_lava_bm25( files: Vec, @@ -575,18 +194,48 @@ pub async fn search_lava_substring( let mut skip_tokens: HashSet = HashSet::new(); for char in SKIP.chars() { let char_str = char.to_string(); - skip_tokens.extend(tokenizer.encode(char_str.clone(), false).unwrap().get_ids().to_vec()); - skip_tokens.extend(tokenizer.encode(format!(" {}", char_str), false).unwrap().get_ids().to_vec()); - skip_tokens.extend(tokenizer.encode(format!("{} ", char_str), false).unwrap().get_ids().to_vec()); + skip_tokens.extend( + tokenizer + .encode(char_str.clone(), false) + .unwrap() + .get_ids() + .to_vec(), + ); + skip_tokens.extend( + tokenizer + .encode(format!(" {}", char_str), false) + .unwrap() + .get_ids() + .to_vec(), + ); + skip_tokens.extend( + tokenizer + .encode(format!("{} ", char_str), false) + .unwrap() + .get_ids() + .to_vec(), + ); } let lower: String = query.chars().flat_map(|c| c.to_lowercase()).collect(); let encoding = tokenizer.encode(lower, false).unwrap(); - let result: Vec = encoding.get_ids().iter().filter(|id| !skip_tokens.contains(id)).cloned().collect(); + let result: Vec = encoding + .get_ids() + .iter() + .filter(|id| !skip_tokens.contains(id)) + .cloned() + .collect(); let mut query: Vec> = if let Some(sample_factor) = sample_factor { (0..sample_factor) - .map(|offset| result.iter().skip(offset).step_by(sample_factor).cloned().collect::>()) + .map(|offset| { + result + .iter() + .skip(offset) + .step_by(sample_factor) + .cloned() + .collect::>() + }) .filter(|vec| !vec.is_empty()) .collect() } else { @@ -599,7 +248,13 @@ pub async fn search_lava_substring( if let Some(token_viable_limit) = token_viable_limit { query.iter_mut().for_each(|vec| { if vec.len() > token_viable_limit { - *vec = vec.iter().rev().take(token_viable_limit).rev().cloned().collect(); + *vec = vec + .iter() + .rev() + .take(token_viable_limit) + .rev() + .cloned() + .collect(); } }); } @@ -610,253 +265,32 @@ pub async fn search_lava_substring( search_generic_async(file_sizes, readers, QueryParam::Substring(query), k).await } -pub async fn _search_lava_substring_char( +#[tokio::main] +pub async fn search_lava_substring_char( files: Vec, query: String, k: usize, reader_type: ReaderType, token_viable_limit: Option, sample_factor: Option, - wavelet_tree: bool, ) -> Result, LavaError> { - let lower: String = query.chars().flat_map(|c| c.to_lowercase()).collect(); - let result: Vec = lower.chars().filter(|id| !SKIP.chars().contains(id)).map(|c| c as u8).collect(); - - let mut query: Vec> = if let Some(sample_factor) = sample_factor { - (0..sample_factor) - .map(|offset| result.iter().skip(offset).step_by(sample_factor).cloned().collect::>()) - .filter(|vec| !vec.is_empty()) - .collect() - } else { - vec![result] - }; - - // println!("query {:?}", query); - - // query = [i[-token_viable_limit:] for i in query] - if let Some(token_viable_limit) = token_viable_limit { - query.iter_mut().for_each(|vec| { - if vec.len() > token_viable_limit { - *vec = vec.iter().rev().take(token_viable_limit).rev().cloned().collect(); - } - }); - } - - // println!("query {:?}", query); - - let (file_sizes, readers) = get_file_sizes_and_readers(&files, reader_type).await?; - search_generic_async( - file_sizes, - readers, - if wavelet_tree { QueryParam::SubstringCharWavelet(query) } else { QueryParam::SubstringChar(query) }, + _search_lava_substring_char( + files, + query, k, + reader_type, + token_viable_limit, + sample_factor, + false, ) .await } #[tokio::main] -pub async fn search_lava_substring_char( +pub async fn get_tokenizer_vocab( files: Vec, - query: String, - k: usize, reader_type: ReaderType, - token_viable_limit: Option, - sample_factor: Option, -) -> Result, LavaError> { - _search_lava_substring_char(files, query, k, reader_type, token_viable_limit, sample_factor, false).await -} - -fn bytes_to_f32_vec(bytes: &[u8]) -> Vec { - let mut vec = Vec::with_capacity(bytes.len() / 4); - let mut i = 0; - while i < bytes.len() { - let value = LittleEndian::read_f32(&bytes[i..i + 4]); - vec.push(value); - i += 4; - } - vec -} - -pub fn search_lava_vector( - files: Vec, - query: Vec, - nprobes: usize, - reader_type: ReaderType, -) -> Result<(Vec, Vec>, Vec<(usize, Array1)>), LavaError> { - let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); - - let res = rt.block_on(search_lava_vector_async(files, query, nprobes, reader_type)); - rt.shutdown_background(); - res -} - -pub async fn search_lava_vector_async( - files: Vec, - query: Vec, - nprobes: usize, - reader_type: ReaderType, -) -> Result<(Vec, Vec>, Vec<(usize, Array1)>), LavaError> { - let start = Instant::now(); - - let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; - - let mut futures = Vec::new(); - - for _ in 0..readers.len() { - let mut reader = readers.remove(0); - - futures.push(tokio::spawn(async move { - let results = reader.read_usize_from_end(4).await.unwrap(); - - let centroid_vectors_compressed_bytes = reader.read_range(results[2], results[3]).await.unwrap(); - - // decompress them - let mut decompressor = Decoder::new(centroid_vectors_compressed_bytes.as_ref()).unwrap(); - let mut centroid_vectors: Vec = Vec::with_capacity(centroid_vectors_compressed_bytes.len() as usize); - decompressor.read_to_end(&mut centroid_vectors).unwrap(); - - let centroid_vectors = bytes_to_f32_vec(¢roid_vectors); - let num_vectors = centroid_vectors.len() / 128; - let array2 = Array2::::from_shape_vec((num_vectors, 128), centroid_vectors).unwrap(); - - (num_vectors, array2) - })); - } - - let result: Vec), tokio::task::JoinError>> = futures::future::join_all(futures).await; - - let end = Instant::now(); - println!("Time stage 1 read: {:?}", end - start); - - let start = Instant::now(); - - let arr_lens = result.iter().map(|x| x.as_ref().unwrap().0).collect::>(); - // get cumulative arr len starting from 0 - let cumsum = arr_lens - .iter() - .scan(0, |acc, &x| { - *acc += x; - Some(*acc) - }) - .collect::>(); - - let arrays: Vec> = result.into_iter().map(|x| x.unwrap().1).collect(); - let centroids = - concatenate(Axis(0), arrays.iter().map(|array| array.view()).collect::>().as_slice()).unwrap(); - let query = Array1::::from_vec(query); - let query_broadcast = query.broadcast(centroids.dim()).unwrap(); - - let difference = ¢roids - &query_broadcast; - let norms = difference.map_axis(Axis(1), |row| row.dot(&row).sqrt()); - let mut indices_and_values: Vec<(usize, f32)> = norms.iter().enumerate().map(|(idx, &val)| (idx, val)).collect(); - - indices_and_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)); - let smallest_indices: Vec = indices_and_values.iter().map(|&(idx, _)| idx).take(nprobes).collect(); - - let mut file_indices: Vec> = vec![vec![]; files.len()]; - for idx in smallest_indices.iter() { - // figure out which file idx based on cumsum. need to find the index of the thing that is just bigger than idx - - let file_idx = cumsum.iter().enumerate().find(|(_, &val)| val > *idx).unwrap().0; - let last_cumsum = if file_idx == 0 { 0 } else { cumsum[file_idx - 1] }; - let remainder = idx - last_cumsum; - file_indices[file_idx].push(remainder); - } - - let end = Instant::now(); - println!("Time math: {:?}", end - start); - - let start = Instant::now(); - - let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; - - let mut file_ids = vec![]; - let mut futures = Vec::new(); - - for file_id in 0..readers.len() { - let mut reader = readers.remove(0); - if file_indices[file_id].len() == 0 { - continue; - } - let my_idx: Vec = file_indices[file_id].clone(); - file_ids.push(file_id); - - futures.push(tokio::spawn(async move { - let results = reader.read_usize_from_end(4).await.unwrap(); - - let pq_bytes = reader.read_range(results[0], results[1]).await.unwrap(); - - let compressed_centroid_offset_bytes = reader.read_range(results[1], results[2]).await.unwrap(); - let mut decompressor = Decoder::new(compressed_centroid_offset_bytes.as_ref()).unwrap(); - let mut centroid_offsets_bytes: Vec = - Vec::with_capacity(compressed_centroid_offset_bytes.len() as usize); - decompressor.read_to_end(&mut centroid_offsets_bytes).unwrap(); - - // now reinterpret centroid_offsets_bytes as a Vec - - let mut centroid_offsets = Vec::with_capacity(centroid_offsets_bytes.len() / 8); - let mut cursor = Cursor::new(centroid_offsets_bytes); - - while cursor.position() < cursor.get_ref().len() as u64 { - let value = cursor.read_u64::().unwrap(); - centroid_offsets.push(value); - } - - let mut this_result: Vec<(usize, u64, u64)> = vec![]; - - for idx in my_idx.iter() { - this_result.push((file_id, centroid_offsets[*idx], centroid_offsets[*idx + 1])); - } - (this_result, Array1::::from_vec(pq_bytes.to_vec())) - })); - } - - let result: Vec, Array1), tokio::task::JoinError>> = - futures::future::join_all(futures).await; - let result: Vec<(Vec<(usize, u64, u64)>, Array1)> = result.into_iter().map(|x| x.unwrap()).collect(); - - let pq_bytes: Vec> = result.iter().map(|x| x.1.clone()).collect::>(); - - let end = Instant::now(); - println!("Time stage 2 read: {:?}", end - start); - - let start = Instant::now(); - let reader = get_reader(files[file_ids[0]].clone(), reader_type.clone()).await.unwrap(); - - let mut futures = FuturesUnordered::new(); - for i in 0..result.len() { - let to_read = result[i].0.clone(); - for (file_id, start, end) in to_read.into_iter() { - let mut reader_c = reader.clone(); - reader_c.update_filename(files[file_id].clone()).unwrap(); - - futures.push(tokio::spawn(async move { - let start_time = Instant::now(); - let codes_and_plist = reader_c.read_range(start, end).await.unwrap(); - // println!( - // "Time to read {:?}, {:?}", - // Instant::now() - start_time, - // codes_and_plist.len() - // ); - (file_id, Array1::::from_vec(codes_and_plist.to_vec())) - })); - } - } - - let mut ranges: Vec<(usize, Array1)> = vec![]; - - while let Some(x) = futures.next().await { - ranges.push(x.unwrap()); - } - - let end = Instant::now(); - println!("Time stage 3 read: {:?}", end - start); - - Ok((file_ids, pq_bytes, ranges)) -} - -#[tokio::main] -pub async fn get_tokenizer_vocab(files: Vec, reader_type: ReaderType) -> Result, LavaError> { +) -> Result, LavaError> { let (_file_sizes, readers) = get_file_sizes_and_readers(&files, reader_type).await?; Ok(get_tokenizer_async(readers).await?.1) } @@ -872,9 +306,14 @@ mod tests { pub fn test_search_lava_one() { let file = "msmarco_index/1.lava"; - let res = - search_lava_bm25(vec![file.to_string()], vec![6300, 15050], vec![0.1, 0.2], 10, ReaderType::default()) - .unwrap(); + let res = search_lava_bm25( + vec![file.to_string()], + vec![6300, 15050], + vec![0.1, 0.2], + 10, + ReaderType::default(), + ) + .unwrap(); println!("{:?}", res); } diff --git a/src/lava/constants.rs b/src/lava/substring/constants.rs similarity index 100% rename from src/lava/constants.rs rename to src/lava/substring/constants.rs diff --git a/src/lava/fm_chunk.rs b/src/lava/substring/fm_chunk.rs similarity index 81% rename from src/lava/fm_chunk.rs rename to src/lava/substring/fm_chunk.rs index bc0d1a8..b54afe9 100644 --- a/src/lava/fm_chunk.rs +++ b/src/lava/substring/fm_chunk.rs @@ -1,4 +1,4 @@ -use super::error::LavaError; +use crate::lava::error::LavaError; use bytes::Bytes; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -27,22 +27,28 @@ where let counts: HashMap = bincode::deserialize(&serialized_counts)?; let compressed_fm_chunk = &chunk[(compressed_counts_size + 8) as usize..]; let mut decompressor = Decoder::new(compressed_fm_chunk)?; - let mut serialized_fm_chunk: Vec = Vec::with_capacity(compressed_fm_chunk.len() as usize); + let mut serialized_fm_chunk: Vec = + Vec::with_capacity(compressed_fm_chunk.len() as usize); decompressor.read_to_end(&mut serialized_fm_chunk)?; let fm_chunk: Vec = bincode::deserialize(&serialized_fm_chunk)?; - Ok(Self { counts_so_far: counts, bwt_chunk: fm_chunk }) + Ok(Self { + counts_so_far: counts, + bwt_chunk: fm_chunk, + }) } #[allow(dead_code)] pub fn serialize(&mut self) -> Result, LavaError> { let serialized_counts = bincode::serialize(&self.counts_so_far)?; - let mut compressed_counts = encode_all(&serialized_counts[..], 0).expect("Compression failed"); + let mut compressed_counts = + encode_all(&serialized_counts[..], 0).expect("Compression failed"); let mut result: Vec = vec![]; result.append(&mut (compressed_counts.len() as u64).to_le_bytes().to_vec()); result.append(&mut compressed_counts); let serialized_chunk = bincode::serialize(&self.bwt_chunk)?; - let mut compressed_chunk = encode_all(&serialized_chunk[..], 0).expect("Compression failed"); + let mut compressed_chunk = + encode_all(&serialized_chunk[..], 0).expect("Compression failed"); result.append(&mut compressed_chunk); Ok(result) } diff --git a/src/lava/substring/merge.rs b/src/lava/substring/merge.rs new file mode 100644 index 0000000..ada1f26 --- /dev/null +++ b/src/lava/substring/merge.rs @@ -0,0 +1,389 @@ +use super::constants::*; +use super::fm_chunk::FMChunk; +use crate::formats::readers::{ + get_file_size_and_reader, get_file_sizes_and_readers, AsyncReader, ReaderType, +}; +use zstd::stream::encode_all; + +struct PListIterator { + reader: AsyncReader, + plist_offsets: Vec, + current_chunk_offset: usize, + pub current_chunk: Vec, +} + +impl PListIterator { + // take ownership of the data structures + pub async fn new(mut reader: AsyncReader, plist_offsets: Vec) -> Result { + let plist_chunk = reader + .read_range_and_decompress(plist_offsets[0], plist_offsets[1]) + .await?; + Ok(Self { + reader: reader, + plist_offsets: plist_offsets, + current_chunk_offset: 0, + current_chunk: plist_chunk, + }) + } + + pub async fn advance(&mut self) -> Result<(), LavaError> { + self.current_chunk_offset += 1; + if self.current_chunk_offset + 2 > self.plist_offsets.len() { + return Err(LavaError::Parse("out of chunks".to_string())); + } + self.current_chunk = self + .reader + .read_range_and_decompress( + self.plist_offsets[self.current_chunk_offset], + self.plist_offsets[self.current_chunk_offset + 1], + ) + .await?; + Ok(()) + } +} + +struct FMChunkIterator { + reader: AsyncReader, + fm_chunk_offsets: Vec, + current_chunk_offset: usize, + pub current_chunk: FMChunk, +} + +impl FMChunkIterator { + // take ownership of the data structures + pub async fn new( + mut reader: AsyncReader, + fm_chunk_offsets: Vec, + ) -> Result { + let buffer3 = reader + .read_range(fm_chunk_offsets[0], fm_chunk_offsets[1]) + .await?; + let current_chunk = FMChunk::new(buffer3)?; + + Ok(Self { + reader: reader, + fm_chunk_offsets: fm_chunk_offsets, + current_chunk_offset: 0, + current_chunk: current_chunk, + }) + } + + pub async fn advance(&mut self) -> Result<(), LavaError> { + self.current_chunk_offset += 1; + + if self.current_chunk_offset + 2 > self.fm_chunk_offsets.len() { + return Err(LavaError::Parse("out of chunks".to_string())); + } + let buffer3 = self + .reader + .read_range( + self.fm_chunk_offsets[self.current_chunk_offset], + self.fm_chunk_offsets[self.current_chunk_offset + 1], + ) + .await?; + self.current_chunk = FMChunk::new(buffer3)?; + + Ok(()) + } + + pub async fn reset(&mut self) -> Result<(), LavaError> { + self.current_chunk = FMChunk::new( + self.reader + .read_range(self.fm_chunk_offsets[0], self.fm_chunk_offsets[1]) + .await?, + )?; + self.current_chunk_offset = 0; + + Ok(()) + } +} + +async fn compute_interleave( + bwt0_reader: &mut FMChunkIterator, + bwt1_reader: &mut FMChunkIterator, + lens: (usize, usize), + cumulative_counts: &Vec, +) -> Result { + let (bwt0_len, bwt1_len) = lens; + + let mut interleave = BitVec::from_elem(bwt0_len + bwt1_len, true); + for i in 0..bwt0_len { + interleave.set(i, false); + } + + // let mut interleave_iterations = 0; + + for _ in 0..10 { + let mut ind: [usize; 2] = [0, 0]; + + let mut bwt0 = &bwt0_reader.current_chunk.bwt_chunk; + let mut bwt1 = &bwt1_reader.current_chunk.bwt_chunk; + + let mut offsets = cumulative_counts.clone(); + let mut new_interleave = BitVec::from_elem(interleave.len(), false); + for i in 0..interleave.len() { + if interleave[i] { + new_interleave.set(offsets[bwt1[ind[1]] as usize] as usize, true); + offsets[bwt1[ind[1]] as usize] += 1; + ind[1] += 1; + + if ind[1] == bwt1.len() { + // will return an Err for the last chunk, that's ok + let _ = bwt1_reader.advance().await; + bwt1 = &bwt1_reader.current_chunk.bwt_chunk; + ind[1] = 0; + } + } else { + offsets[bwt0[ind[0]] as usize] += 1; + ind[0] += 1; + + if ind[0] == bwt0.len() { + let _ = bwt0_reader.advance().await; + bwt0 = &bwt0_reader.current_chunk.bwt_chunk; + ind[0] = 0; + } + } + } + + bwt0_reader.reset().await?; + bwt1_reader.reset().await?; + + // interleave_iterations += 1; + // println!( + // "{} ", + // interleave_iterations, + // ); + + if new_interleave == interleave { + break; + } + interleave = new_interleave; + } + + // println!("interleave iterations: {}", interleave_iterations); + Ok(interleave) +} + +pub(crate) async fn merge_lava_substring( + condensed_lava_file: &str, + lava_files: Vec, + uid_offsets: Vec, + reader_type: ReaderType, +) -> Result, LavaError> { + // first merge the tokenizer, then merge the fm indices then merge the posting lists. + // let mut builder = Fs::default(); + // let current_path = env::current_dir()?; + // builder.root(current_path.to_str().expect("no path")); + // let operator = Operator::new(builder)?.finish(); + + let mut compressed_tokenizer: Option> = None; + + // currently only support merging two files, but can support more in the future. + assert_eq!(lava_files.len(), 2); + assert_eq!(uid_offsets.len(), 2); + + let mut ns: Vec = vec![]; + let mut combined_cumulative_counts: Vec = vec![]; + let mut fm_chunk_iterators: Vec = vec![]; + let mut plist_iterators: Vec = vec![]; + + for file in lava_files { + // @Rain just make two different readers for now because this is hopefully low overhead + // instead of bothering with wrapping this thing in Arc>. Lots of tech debt to clean up + // needed for the FMChunkIterator and PListIterator + let (_, mut reader) = get_file_size_and_reader(file.clone(), reader_type.clone()).await?; + let (file_size, reader1) = + get_file_size_and_reader(file.clone(), reader_type.clone()).await?; + let file_size = file_size as u64; + + let results = reader.read_usize_from_end(4).await?; + let fm_chunk_offsets_offset = results[0]; + let posting_list_offsets_offset = results[1]; + let total_counts_offset = results[2]; + let n = results[3]; + + ns.push(n); + + let compressed_tokenizer_size = reader.read_usize_from_start(0, 1).await?[0]; + let this_compressed_tokenizer: bytes::Bytes = + reader.read_range(8, 8 + compressed_tokenizer_size).await?; + + match &compressed_tokenizer { + Some(value) => assert!( + this_compressed_tokenizer == value, + "detected different tokenizers, cannot merge, something is very wrong." + ), + None => compressed_tokenizer = Some(this_compressed_tokenizer.to_vec()), + } + + let fm_chunk_offsets: Vec = reader + .read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset) + .await?; + let posting_list_offsets: Vec = reader + .read_range_and_decompress(posting_list_offsets_offset, total_counts_offset) + .await?; + let cumulative_counts: Vec = reader + .read_range_and_decompress(total_counts_offset, (file_size - 32) as u64) + .await?; + + // println!("{} {}", file, cumulative_counts.len()); + + fm_chunk_iterators.push(FMChunkIterator::new(reader, fm_chunk_offsets).await?); + plist_iterators.push(PListIterator::new(reader1, posting_list_offsets).await?); + + if combined_cumulative_counts.len() == 0 { + combined_cumulative_counts = cumulative_counts; + } else { + // add cumulative_counts to combined_cumulative_counts + for (i, count) in cumulative_counts.iter().enumerate() { + combined_cumulative_counts[i] += count; + } + } + } + + let mut bwt0_reader = fm_chunk_iterators.remove(0); + let mut bwt1_reader = fm_chunk_iterators.remove(0); + let mut plist0_reader = plist_iterators.remove(0); + let mut plist1_reader = plist_iterators.remove(0); + + // let start = std::time::Instant::now(); + let interleave: BitVec = compute_interleave( + &mut bwt0_reader, + &mut bwt1_reader, + (ns[0] as usize, ns[1] as usize), + &combined_cumulative_counts, + ) + .await?; + + let _ = bwt0_reader.reset().await?; + let _ = bwt1_reader.reset().await?; + + // let duration = start.elapsed(); + // println!("interleave time: {:?}", duration); + + let mut output_file = File::create(condensed_lava_file)?; + let compressed_tokenizer = compressed_tokenizer.unwrap(); + output_file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; + output_file.write_all(&compressed_tokenizer)?; + + let mut bwt_output: Vec = Vec::with_capacity(interleave.len()); + let mut index_output: Vec = Vec::with_capacity(interleave.len()); + + let mut bwt_ind0 = 0; + let mut bwt_ind1 = 0; + let mut idx_ind0 = 0; + let mut idx_ind1 = 0; + + let mut bwt0 = &bwt0_reader.current_chunk.bwt_chunk; + let mut bwt1 = &bwt1_reader.current_chunk.bwt_chunk; + let mut idx0 = &plist0_reader.current_chunk; + let mut idx1 = &plist1_reader.current_chunk; + + for i in 0..interleave.len() { + if interleave[i] { + bwt_output.push(bwt1[bwt_ind1]); + index_output.push(idx1[idx_ind1] + uid_offsets[1]); + + bwt_ind1 += 1; + if bwt_ind1 == bwt1.len() { + let _ = bwt1_reader.advance().await; + bwt1 = &bwt1_reader.current_chunk.bwt_chunk; + bwt_ind1 = 0; + } + + idx_ind1 += 1; + if idx_ind1 == idx1.len() { + let _ = plist1_reader.advance().await; + idx1 = &plist1_reader.current_chunk; + idx_ind1 = 0; + } + } else { + bwt_output.push(bwt0[bwt_ind0]); + index_output.push(idx0[idx_ind0] + uid_offsets[0]); + + bwt_ind0 += 1; + if bwt_ind0 == bwt0.len() { + let _ = bwt0_reader.advance().await; + bwt0 = &bwt0_reader.current_chunk.bwt_chunk; + bwt_ind0 = 0; + } + + idx_ind0 += 1; + if idx_ind0 == idx0.len() { + let _ = plist0_reader.advance().await; + idx0 = &plist0_reader.current_chunk; + idx_ind0 = 0; + } + } + } + + let mut current_chunk: Vec = vec![]; + let mut current_chunk_counts: HashMap = HashMap::new(); + let mut next_chunk_counts: HashMap = HashMap::new(); + let mut fm_chunk_offsets: Vec = vec![output_file.seek(SeekFrom::Current(0))? as usize]; + + for i in 0..bwt_output.len() { + let current_tok = bwt_output[i]; + next_chunk_counts + .entry(current_tok) + .and_modify(|count| *count += 1) + .or_insert(1); + current_chunk.push(current_tok); + + if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt_output.len() - 1 { + let serialized_counts = bincode::serialize(¤t_chunk_counts)?; + let compressed_counts = + encode_all(&serialized_counts[..], 0).expect("Compression failed"); + output_file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; + output_file.write_all(&compressed_counts)?; + let serialized_chunk = bincode::serialize(¤t_chunk)?; + let compressed_chunk = + encode_all(&serialized_chunk[..], 0).expect("Compression failed"); + output_file.write_all(&compressed_chunk)?; + fm_chunk_offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); + current_chunk_counts = next_chunk_counts.clone(); + current_chunk = vec![]; + } + } + + let mut posting_list_offsets: Vec = + vec![output_file.seek(SeekFrom::Current(0))? as usize]; + + for i in (0..index_output.len()).step_by(FM_CHUNK_TOKS) { + let slice = &index_output[i..std::cmp::min(index_output.len(), i + FM_CHUNK_TOKS)]; + let serialized_slice = bincode::serialize(slice)?; + let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); + output_file.write_all(&compressed_slice)?; + posting_list_offsets.push(output_file.seek(SeekFrom::Current(0))? as usize); + } + + let cache_start = output_file.seek(SeekFrom::Current(0))? as usize; + + let fm_chunk_offsets_offset = output_file.seek(SeekFrom::Current(0))? as usize; + let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; + let compressed_fm_chunk_offsets = + encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); + output_file.write_all(&compressed_fm_chunk_offsets)?; + + let posting_list_offsets_offset = output_file.seek(SeekFrom::Current(0))? as usize; + let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; + let compressed_posting_list_offsets = + encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); + output_file.write_all(&compressed_posting_list_offsets)?; + + let total_counts_offset = output_file.seek(SeekFrom::Current(0))? as usize; + let serialized_total_counts = bincode::serialize(&combined_cumulative_counts)?; + let compressed_total_counts: Vec = + encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); + output_file.write_all(&compressed_total_counts)?; + + output_file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; + output_file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; + output_file.write_all(&(total_counts_offset as u64).to_le_bytes())?; + output_file.write_all(&(bwt_output.len() as u64).to_le_bytes())?; + + Ok(vec![( + cache_start, + output_file.seek(SeekFrom::Current(0))? as usize, + )]) +} diff --git a/src/lava/substring/mod.rs b/src/lava/substring/mod.rs new file mode 100644 index 0000000..b549162 --- /dev/null +++ b/src/lava/substring/mod.rs @@ -0,0 +1,11 @@ +mod constants; +pub(crate) mod fm_chunk; +pub(crate) mod merge; +mod substring; +pub(crate) mod wavelet_tree; +pub(crate) use merge::merge_lava_substring; + +pub(crate) use substring::_build_lava_substring_char; +pub(crate) use substring::_build_lava_substring_char_wavelet; +pub use substring::build_lava_substring; +pub use substring::build_lava_substring_char; diff --git a/src/lava/substring/substring.rs b/src/lava/substring/substring.rs new file mode 100644 index 0000000..90186a2 --- /dev/null +++ b/src/lava/substring/substring.rs @@ -0,0 +1,910 @@ +use super::constants::*; +use super::fm_chunk::FMChunk; +use crate::formats::readers::{get_file_sizes_and_readers, AsyncReader}; +use crate::lava::error::LavaError; + +use crate::lava::substring::wavelet_tree::{construct_wavelet_tree, write_wavelet_tree_to_disk}; +use arrow::array::{make_array, Array, ArrayData, LargeStringArray, UInt64Array}; +use bincode; +use bytes; +use divsufsort::sort_in_place; +use itertools::Itertools; +use serde_json; + +use rayon::prelude::*; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt::Debug; +use std::fs::File; +use std::io::Read; +use std::io::{BufWriter, Seek, SeekFrom, Write}; +use tokenizers::parallelism::MaybeParallelIterator; +use tokenizers::tokenizer::Tokenizer; // You'll need the `byteorder` crate +use tokio::task::JoinSet; +use zstd::stream::encode_all; +use zstd::stream::read::Decoder; + +pub async fn _build_lava_substring_char_wavelet( + output_file_name: String, + texts: Vec<(u64, String)>, + char_skip_factor: u32, +) -> Result, LavaError> { + let named_encodings = texts + .into_iter() + .map(|(uid, text)| { + let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); + let result: Vec = if char_skip_factor == 1 { + lower + .chars() + .filter(|id| !SKIP.chars().contains(id)) + .map(|c| c as u8) + .collect() + } else { + lower + .chars() + .filter(|id| !SKIP.chars().contains(id)) + .enumerate() + .filter(|&(index, _)| index % char_skip_factor as usize == 1) + .map(|(_, c)| c as u8) + .collect() + }; + (vec![uid; result.len()], result) + }) + .collect::, Vec)>>(); + + let uids: Vec = named_encodings + .iter() + .map(|(uid, _)| uid) + .flatten() + .cloned() + .collect::>(); + let encodings: Vec = named_encodings + .into_iter() + .map(|(_, text)| text) + .flatten() + .collect::>(); + + let mut sa: Vec = (0..encodings.len() as i32).collect(); + + sort_in_place(&encodings, &mut sa); + + let mut idx: Vec = Vec::with_capacity(encodings.len()); + let mut bwt: Vec = Vec::with_capacity(encodings.len()); + let mut total_counts: Vec = vec![0; 256]; + for i in 0..sa.len() { + let char = if sa[i] == 0 { + encodings[encodings.len() - 1] + } else { + encodings[(sa[i] - 1) as usize] + }; + bwt.push(char); + total_counts[char as usize] += 1; + if sa[i] == 0 { + idx.push(uids[uids.len() - 1]); + } else { + idx.push(uids[(sa[i] - 1) as usize]); + } + } + + let mut cumulative_counts = vec![0; 256]; + cumulative_counts[0] = 0; + for i in 1..256 { + cumulative_counts[i] = cumulative_counts[i - 1] + total_counts[i - 1]; + } + + let wavelet_tree = construct_wavelet_tree(&bwt); + + let mut file = File::create(output_file_name)?; + + let (offsets, level_offsets) = write_wavelet_tree_to_disk(&wavelet_tree, &mut file).unwrap(); + + // print out total file size so far + println!("total file size: {}", file.seek(SeekFrom::Current(0))?); + + let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; + + for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { + let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; + let serialized_slice = bincode::serialize(slice)?; + let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); + file.write_all(&compressed_slice)?; + posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); + } + + let metadata: (Vec, Vec, Vec, Vec, usize) = ( + offsets, + level_offsets, + posting_list_offsets, + cumulative_counts, + bwt.len(), + ); + + let cache_start = file.seek(SeekFrom::Current(0))? as usize; + + let serialized_metadata = bincode::serialize(&metadata)?; + let compressed_metadata = encode_all(&serialized_metadata[..], 0).expect("Compression failed"); + file.write_all(&compressed_metadata)?; + file.write_all(&cache_start.to_le_bytes())?; + + let cache_end = file.seek(SeekFrom::Current(0))? as usize; + + Ok(vec![(cache_start, cache_end)]) +} + +pub async fn _build_lava_substring_char( + output_file_name: String, + texts: Vec<(u64, String)>, + char_skip_factor: u32, +) -> Result, LavaError> { + let named_encodings = texts + .into_iter() + .map(|(uid, text)| { + let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); + let result: Vec = if char_skip_factor == 1 { + lower + .chars() + .filter(|id| !SKIP.chars().contains(id)) + .map(|c| c as u8) + .collect() + } else { + lower + .chars() + .filter(|id| !SKIP.chars().contains(id)) + .enumerate() + .filter(|&(index, _)| index % char_skip_factor as usize == 1) + .map(|(_, c)| c as u8) + .collect() + }; + (vec![uid; result.len()], result) + }) + .collect::, Vec)>>(); + + let uids: Vec = named_encodings + .iter() + .map(|(uid, _)| uid) + .flatten() + .cloned() + .collect::>(); + let encodings: Vec = named_encodings + .into_iter() + .map(|(_, text)| text) + .flatten() + .collect::>(); + + let mut sa: Vec = (0..encodings.len() as i32).collect(); + + sort_in_place(&encodings, &mut sa); + + let mut idx: Vec = Vec::with_capacity(encodings.len()); + let mut bwt: Vec = Vec::with_capacity(encodings.len()); + for i in 0..sa.len() { + if sa[i] == 0 { + bwt.push(encodings[encodings.len() - 1]); + idx.push(uids[uids.len() - 1]); + } else { + bwt.push(encodings[(sa[i] - 1) as usize]); + idx.push(uids[(sa[i] - 1) as usize]); + } + } + + let mut file = File::create(output_file_name)?; + + let mut fm_chunk_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; + + let mut current_chunk: Vec = vec![]; + let mut current_chunk_counts: HashMap = HashMap::new(); + let mut next_chunk_counts: HashMap = HashMap::new(); + + for i in 0..bwt.len() { + let current_tok = bwt[i]; + next_chunk_counts + .entry(current_tok) + .and_modify(|count| *count += 1) + .or_insert(1); + current_chunk.push(current_tok); + + if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt.len() - 1 { + let serialized_counts = bincode::serialize(¤t_chunk_counts)?; + let compressed_counts = + encode_all(&serialized_counts[..], 10).expect("Compression failed"); + println!("chunk size: {}", compressed_counts.len()); + file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; + file.write_all(&compressed_counts)?; + let serialized_chunk = bincode::serialize(¤t_chunk)?; + let compressed_chunk = + encode_all(&serialized_chunk[..], 10).expect("Compression failed"); + file.write_all(&compressed_chunk)?; + fm_chunk_offsets.push(file.seek(SeekFrom::Current(0))? as usize); + current_chunk_counts = next_chunk_counts.clone(); + current_chunk = vec![]; + } + } + // print out total file size so far + println!("total file size: {}", file.seek(SeekFrom::Current(0))?); + + let mut cumulative_counts: Vec = vec![0]; + for i in 0..256 { + cumulative_counts + .push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u8)).unwrap_or(&0)); + } + + let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; + + for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { + let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; + let serialized_slice = bincode::serialize(slice)?; + let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); + file.write_all(&compressed_slice)?; + posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); + } + + let cache_start = file.seek(SeekFrom::Current(0))? as usize; + + let fm_chunk_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; + let compressed_fm_chunk_offsets = + encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); + file.write_all(&compressed_fm_chunk_offsets)?; + + let posting_list_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; + let compressed_posting_list_offsets = + encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); + file.write_all(&compressed_posting_list_offsets)?; + + let total_counts_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_total_counts = bincode::serialize(&cumulative_counts)?; + let compressed_total_counts: Vec = + encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); + file.write_all(&compressed_total_counts)?; + + file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(total_counts_offset as u64).to_le_bytes())?; + file.write_all(&(bwt.len() as u64).to_le_bytes())?; + + let cache_end = file.seek(SeekFrom::Current(0))? as usize; + + Ok(vec![(cache_start, cache_end)]) +} + +#[tokio::main] +pub async fn build_lava_substring_char( + output_file_name: String, + array: ArrayData, + uid: ArrayData, + char_skip_factor: Option, +) -> Result, LavaError> { + let array = make_array(array); + // let uid = make_array(ArrayData::from_pyarrow(uid)?); + let uid = make_array(uid); + + let char_skip_factor = char_skip_factor.unwrap_or(1); + + let array: &arrow_array::GenericByteArray> = array + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects string array as first argument".to_string(), + ))?; + + let uid = uid + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects uint64 array as second argument".to_string(), + ))?; + + if array.len() != uid.len() { + return Err(LavaError::Parse( + "The length of the array and the uid array must be the same".to_string(), + )); + } + + let mut texts: Vec<(u64, String)> = Vec::with_capacity(array.len()); + for i in 0..array.len() { + let text = array.value(i); + texts.push((uid.value(i), text.to_string())); + } + + println!("made it to this point"); + // _build_lava_substring_char(output_file_name, texts, char_skip_factor).await + _build_lava_substring_char_wavelet(output_file_name, texts, char_skip_factor).await +} + +#[tokio::main] +pub async fn build_lava_substring( + output_file_name: String, + array: ArrayData, + uid: ArrayData, + tokenizer_file: Option, + token_skip_factor: Option, +) -> Result, LavaError> { + let array = make_array(array); + // let uid = make_array(ArrayData::from_pyarrow(uid)?); + let uid = make_array(uid); + + let token_skip_factor = token_skip_factor.unwrap_or(1); + + let tokenizer = if let Some(tokenizer_file) = tokenizer_file { + if !std::path::Path::new(&tokenizer_file).exists() { + return Err(LavaError::Parse( + "Tokenizer file does not exist".to_string(), + )); + } + println!("Tokenizer file: {}", tokenizer_file); + Tokenizer::from_file(tokenizer_file).unwrap() + } else { + Tokenizer::from_pretrained("bert-base-uncased", None).unwrap() + }; + + let serialized_tokenizer = serde_json::to_string(&tokenizer).unwrap(); + let compressed_tokenizer = + encode_all(serialized_tokenizer.as_bytes(), 0).expect("Compression failed"); + + let array: &arrow_array::GenericByteArray> = array + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects string array as first argument".to_string(), + ))?; + + let uid = uid + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects uint64 array as second argument".to_string(), + ))?; + + if array.len() != uid.len() { + return Err(LavaError::Parse( + "The length of the array and the uid array must be the same".to_string(), + )); + } + + let mut texts: Vec<(u64, &str)> = Vec::with_capacity(array.len()); + for i in 0..array.len() { + let text = array.value(i); + texts.push((uid.value(i), text)); + } + + let mut skip_tokens: HashSet = HashSet::new(); + for char in SKIP.chars() { + let char_str = char.to_string(); + skip_tokens.extend( + tokenizer + .encode(char_str.clone(), false) + .unwrap() + .get_ids() + .to_vec(), + ); + skip_tokens.extend( + tokenizer + .encode(format!(" {}", char_str), false) + .unwrap() + .get_ids() + .to_vec(), + ); + skip_tokens.extend( + tokenizer + .encode(format!("{} ", char_str), false) + .unwrap() + .get_ids() + .to_vec(), + ); + } + + let named_encodings = texts + .into_maybe_par_iter() + .map(|(uid, text)| { + // strip out things in skip in text + + let lower: String = text.chars().flat_map(|c| c.to_lowercase()).collect(); + let encoding = tokenizer.encode(lower, false).unwrap(); + let result: Vec = encoding + .get_ids() + .iter() + .filter(|id| !skip_tokens.contains(id)) + .cloned() + .collect(); + (vec![uid; result.len()], result) + }) + .collect::, Vec)>>(); + + let uids: Vec = named_encodings + .iter() + .map(|(uid, _)| uid) + .flatten() + .cloned() + .collect::>(); + let encodings: Vec = named_encodings + .into_iter() + .map(|(_, text)| text) + .flatten() + .collect::>(); + + let mut suffices: Vec> = vec![]; + + let (encodings, uids) = if token_skip_factor > 1 { + let encodings: Vec = encodings + .into_iter() + .enumerate() // Enumerate to get the index and value + .filter(|&(index, _)| index % token_skip_factor as usize == 1) // Keep only elements with odd indices (every second element) + .map(|(_, value)| value) // Extract the value + .collect(); // Collect into a vector + + let uids: Vec = uids + .into_iter() + .enumerate() // Enumerate to get the index and value + .filter(|&(index, _)| index % token_skip_factor as usize == 1) // Keep only elements with odd indices (every second element) + .map(|(_, value)| value) // Extract the value + .collect(); + (encodings, uids) + } else { + (encodings, uids) + }; + + for i in 10..encodings.len() { + suffices.push(encodings[i - 10..i].to_vec()); + } + + for i in encodings.len()..encodings.len() + 10 { + let mut suffix = encodings[i - 10..encodings.len()].to_vec(); + suffix.append(&mut vec![0; i - encodings.len()]); + suffices.push(suffix); + } + + let mut sa: Vec = (0..suffices.len()).collect(); + + sa.par_sort_by(|&a, &b| suffices[a].cmp(&suffices[b])); + + let mut idx: Vec = Vec::with_capacity(encodings.len()); + let mut bwt: Vec = Vec::with_capacity(encodings.len()); + for i in 0..sa.len() { + if sa[i] == 0 { + bwt.push(encodings[encodings.len() - 1]); + idx.push(uids[uids.len() - 1]); + } else { + bwt.push(encodings[(sa[i] - 1) as usize]); + idx.push(uids[(sa[i] - 1) as usize]); + } + } + + let mut file = File::create(output_file_name)?; + file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?; + file.write_all(&compressed_tokenizer)?; + + let mut fm_chunk_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; + + let mut current_chunk: Vec = vec![]; + let mut current_chunk_counts: HashMap = HashMap::new(); + let mut next_chunk_counts: HashMap = HashMap::new(); + + for i in 0..bwt.len() { + let current_tok = bwt[i]; + next_chunk_counts + .entry(current_tok) + .and_modify(|count| *count += 1) + .or_insert(1); + current_chunk.push(current_tok); + + if ((i + 1) % FM_CHUNK_TOKS == 0) || i == bwt.len() - 1 { + let serialized_counts = bincode::serialize(¤t_chunk_counts)?; + let compressed_counts = + encode_all(&serialized_counts[..], 10).expect("Compression failed"); + + file.write_all(&(compressed_counts.len() as u64).to_le_bytes())?; + file.write_all(&compressed_counts)?; + let serialized_chunk = bincode::serialize(¤t_chunk)?; + let compressed_chunk = + encode_all(&serialized_chunk[..], 10).expect("Compression failed"); + file.write_all(&compressed_chunk)?; + + fm_chunk_offsets.push(file.seek(SeekFrom::Current(0))? as usize); + current_chunk_counts = next_chunk_counts.clone(); + current_chunk = vec![]; + } + } + // print out total file size so far + println!("total file size: {}", file.seek(SeekFrom::Current(0))?); + + let mut cumulative_counts: Vec = vec![0]; + for i in 0..tokenizer.get_vocab_size(false) { + cumulative_counts + .push(cumulative_counts[i] + *current_chunk_counts.get(&(i as u32)).unwrap_or(&0)); + } + + let mut posting_list_offsets: Vec = vec![file.seek(SeekFrom::Current(0))? as usize]; + + for i in (0..idx.len()).step_by(FM_CHUNK_TOKS) { + let slice = &idx[i..std::cmp::min(idx.len(), i + FM_CHUNK_TOKS)]; + let serialized_slice = bincode::serialize(slice)?; + let compressed_slice = encode_all(&serialized_slice[..], 0).expect("Compression failed"); + file.write_all(&compressed_slice)?; + posting_list_offsets.push(file.seek(SeekFrom::Current(0))? as usize); + } + + let cache_start = file.seek(SeekFrom::Current(0))? as usize; + + let fm_chunk_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_fm_chunk_offsets = bincode::serialize(&fm_chunk_offsets)?; + let compressed_fm_chunk_offsets = + encode_all(&serialized_fm_chunk_offsets[..], 0).expect("Compression failed"); + file.write_all(&compressed_fm_chunk_offsets)?; + + let posting_list_offsets_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_posting_list_offsets = bincode::serialize(&posting_list_offsets)?; + let compressed_posting_list_offsets = + encode_all(&serialized_posting_list_offsets[..], 0).expect("Compression failed"); + file.write_all(&compressed_posting_list_offsets)?; + + let total_counts_offset = file.seek(SeekFrom::Current(0))? as usize; + let serialized_total_counts = bincode::serialize(&cumulative_counts)?; + let compressed_total_counts: Vec = + encode_all(&serialized_total_counts[..], 0).expect("Compression failed"); + file.write_all(&compressed_total_counts)?; + + file.write_all(&(fm_chunk_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(posting_list_offsets_offset as u64).to_le_bytes())?; + file.write_all(&(total_counts_offset as u64).to_le_bytes())?; + file.write_all(&(bwt.len() as u64).to_le_bytes())?; + + let cache_end = file.seek(SeekFrom::Current(0))? as usize; + + Ok(vec![(cache_start, cache_end)]) +} + +use num_traits::{AsPrimitive, PrimInt, Unsigned}; +use serde::{Deserialize, Serialize}; +use std::ops::Add; + +async fn process_substring_query( + query: Vec, + n: u64, + fm_chunk_offsets: &[u64], + cumulative_counts: &[u64], + posting_list_offsets: &[u64], + reader: &mut AsyncReader, + file_id: u64, +) -> Vec<(u64, u64)> +where + T: PrimInt + + Unsigned + + Serialize + + for<'de> Deserialize<'de> + + Clone + + Eq + + std::hash::Hash + + AsPrimitive + + 'static, + usize: AsPrimitive, +{ + let mut res: Vec<(u64, u64)> = vec![]; + let mut start: usize = 0; + let mut end: usize = n as usize; + + for i in (0..query.len()).rev() { + let current_token = query[i]; + + let start_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS]; + let end_byte = fm_chunk_offsets[start / FM_CHUNK_TOKS + 1]; + let start_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); + + let start_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS]; + let end_byte = fm_chunk_offsets[end / FM_CHUNK_TOKS + 1]; + let end_chunk = reader.read_range(start_byte, end_byte).await.unwrap(); + + start = cumulative_counts[current_token.as_()] as usize + + FMChunk::::new(start_chunk) + .unwrap() + .search(current_token, start % FM_CHUNK_TOKS) + .unwrap() as usize; + end = cumulative_counts[current_token.as_()] as usize + + FMChunk::::new(end_chunk) + .unwrap() + .search(current_token, end % FM_CHUNK_TOKS) + .unwrap() as usize; + + if start >= end { + return res; + } + + if end <= start + 2 { + break; + } + } + + let start_offset = posting_list_offsets[start / FM_CHUNK_TOKS]; + let end_offset = posting_list_offsets[end / FM_CHUNK_TOKS + 1]; + let total_chunks = end / FM_CHUNK_TOKS - start / FM_CHUNK_TOKS + 1; + + let plist_chunks = reader.read_range(start_offset, end_offset).await.unwrap(); + + let mut chunk_set = JoinSet::new(); + + for i in 0..total_chunks { + let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; + let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; + let this_chunk = plist_chunks + [(this_start - start_offset) as usize..(this_end - start_offset) as usize] + .to_vec(); + + chunk_set.spawn(async move { + let mut decompressor = Decoder::new(&this_chunk[..]).unwrap(); + let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len()); + decompressor + .read_to_end(&mut serialized_plist_chunk) + .unwrap(); + let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk).unwrap(); + + let chunk_res: Vec<(u64, u64)> = if i == 0 { + if total_chunks == 1 { + plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } else { + plist_chunk[start % FM_CHUNK_TOKS..] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } + } else if i == total_chunks - 1 { + plist_chunk[..end % FM_CHUNK_TOKS] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } else { + plist_chunk.iter().map(|&uid| (file_id, uid)).collect() + }; + + chunk_res + }); + } + + while let Some(chunk_res) = chunk_set.join_next().await { + res.extend(chunk_res.unwrap()); + } + + res +} + +use super::wavelet_tree::search_wavelet_tree_from_reader; +use crate::formats::readers::read_and_decompress; + +async fn search_substring_wavelet_one_file( + file_id: u64, + mut reader: AsyncReader, + file_size: usize, + queries: Vec>, +) -> Result, LavaError> { + println!("{:?}", queries); + + let metadata_start = reader.read_usize_from_end(1).await?[0]; + + let metadata: (Vec, Vec, Vec, Vec, usize) = read_and_decompress( + &mut reader, + metadata_start as u64, + file_size as u64 - metadata_start - 8, + ) + .await + .unwrap(); + let (offsets, level_offsets, posting_list_offsets, cumulative_counts, n) = metadata; + + // let mut query_set = JoinSet::new(); + + let mut res: Vec<(u64, u64)> = vec![]; + + for query in queries { + let mut reader = reader.clone(); + let (start, end) = search_wavelet_tree_from_reader( + &mut reader, + &query, + n, + &offsets, + &level_offsets, + &cumulative_counts, + ) + .await?; + + println!("{} {}", start, end); + + if start == end { + continue; + } + + let start_offset = posting_list_offsets[start / FM_CHUNK_TOKS]; + let end_offset = posting_list_offsets[end / FM_CHUNK_TOKS + 1]; + let total_chunks = end / FM_CHUNK_TOKS - start / FM_CHUNK_TOKS + 1; + + let plist_chunks = reader + .read_range(start_offset as u64, end_offset as u64) + .await + .unwrap(); + + let mut chunk_set = JoinSet::new(); + + for i in 0..total_chunks { + let this_start = posting_list_offsets[start / FM_CHUNK_TOKS + i]; + let this_end = posting_list_offsets[start / FM_CHUNK_TOKS + i + 1]; + let this_chunk = plist_chunks + [(this_start - start_offset) as usize..(this_end - start_offset) as usize] + .to_vec(); + + chunk_set.spawn(async move { + let mut decompressor = Decoder::new(&this_chunk[..]).unwrap(); + let mut serialized_plist_chunk: Vec = Vec::with_capacity(this_chunk.len()); + decompressor + .read_to_end(&mut serialized_plist_chunk) + .unwrap(); + let plist_chunk: Vec = bincode::deserialize(&serialized_plist_chunk).unwrap(); + + let chunk_res: Vec<(u64, u64)> = if i == 0 { + if total_chunks == 1 { + plist_chunk[start % FM_CHUNK_TOKS..end % FM_CHUNK_TOKS] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } else { + plist_chunk[start % FM_CHUNK_TOKS..] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } + } else if i == total_chunks - 1 { + plist_chunk[..end % FM_CHUNK_TOKS] + .iter() + .map(|&uid| (file_id, uid)) + .collect() + } else { + plist_chunk.iter().map(|&uid| (file_id, uid)).collect() + }; + + chunk_res + }); + } + + while let Some(chunk_res) = chunk_set.join_next().await { + res.extend(chunk_res.unwrap()); + } + } + + // let mut res = Vec::new(); + // while let Some(query_res) = query_set.join_next().await { + // res.extend(query_res.unwrap()); + // } + + Ok(res) +} + +async fn search_substring_one_file( + file_id: u64, + mut reader: AsyncReader, + file_size: usize, + queries: Vec>, +) -> Result, LavaError> +where + T: PrimInt + + Unsigned + + Serialize + + for<'de> Deserialize<'de> + + Clone + + Eq + + std::hash::Hash + + AsPrimitive + + Debug + + Send + + 'static, + usize: AsPrimitive, +{ + println!("{:?}", queries); + + let results = reader.read_usize_from_end(4).await?; + let fm_chunk_offsets_offset = results[0]; + let posting_list_offsets_offset = results[1]; + let total_counts_offset = results[2]; + let n = results[3]; + + let fm_chunk_offsets: Vec = reader + .read_range_and_decompress(fm_chunk_offsets_offset, posting_list_offsets_offset) + .await?; + let posting_list_offsets: Vec = reader + .read_range_and_decompress(posting_list_offsets_offset, total_counts_offset) + .await?; + let cumulative_counts: Vec = reader + .read_range_and_decompress(total_counts_offset, (file_size - 32) as u64) + .await?; + + let mut query_set = JoinSet::new(); + + for query in queries { + let fm_chunk_offsets = fm_chunk_offsets.clone(); + let cumulative_counts = cumulative_counts.clone(); + let posting_list_offsets = posting_list_offsets.clone(); + let mut reader = reader.clone(); + + query_set.spawn(async move { + process_substring_query::( + query, + n, + &fm_chunk_offsets, + &cumulative_counts, + &posting_list_offsets, + &mut reader, + file_id, + ) + .await + }); + } + + let mut res = Vec::new(); + while let Some(query_res) = query_set.join_next().await { + res.extend(query_res.unwrap()); + } + Ok(res) +} + +pub async fn _search_lava_substring_char( + files: Vec, + query: String, + k: usize, + reader_type: ReaderType, + token_viable_limit: Option, + sample_factor: Option, + wavelet_tree: bool, +) -> Result, LavaError> { + let lower: String = query.chars().flat_map(|c| c.to_lowercase()).collect(); + let result: Vec = lower + .chars() + .filter(|id| !SKIP.chars().contains(id)) + .map(|c| c as u8) + .collect(); + + let mut query: Vec> = if let Some(sample_factor) = sample_factor { + (0..sample_factor) + .map(|offset| { + result + .iter() + .skip(offset) + .step_by(sample_factor) + .cloned() + .collect::>() + }) + .filter(|vec| !vec.is_empty()) + .collect() + } else { + vec![result] + }; + + // println!("query {:?}", query); + + // query = [i[-token_viable_limit:] for i in query] + if let Some(token_viable_limit) = token_viable_limit { + query.iter_mut().for_each(|vec| { + if vec.len() > token_viable_limit { + *vec = vec + .iter() + .rev() + .take(token_viable_limit) + .rev() + .cloned() + .collect(); + } + }); + } + + // println!("query {:?}", query); + + let (file_sizes, readers) = get_file_sizes_and_readers(&files, reader_type).await?; + search_generic_async( + file_sizes, + readers, + if wavelet_tree { + QueryParam::SubstringCharWavelet(query) + } else { + QueryParam::SubstringChar(query) + }, + k, + ) + .await +} diff --git a/src/lava/wavelet_tree.rs b/src/lava/substring/wavelet_tree.rs similarity index 100% rename from src/lava/wavelet_tree.rs rename to src/lava/substring/wavelet_tree.rs diff --git a/src/lava/uuid/mod.rs b/src/lava/uuid/mod.rs new file mode 100644 index 0000000..53e6c7a --- /dev/null +++ b/src/lava/uuid/mod.rs @@ -0,0 +1,4 @@ +mod trie; +mod uuid; +pub use uuid::build_lava_uuid; +pub(crate) use uuid::merge_lava_uuid; diff --git a/src/lava/trie.rs b/src/lava/uuid/trie.rs similarity index 100% rename from src/lava/trie.rs rename to src/lava/uuid/trie.rs diff --git a/src/lava/uuid/uuid.rs b/src/lava/uuid/uuid.rs new file mode 100644 index 0000000..8ddb8bf --- /dev/null +++ b/src/lava/uuid/uuid.rs @@ -0,0 +1,94 @@ +use crate::{formats::readers::get_file_size_and_reader, lava::error::LavaError}; + +use super::trie::{BinaryTrieNode, FastTrie}; +use arrow::array::{make_array, Array, ArrayData, LargeStringArray, UInt64Array}; + +#[tokio::main] +pub async fn build_lava_uuid( + output_file_name: String, + array: ArrayData, + uid: ArrayData, +) -> Result, LavaError> { + let array = make_array(array); + // let uid = make_array(ArrayData::from_pyarrow(uid)?); + let uid = make_array(uid); + let array: &arrow_array::GenericByteArray> = array + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects string array as first argument".to_string(), + ))?; + + let uid: &arrow_array::PrimitiveArray = uid + .as_any() + .downcast_ref::() + .ok_or(LavaError::Parse( + "Expects uint64 array as second argument".to_string(), + ))?; + + if array.len() != uid.len() { + return Err(LavaError::Parse( + "The length of the array and the uid array must be the same".to_string(), + )); + } + + let mut texts = Vec::with_capacity(array.len()); + for i in 0..array.len() { + let text = array.value(i); + texts.push(text.as_bytes().to_vec()); + } + let mut inds = Vec::with_capacity(array.len()); + for i in 0..uid.len() { + inds.push(vec![uid.value(i) as usize]); + } + + let root = BinaryTrieNode::build(&texts, &inds); + let fast_trie = FastTrie::new(root, Some(16)); + let (serialized_fast_trie, (cache_start, cache_end)) = fast_trie.serialize(); + std::fs::write(output_file_name, serialized_fast_trie).unwrap(); + + Ok(vec![(cache_start, cache_end)]) +} + +pub(crate) async fn merge_lava_uuid( + condensed_lava_file: &str, + lava_files: Vec, + uid_offsets: Vec, + reader_type: ReaderType, +) -> Result, LavaError> { + // currently only support merging two files, but can support more in the future. + assert_eq!(lava_files.len(), 2); + assert_eq!(uid_offsets.len(), 2); + + let (file_size1, mut reader1) = + get_file_size_and_reader(lava_files[0].clone(), reader_type.clone()).await?; + let (file_size2, mut reader2) = + get_file_size_and_reader(lava_files[1].clone(), reader_type.clone()).await?; + + // let buffer: bytes::Bytes = reader1.read_range(0, file_size1 as u64).await?; + // let mut fast_trie1 = FastTrie::deserialize(buffer.to_vec()); + // let buffer: bytes::Bytes = reader2.read_range(0, file_size2 as u64).await?; + // let mut fast_trie2 = FastTrie::deserialize(buffer.to_vec()); + + // fast_trie1.extend( + // &mut fast_trie2, + // uid_offsets[0] as usize, + // uid_offsets[1] as usize, + // ); + // let (serialized, (cache_start, cache_end)) = fast_trie1.serialize(); + // let mut output_file = File::create(condensed_lava_file)?; + // output_file.write(&serialized)?; + + let (cache_start, cache_end) = FastTrie::extend_with_readers_into_file( + file_size1, + &mut reader1, + file_size2, + &mut reader2, + condensed_lava_file, + uid_offsets[0] as usize, + uid_offsets[1] as usize, + ) + .await?; + + Ok(vec![(cache_start, cache_end)]) +} diff --git a/src/lava/vector/mod.rs b/src/lava/vector/mod.rs new file mode 100644 index 0000000..5119f4c --- /dev/null +++ b/src/lava/vector/mod.rs @@ -0,0 +1,3 @@ +mod vector; + +pub use vector::search_lava_vector; diff --git a/src/lava/vector/vector.rs b/src/lava/vector/vector.rs new file mode 100644 index 0000000..ad32f2b --- /dev/null +++ b/src/lava/vector/vector.rs @@ -0,0 +1,231 @@ +use ndarray::{concatenate, stack, Array1, Array2, Axis}; + +use crate::formats::readers::{get_file_sizes_and_readers, get_reader}; +fn bytes_to_f32_vec(bytes: &[u8]) -> Vec { + let mut vec = Vec::with_capacity(bytes.len() / 4); + let mut i = 0; + while i < bytes.len() { + let value = LittleEndian::read_f32(&bytes[i..i + 4]); + vec.push(value); + i += 4; + } + vec +} +async fn search_lava_vector_async( + files: Vec, + query: Vec, + nprobes: usize, + reader_type: ReaderType, +) -> Result<(Vec, Vec>, Vec<(usize, Array1)>), LavaError> { + let start = Instant::now(); + + let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; + + let mut futures = Vec::new(); + + for _ in 0..readers.len() { + let mut reader = readers.remove(0); + + futures.push(tokio::spawn(async move { + let results = reader.read_usize_from_end(4).await.unwrap(); + + let centroid_vectors_compressed_bytes = + reader.read_range(results[2], results[3]).await.unwrap(); + + // decompress them + let mut decompressor = + Decoder::new(centroid_vectors_compressed_bytes.as_ref()).unwrap(); + let mut centroid_vectors: Vec = + Vec::with_capacity(centroid_vectors_compressed_bytes.len() as usize); + decompressor.read_to_end(&mut centroid_vectors).unwrap(); + + let centroid_vectors = bytes_to_f32_vec(¢roid_vectors); + let num_vectors = centroid_vectors.len() / 128; + let array2 = + Array2::::from_shape_vec((num_vectors, 128), centroid_vectors).unwrap(); + + (num_vectors, array2) + })); + } + + let result: Vec), tokio::task::JoinError>> = + futures::future::join_all(futures).await; + + let end = Instant::now(); + println!("Time stage 1 read: {:?}", end - start); + + let start = Instant::now(); + + let arr_lens = result + .iter() + .map(|x| x.as_ref().unwrap().0) + .collect::>(); + // get cumulative arr len starting from 0 + let cumsum = arr_lens + .iter() + .scan(0, |acc, &x| { + *acc += x; + Some(*acc) + }) + .collect::>(); + + let arrays: Vec> = result.into_iter().map(|x| x.unwrap().1).collect(); + let centroids = concatenate( + Axis(0), + arrays + .iter() + .map(|array| array.view()) + .collect::>() + .as_slice(), + ) + .unwrap(); + let query = Array1::::from_vec(query); + let query_broadcast = query.broadcast(centroids.dim()).unwrap(); + + let difference = ¢roids - &query_broadcast; + let norms = difference.map_axis(Axis(1), |row| row.dot(&row).sqrt()); + let mut indices_and_values: Vec<(usize, f32)> = norms + .iter() + .enumerate() + .map(|(idx, &val)| (idx, val)) + .collect(); + + indices_and_values.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)); + let smallest_indices: Vec = indices_and_values + .iter() + .map(|&(idx, _)| idx) + .take(nprobes) + .collect(); + + let mut file_indices: Vec> = vec![vec![]; files.len()]; + for idx in smallest_indices.iter() { + // figure out which file idx based on cumsum. need to find the index of the thing that is just bigger than idx + + let file_idx = cumsum + .iter() + .enumerate() + .find(|(_, &val)| val > *idx) + .unwrap() + .0; + let last_cumsum = if file_idx == 0 { + 0 + } else { + cumsum[file_idx - 1] + }; + let remainder = idx - last_cumsum; + file_indices[file_idx].push(remainder); + } + + let end = Instant::now(); + println!("Time math: {:?}", end - start); + + let start = Instant::now(); + + let (_, mut readers) = get_file_sizes_and_readers(&files, reader_type.clone()).await?; + + let mut file_ids = vec![]; + let mut futures = Vec::new(); + + for file_id in 0..readers.len() { + let mut reader = readers.remove(0); + if file_indices[file_id].len() == 0 { + continue; + } + let my_idx: Vec = file_indices[file_id].clone(); + file_ids.push(file_id); + + futures.push(tokio::spawn(async move { + let results = reader.read_usize_from_end(4).await.unwrap(); + + let pq_bytes = reader.read_range(results[0], results[1]).await.unwrap(); + + let compressed_centroid_offset_bytes = + reader.read_range(results[1], results[2]).await.unwrap(); + let mut decompressor = Decoder::new(compressed_centroid_offset_bytes.as_ref()).unwrap(); + let mut centroid_offsets_bytes: Vec = + Vec::with_capacity(compressed_centroid_offset_bytes.len() as usize); + decompressor + .read_to_end(&mut centroid_offsets_bytes) + .unwrap(); + + // now reinterpret centroid_offsets_bytes as a Vec + + let mut centroid_offsets = Vec::with_capacity(centroid_offsets_bytes.len() / 8); + let mut cursor = Cursor::new(centroid_offsets_bytes); + + while cursor.position() < cursor.get_ref().len() as u64 { + let value = cursor.read_u64::().unwrap(); + centroid_offsets.push(value); + } + + let mut this_result: Vec<(usize, u64, u64)> = vec![]; + + for idx in my_idx.iter() { + this_result.push((file_id, centroid_offsets[*idx], centroid_offsets[*idx + 1])); + } + (this_result, Array1::::from_vec(pq_bytes.to_vec())) + })); + } + + let result: Vec, Array1), tokio::task::JoinError>> = + futures::future::join_all(futures).await; + let result: Vec<(Vec<(usize, u64, u64)>, Array1)> = + result.into_iter().map(|x| x.unwrap()).collect(); + + let pq_bytes: Vec> = result.iter().map(|x| x.1.clone()).collect::>(); + + let end = Instant::now(); + println!("Time stage 2 read: {:?}", end - start); + + let start = Instant::now(); + let reader = get_reader(files[file_ids[0]].clone(), reader_type.clone()) + .await + .unwrap(); + + let mut futures = FuturesUnordered::new(); + for i in 0..result.len() { + let to_read = result[i].0.clone(); + for (file_id, start, end) in to_read.into_iter() { + let mut reader_c = reader.clone(); + reader_c.update_filename(files[file_id].clone()).unwrap(); + + futures.push(tokio::spawn(async move { + let start_time = Instant::now(); + let codes_and_plist = reader_c.read_range(start, end).await.unwrap(); + // println!( + // "Time to read {:?}, {:?}", + // Instant::now() - start_time, + // codes_and_plist.len() + // ); + (file_id, Array1::::from_vec(codes_and_plist.to_vec())) + })); + } + } + + let mut ranges: Vec<(usize, Array1)> = vec![]; + + while let Some(x) = futures.next().await { + ranges.push(x.unwrap()); + } + + let end = Instant::now(); + println!("Time stage 3 read: {:?}", end - start); + + Ok((file_ids, pq_bytes, ranges)) +} + +pub fn search_lava_vector( + files: Vec, + query: Vec, + nprobes: usize, + reader_type: ReaderType, +) -> Result<(Vec, Vec>, Vec<(usize, Array1)>), LavaError> { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let res = rt.block_on(search_lava_vector_async(files, query, nprobes, reader_type)); + rt.shutdown_background(); + res +}