Skip to content

Commit

Permalink
latest
Browse files Browse the repository at this point in the history
  • Loading branch information
marsupialtail committed Aug 22, 2024
1 parent 61f896b commit 968f9e0
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 154 deletions.
22 changes: 12 additions & 10 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
# print(result)

def substring_test():
internal.index_files_substring(["example_data/a.parquet"], "text", "index0", token_skip_factor = 10)
internal.index_files_substring(["example_data/b.parquet"], "text", "index1", token_skip_factor = 10)
internal.merge_index_substring("merged_index", ["index0", "index1"])
# internal.index_files_substring(["example_data/a.parquet"], "text", "index0", token_skip_factor = 10)
# internal.index_files_substring(["example_data/b.parquet"], "text", "index1", token_skip_factor = 10)
# internal.merge_index_substring("merged_index", ["index0", "index1"])
result = internal.search_index_substring(["index0"],
"One step you have to remember not to skip is to use Disk Utility to partition the SSD as GUID partition scheme HFS+ before doing the clone.",
K = 10, sample_factor = 10)
K = 10, token_viable_limit= 1, sample_factor = 10)
print(result)

# table1 = polars.read_parquet("uuid_data/a.parquet")
Expand All @@ -23,16 +23,18 @@ def substring_test():
# write_deltalake("uuid_data_delta", table2.to_arrow(), mode = "append", engine = 'rust', writer_properties = WriterProperties(data_page_size_limit=1000000, compression = 'ZSTD'))
# rottnest.index_delta("uuid_data_delta", "hashes", "uuid_rottnest_index", "uuid")


def uuid_test():
internal.index_files_uuid(["uuid_data/a.parquet"], "hashes", "index0")
internal.index_files_uuid(["uuid_data/b.parquet"], "hashes", "index1")
internal.merge_index_uuid("merged_index_new", ["index0", "index1"])
result = internal.search_index_uuid(["merged_index_new"], "93b9f88dd22cb168cbc45000fcb05042cd1fc4b5602a56e70383fa26d33d21b08d004d78a7c97a463331da2da64e88f5546367e16e5fd2539bb9b8796ffffc7f", K = 10)
print(result)
# internal.index_files_uuid(["uuid_data/a.parquet"], "hashes", "index0")
# internal.index_files_uuid(["uuid_data/b.parquet"], "hashes", "index1")

internal.index_files_uuid(["s3://txhashesbenchmark/0.parquet"], "hashes", "index0")
internal.index_files_uuid(["s3://txhashesbenchmark/1.parquet"], "hashes", "index1")

internal.merge_index_uuid("merged_index", ["index0", "index1"])
result = internal.search_index_uuid(["merged_index"], "93b9f88dd22cb168cbc45000fcb05042cd1fc4b5602a56e70383fa26d33d21b08d004d78a7c97a463331da2da64e88f5546367e16e5fd2539bb9b8796ffffc7f", K = 10)
print(result)

substring_test()

# result = rottnest.search_index_uuid(["merged_index"], "650243a9024fe6595fa953e309c722c225cb2fae1f70c74364917eb901bcdce1f9a878d22345a8576a201646b6da815ebd6397cfd313447ee3a548259f63825a", K = 10)
# print(result)
Expand Down
72 changes: 50 additions & 22 deletions python/rottnest/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def index_files_bm25(file_paths: list[str], column_name: str, name = uuid.uuid4(
file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd')

def index_files_substring(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical", tokenizer_file = None, token_skip_factor = None):
def index_files_substring(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical", tokenizer_file = None, token_skip_factor = None, remote = None):

arr, uid, file_data = get_physical_layout(file_paths, column_name) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid")
arr, uid, file_data = get_physical_layout(file_paths, column_name, remote = remote) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", remote = remote)

cache_ranges = rottnest.build_lava_substring(f"{name}.lava", arr, uid, tokenizer_file, token_skip_factor)

file_data = file_data.to_arrow()
file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd')

def index_files_uuid(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical"):
def index_files_uuid(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, index_mode = "physical", remote = None):

arr, uid, file_data = get_physical_layout(file_paths, column_name) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid")
arr, uid, file_data = get_physical_layout(file_paths, column_name, remote = remote) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", remote = remote)

idx = pac.sort_indices(arr)
arr = arr.take(idx)
Expand All @@ -51,7 +51,9 @@ def index_files_uuid(file_paths: list[str], column_name: str, name = uuid.uuid4(
file_data = file_data.replace_schema_metadata({"cache_ranges": json.dumps(cache_ranges)})
pq.write_table(file_data, f"{name}.meta", write_statistics = False, compression = 'zstd')

def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, dtype = 'f32', index_mode = "physical", gpu = False):


def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid4().hex, dtype = 'f32', index_mode = "physical", gpu = False, remote = None):

try:
import faiss
Expand All @@ -64,7 +66,7 @@ def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid
assert dtype == 'f32'
dtype_size = 4

arr, uid, file_data = get_physical_layout(file_paths, column_name, type = "binary") if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", type = "binary")
arr, uid, file_data = get_physical_layout(file_paths, column_name, type = "binary", remote = remote) if index_mode == "physical" else get_virtual_layout(file_paths, column_name, "uid", type = "binary", remote = remote)
uid = uid.to_numpy()

# arr will be a array of largebinary, we need to convert it into numpy, time for some arrow ninja
Expand All @@ -75,9 +77,8 @@ def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid
dim = diffs.item() // dtype_size
x = np.frombuffer(buffers[2], dtype = np.float32).reshape(len(arr), dim)

num_centroids = 1000 # len(arr) // 10_000
num_centroids = len(arr) // 10_000

# kmeans = faiss.Kmeans(128, len(arr) // 10_000, niter=30, verbose=True, gpu = gpu)
kmeans = faiss.Kmeans(128,num_centroids, niter=30, verbose=True, gpu = gpu)
kmeans.train(x)
centroids = kmeans.centroids
Expand All @@ -91,21 +92,45 @@ def index_files_vector(file_paths: list[str], column_name: str, name = uuid.uuid
posting_lists = [[] for _ in range(num_centroids)]
codes_lists = [[] for _ in range(num_centroids)]

for i in tqdm(range(len(arr) // batch_size)):
batch = x[i * batch_size:(i + 1) * batch_size]
if gpu:

res = faiss.StandardGpuResources()
d = centroids.shape[1]
index = faiss.GpuIndexFlatL2(res, d)
index.add(centroids.astype('float32'))

# Process batches
for i in tqdm(range(len(arr) // batch_size)):
batch = x[i * batch_size:(i + 1) * batch_size].astype('float32')
k = 20
distances, indices = index.search(batch, k)

# The indices are already sorted by distance, so we don't need to sort again
closest_centroids = indices[:, 0]

distances = -np.sum(centroids ** 2, axis=1, keepdims=True).T + 2 * np.dot(batch, centroids.T)
indices = np.argpartition(-distances, kth=20, axis=1)[:, :20]
sorted_indices = np.argsort(-distances[np.arange(distances.shape[0])[:, None], indices], axis=1)
indices = indices[np.arange(indices.shape[0])[:, None], sorted_indices]
for k in range(batch_size):
# TODO: this uses UID! Just a warning. because gemv is fast even on lowly CPUs for final reranking.
posting_lists[closest_centroids[k]].append(uid[i * batch_size + k])
codes_lists[closest_centroids[k]].append(codes[i * batch_size + k])

closest_centroids = list(indices[:,0])
# closest2_centroids = list(indices[:,1])

else:
for i in tqdm(range(len(arr) // batch_size)):
batch = x[i * batch_size:(i + 1) * batch_size]

distances = -np.sum(centroids ** 2, axis=1, keepdims=True).T + 2 * np.dot(batch, centroids.T)
indices = np.argpartition(-distances, kth=20, axis=1)[:, :20]
sorted_indices = np.argsort(-distances[np.arange(distances.shape[0])[:, None], indices], axis=1)
indices = indices[np.arange(indices.shape[0])[:, None], sorted_indices]

closest_centroids = list(indices[:,0])
# closest2_centroids = list(indices[:,1])

for k in range(batch_size):
# TODO: this uses UID! Just a warning. because gemv is fast even on lowly CPUs for final reranking.
posting_lists[closest_centroids[k]].append(uid[i * batch_size + k])
codes_lists[closest_centroids[k]].append(codes[i * batch_size + k])

for k in range(batch_size):
# TODO: this uses UID! Just a warning. because gemv is fast even on lowly CPUs for final reranking.
posting_lists[closest_centroids[k]].append(uid[i * batch_size + k])
codes_lists[closest_centroids[k]].append(codes[i * batch_size + k])

f = open(f"{name}.lava", "wb")
centroid_offsets = [0]
Expand Down Expand Up @@ -241,15 +266,18 @@ def search_index_uuid(indices: List[str], query: str, K: int, columns = []):
return return_full_result(result, metadata, column_name, columns)


def search_index_substring(indices: List[str], query: str, K: int, sample_factor = None, columns = []):
def search_index_substring(indices: List[str], query: str, K: int, sample_factor = None, token_viable_limit = 1, columns = []):

metadata = get_metadata_and_populate_cache(indices)

index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K, "aws", sample_factor = sample_factor)
index_search_results = rottnest.search_lava_substring([f"{index_name}.lava" for index_name in indices], query, K, "aws", sample_factor = sample_factor, token_viable_limit = token_viable_limit)
print(index_search_results)

if len(index_search_results) == 0:
return None

if len(index_search_results) > 10000:
return "Brute Force Please"

result, column_name, metadata = get_result_from_index_result(metadata, index_search_results)
result = polars.from_arrow(result).filter(polars.col(column_name).str.to_lowercase().str.contains(query.lower(), literal=True))
Expand Down
45 changes: 29 additions & 16 deletions python/rottnest/table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from deltalake import DeltaTable, write_deltalake
import duckdb
import pyarrow
import polars
from deltalake._internal import TableNotFoundError
import uuid

from . import internal
"""
the schema of the metadata table will be:
Expand All @@ -16,6 +11,9 @@

def index_delta(table: str, column: str, index_dir: str, type: str, index_impl = 'physical', extra_configs = {}):

from deltalake import DeltaTable, write_deltalake
from deltalake._internal import TableNotFoundError

assert type in {"uuid", "substring", "vector"}
assert index_impl in {"physical", "virtual"}
if index_impl == "virtual":
Expand Down Expand Up @@ -62,30 +60,45 @@ def index_delta(table: str, column: str, index_dir: str, type: str, index_impl =
# we should deprecate the type argument, and figure out the type automatically.
def search_delta(table: str, index_dir: str, query, type: str, K: int, snapshot : int | None = None, extra_configs = {}):

from deltalake import DeltaTable, write_deltalake
from deltalake._internal import TableNotFoundError

assert type in {"uuid", "substring", "vector"}

main_table = DeltaTable(table)
if snapshot is not None:
main_table.load_as_version(snapshot)

existing_parquet_files = polars.from_dict({"covered_parquet_files": main_table.file_uris()})
existing_parquet_files = set(main_table.file_uris())

index_dir = index_dir.rstrip("/")
metadata_table_dir = f"{index_dir}/metadata_table"

selected_indices = []
uncovered_parquets = set(existing_parquet_files)

try:
metadata = polars.from_arrow(DeltaTable(metadata_table_dir).to_pyarrow_table())
metadata = metadata.explode('covered_parquet_files')
# convert it to a dictionary
metadata = {row['index_file']: row['covered_parquet_files'] for _, row in metadata.iterrows()}

while True:
overlap_to_index_file = {len(existing_parquet_files.intersection(v)): k for k, v in metadata.items()}
max_overlap = max(overlap_to_index_file.keys())
if max_overlap == 0:
break
index_file = overlap_to_index_file[max_overlap]
selected_indices.append(index_file)
uncovered_parquets = uncovered_parquets.difference(set(metadata[index_file]))
del metadata[index_file]

except TableNotFoundError:
# brute force
pass




# we want to figure out the minimal number of index files that cover the existing parquet files to the best of our abilities.



# if type == "uuid":
# index_search_results = internal.search_lava_uuid([f"{index_name}.lava" for index_name in indices], query, K, "aws")
if type == "uuid":
index_search_results = internal.search_lava_uuid([f"{index_name}.lava" for index_name in selected_indices], query, K, "aws")
elif type == "substring":
index_search_results = internal.search_lava_substring([f"{index_name}.lava" for index_name in selected_indices], query, K, "aws")
elif type == "vector":
index_search_results = internal.search_lava_vector([f"{index_name}.lava" for index_name in selected_indices], query, K, "aws")
10 changes: 5 additions & 5 deletions python/rottnest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def read_parquet_file(file, row_group, row_nr):
return pyarrow.concat_tables(results)


def get_physical_layout(file_paths: list[str], column_name: str, type = "str"):
def get_physical_layout(file_paths: list[str], column_name: str, type = "str", remote = None):

assert type in {"str", "binary"}

Expand All @@ -80,7 +80,7 @@ def get_physical_layout(file_paths: list[str], column_name: str, type = "str"):

metadata = polars.from_dict({
"uid": np.arange(len(data_page_num_rows) + 1),
"file_path": [file_path] * (len(data_page_num_rows) + 1),
"file_path": [file_path if remote is None else remote + file_path] * (len(data_page_num_rows) + 1),
"column_name": [column_name] * (len(data_page_num_rows) + 1),
# TODO: figure out a better way to handle this. Currently this is definitely not a bottleneck. Write ampl factor is almost 10x
# writing just one row followed by a bunch of Nones don't help, likely because it's already smart enough to do dict encoding.
Expand All @@ -105,7 +105,7 @@ def get_physical_layout(file_paths: list[str], column_name: str, type = "str"):

return pyarrow.concat_arrays(all_arrs), pyarrow.array(all_uids.astype(np.uint64)), polars.concat(metadatas)

def get_virtual_layout(file_paths: list[str], column_name: str, key_column_name: str, type = "str", stride = 500):
def get_virtual_layout(file_paths: list[str], column_name: str, key_column_name: str, type = "str", stride = 500, remote = None):

fs = get_fs_from_file_path(file_paths[0])
metadatas = []
Expand All @@ -119,7 +119,7 @@ def get_virtual_layout(file_paths: list[str], column_name: str, key_column_name:
arr = table[column_name].to_arrow().cast(pyarrow.large_string() if type == 'str' else pyarrow.large_binary())
uid = table['__uid__'].to_arrow().cast(pyarrow.uint64())

metadata = table.groupby("__uid__").agg([polars.col(key_column_name).min().alias("min"), polars.col(key_column_name).max().alias("max")]).sort("__uid__")
metadata = table.group_by("__uid__").agg([polars.col(key_column_name).min().alias("min"), polars.col(key_column_name).max().alias("max")]).sort("__uid__")

return arr, uid, metadata

Expand Down Expand Up @@ -164,7 +164,7 @@ def get_result_from_index_result(metadata: polars.DataFrame, index_search_result
def return_full_result(result: polars.DataFrame, metadata: polars.DataFrame, column_name: str, columns: List[str]):
if columns != []:
result = result.join(metadata.select(["__metadata_key__", "file_path", "row_groups"]), on = "__metadata_key__", how = "left")
grouped = result.groupby(["file_path", "row_groups"]).agg([polars.col('__metadata_key__'), polars.col('__row_group_rownr__')])
grouped = result.group_by(["file_path", "row_groups"]).agg([polars.col('__metadata_key__'), polars.col('__row_group_rownr__')])
collected_results = polars.from_arrow(read_columns(grouped["file_path"].to_list(), grouped["row_groups"].to_list(), grouped["__row_group_rownr__"].to_list()))
unnested_metadata_key = grouped['__metadata_key__'].explode()
unnested_row_group_rownr = grouped['__row_group_rownr__'].explode()
Expand Down
18 changes: 7 additions & 11 deletions src/lava/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,6 @@ pub async fn build_lava_substring(
})
.collect::<Vec<(Vec<u64>, Vec<u32>)>>();

println!("{:?}", named_encodings[1]);

let uids: Vec<u64> = named_encodings
.iter()
.map(|(uid, _)| uid)
Expand Down Expand Up @@ -558,8 +556,6 @@ pub async fn build_lava_substring(
(encodings, uids)
};

println!("{:?}", &encodings[..200]);

for i in 10..encodings.len() {
suffices.push(encodings[i - 10..i].to_vec());
}
Expand Down Expand Up @@ -598,16 +594,16 @@ pub async fn build_lava_substring(

// write out the bwt to a numpy array

let file = File::create("output.bin")?;
let mut writer = BufWriter::new(file);
// let file = File::create("output.bin")?;
// let mut writer = BufWriter::new(file);

// Write each u32 to the file as bytes
for number in bwt.iter() {
writer.write_u32::<LittleEndian>(*number)?;
}
// // Write each u32 to the file as bytes
// for number in bwt.iter() {
// writer.write_u32::<LittleEndian>(*number)?;
// }

// Flush the buffer to ensure all data is written to the file
writer.flush()?;
// writer.flush()?;

let mut file = File::create(output_file_name)?;
file.write_all(&(compressed_tokenizer.len() as u64).to_le_bytes())?;
Expand Down
Loading

0 comments on commit 968f9e0

Please sign in to comment.