Skip to content

Commit

Permalink
[ENH] Change DataChunk to Chunk and make Generic
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Apr 26, 2024
1 parent ef24ce2 commit 648629e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 37 deletions.
26 changes: 13 additions & 13 deletions rust/worker/src/execution/data/data_chunk.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use crate::types::LogRecord;
use std::sync::Arc;

#[derive(Clone, Debug)]
pub(crate) struct DataChunk {
data: Arc<[LogRecord]>,
pub(crate) struct Chunk<T> {
data: Arc<[T]>,
visibility: Arc<[bool]>,
}

impl DataChunk {
pub fn new(data: Arc<[LogRecord]>) -> Self {
impl<T> Chunk<T> {
pub fn new(data: Arc<[T]>) -> Self {
let len = data.len();
DataChunk {
Chunk {
data,
visibility: vec![true; len].into(),
}
Expand All @@ -30,7 +29,7 @@ impl DataChunk {
/// if the index is out of bounds, it returns None
/// # Arguments
/// * `index` - The index of the element
pub fn get(&self, index: usize) -> Option<&LogRecord> {
pub fn get(&self, index: usize) -> Option<&T> {
if index < self.data.len() {
Some(&self.data[index])
} else {
Expand Down Expand Up @@ -69,21 +68,21 @@ impl DataChunk {
/// The iterator returns a tuple of the element and its index
/// # Returns
/// An iterator over the visible elements in the data chunk
pub fn iter(&self) -> DataChunkIteraror<'_> {
pub fn iter(&self) -> DataChunkIteraror<'_, T> {
DataChunkIteraror {
chunk: self,
index: 0,
}
}
}

pub(crate) struct DataChunkIteraror<'a> {
chunk: &'a DataChunk,
pub(crate) struct DataChunkIteraror<'a, T> {
chunk: &'a Chunk<T>,
index: usize,
}

impl<'a> Iterator for DataChunkIteraror<'a> {
type Item = (&'a LogRecord, usize);
impl<'a, T> Iterator for DataChunkIteraror<'a, T> {
type Item = (&'a T, usize);

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.chunk.total_len() {
Expand All @@ -108,6 +107,7 @@ impl<'a> Iterator for DataChunkIteraror<'a> {
#[cfg(test)]
mod tests {
use super::*;
use crate::types::LogRecord;
use crate::types::Operation;
use crate::types::OperationRecord;

Expand Down Expand Up @@ -136,7 +136,7 @@ mod tests {
},
];
let data = data.into();
let mut chunk = DataChunk::new(data);
let mut chunk = Chunk::new(data);
assert_eq!(chunk.len(), 2);
let mut iter = chunk.iter();
let elem = iter.next();
Expand Down
13 changes: 7 additions & 6 deletions rust/worker/src/execution/operators/brute_force_knn.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::execution::data::data_chunk::DataChunk;
use crate::execution::data::data_chunk::Chunk;
use crate::types::LogRecord;
use crate::{distance::DistanceFunction, execution::operator::Operator};
use async_trait::async_trait;
use std::cmp::Ordering;
Expand All @@ -20,7 +21,7 @@ pub struct BruteForceKnnOperator {}
/// * `distance_metric` - The distance metric to use.
#[derive(Debug)]
pub struct BruteForceKnnOperatorInput {
pub data: DataChunk,
pub data: Chunk<LogRecord>,
pub query: Vec<f32>,
pub k: usize,
pub distance_metric: DistanceFunction,
Expand All @@ -35,7 +36,7 @@ pub struct BruteForceKnnOperatorInput {
/// One row for each query vector.
#[derive(Debug)]
pub struct BruteForceKnnOperatorOutput {
pub data: DataChunk,
pub data: Chunk<LogRecord>,
pub indices: Vec<usize>,
pub distances: Vec<f32>,
}
Expand Down Expand Up @@ -172,7 +173,7 @@ mod tests {
},
},
];
let data_chunk = DataChunk::new(data.into());
let data_chunk = Chunk::new(data.into());

let input = BruteForceKnnOperatorInput {
data: data_chunk,
Expand Down Expand Up @@ -230,7 +231,7 @@ mod tests {
},
},
];
let data_chunk = DataChunk::new(data.into());
let data_chunk = Chunk::new(data.into());

let input = BruteForceKnnOperatorInput {
data: data_chunk,
Expand Down Expand Up @@ -264,7 +265,7 @@ mod tests {
},
}];

let data_chunk = DataChunk::new(data.into());
let data_chunk = Chunk::new(data.into());

let input = BruteForceKnnOperatorInput {
data: data_chunk,
Expand Down
23 changes: 14 additions & 9 deletions rust/worker/src/execution/operators/partition.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::errors::{ChromaError, ErrorCodes};
use crate::execution::data::data_chunk::DataChunk;
use crate::execution::data::data_chunk::Chunk;
use crate::execution::operator::Operator;
use crate::types::LogRecord;
use async_trait::async_trait;
use std::collections::HashMap;
use thiserror::Error;
Expand All @@ -18,7 +19,7 @@ pub struct PartitionOperator {}
/// * `records` - The records to partition.
#[derive(Debug)]
pub struct PartitionInput {
pub(crate) records: DataChunk,
pub(crate) records: Chunk<LogRecord>,
pub(crate) max_partition_size: usize,
}

Expand All @@ -29,7 +30,7 @@ impl PartitionInput {
/// * `max_partition_size` - The maximum size of a partition. Since we are trying to
/// partition the records by id, which can casue the partition size to be larger than this
/// value.
pub fn new(records: DataChunk, max_partition_size: usize) -> Self {
pub fn new(records: Chunk<LogRecord>, max_partition_size: usize) -> Self {
PartitionInput {
records,
max_partition_size,
Expand All @@ -42,7 +43,7 @@ impl PartitionInput {
/// * `records` - The partitioned records.
#[derive(Debug)]
pub struct PartitionOutput {
pub(crate) records: Vec<DataChunk>,
pub(crate) records: Vec<Chunk<LogRecord>>,
}

#[derive(Debug, Error)]
Expand All @@ -66,7 +67,11 @@ impl PartitionOperator {
Box::new(PartitionOperator {})
}

pub fn partition(&self, records: &DataChunk, partition_size: usize) -> Vec<DataChunk> {
pub fn partition(
&self,
records: &Chunk<LogRecord>,
partition_size: usize,
) -> Vec<Chunk<LogRecord>> {
let mut map = HashMap::new();
for data in records.iter() {
let log_record = data.0;
Expand Down Expand Up @@ -174,23 +179,23 @@ mod tests {
let data: Arc<[LogRecord]> = data.into();

// Test group size is larger than the number of records
let chunk = DataChunk::new(data.clone());
let chunk = Chunk::new(data.clone());
let operator = PartitionOperator::new();
let input = PartitionInput::new(chunk, 4);
let result = operator.run(&input).await.unwrap();
assert_eq!(result.records.len(), 1);
assert_eq!(result.records[0].len(), 3);

// Test group size is the same as the number of records
let chunk = DataChunk::new(data.clone());
let chunk = Chunk::new(data.clone());
let operator = PartitionOperator::new();
let input = PartitionInput::new(chunk, 3);
let result = operator.run(&input).await.unwrap();
assert_eq!(result.records.len(), 1);
assert_eq!(result.records[0].len(), 3);

// Test group size is smaller than the number of records
let chunk = DataChunk::new(data.clone());
let chunk = Chunk::new(data.clone());
let operator = PartitionOperator::new();
let input = PartitionInput::new(chunk, 2);
let mut result = operator.run(&input).await.unwrap();
Expand All @@ -206,7 +211,7 @@ mod tests {
}

// Test group size is smaller than the number of records
let chunk = DataChunk::new(data.clone());
let chunk = Chunk::new(data.clone());
let operator = PartitionOperator::new();
let input = PartitionInput::new(chunk, 1);
let mut result = operator.run(&input).await.unwrap();
Expand Down
11 changes: 6 additions & 5 deletions rust/worker/src/execution/operators/pull_log.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::execution::data::data_chunk::DataChunk;
use crate::execution::data::data_chunk::Chunk;
use crate::execution::operator::Operator;
use crate::log::log::Log;
use crate::log::log::PullLogsError;
use crate::types::LogRecord;
use async_trait::async_trait;
use tracing::debug;
use tracing::trace;
Expand Down Expand Up @@ -66,21 +67,21 @@ impl PullLogsInput {
/// The output of the pull logs operator.
#[derive(Debug)]
pub struct PullLogsOutput {
logs: DataChunk,
logs: Chunk<LogRecord>,
}

impl PullLogsOutput {
/// Create a new pull logs output.
/// # Parameters
/// * `logs` - The logs that were read.
pub fn new(logs: DataChunk) -> Self {
pub fn new(logs: Chunk<LogRecord>) -> Self {
PullLogsOutput { logs }
}

/// Get the log entries that were read by an invocation of the pull logs operator.
/// # Returns
/// The log entries that were read.
pub fn logs(&self) -> DataChunk {
pub fn logs(&self) -> Chunk<LogRecord> {
self.logs.clone()
}
}
Expand Down Expand Up @@ -138,7 +139,7 @@ impl Operator<PullLogsInput, PullLogsOutput> for PullLogsOperator {
trace!("Truncated log records {:?}", result);
}
// Convert to DataChunk
let data_chunk = DataChunk::new(result.into());
let data_chunk = Chunk::new(result.into());
Ok(PullLogsOutput::new(data_chunk))
}
}
Expand Down
9 changes: 5 additions & 4 deletions rust/worker/src/execution/orchestration/compact.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::super::operator::{wrap, TaskMessage};
use crate::compactor::CompactionJob;
use crate::errors::ChromaError;
use crate::execution::data::data_chunk::DataChunk;
use crate::execution::data::data_chunk::Chunk;
use crate::execution::operators::flush_sysdb::FlushSysDbInput;
use crate::execution::operators::flush_sysdb::FlushSysDbOperator;
use crate::execution::operators::flush_sysdb::FlushSysDbResult;
use crate::execution::operators::partition;
use crate::execution::operators::partition::PartitionInput;
use crate::execution::operators::partition::PartitionOperator;
use crate::execution::operators::partition::PartitionResult;
Expand All @@ -18,8 +19,8 @@ use crate::system::Component;
use crate::system::Handler;
use crate::system::Receiver;
use crate::system::System;
use crate::types::LogRecord;
use crate::types::SegmentFlushInfo;
use arrow::compute::kernels::partition;
use async_trait::async_trait;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
Expand Down Expand Up @@ -130,7 +131,7 @@ impl CompactOrchestrator {

async fn partition(
&mut self,
records: DataChunk,
records: Chunk<LogRecord>,
self_address: Box<dyn Receiver<PartitionResult>>,
) {
self.state = ExecutionState::Partition;
Expand All @@ -147,7 +148,7 @@ impl CompactOrchestrator {
}
}

async fn write(&mut self, partitions: Vec<DataChunk>) {
async fn write(&mut self, partitions: Vec<Chunk<LogRecord>>) {
self.state = ExecutionState::Write;

self.num_write_tasks = partitions.len() as i32;
Expand Down

0 comments on commit 648629e

Please sign in to comment.