Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added example.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Aug 9, 2021
1 parent 1ada073 commit b5a9c5f
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 0 deletions.
10 changes: 10 additions & 0 deletions examples/s3/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "s3"
version = "0.1.0"
edition = "2018"

[dependencies]
arrow2 = { path = "../../", default-features = false, features = ["io_parquet", "io_parquet_compression"] }
rust-s3 = { version = "0.27.0-rc4", features = ["tokio"] }
futures = "0.3"
tokio = { version = "1.0.0", features = ["macros", "rt-multi-thread"] }
65 changes: 65 additions & 0 deletions examples/s3/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use arrow2::array::{Array, Int64Array};
use arrow2::datatypes::DataType;
use arrow2::error::Result;
use arrow2::io::parquet::read::{
decompress, get_page_stream, page_stream_to_array, read_metadata_async,
};
use futures::{future::BoxFuture, StreamExt};
use s3::Bucket;

mod stream;
use stream::{RangedStreamer, SeekOutput};

#[tokio::main]
async fn main() -> Result<()> {
let bucket_name = "dev-jorgecardleitao";
let region = "eu-central-1".parse().unwrap();
let bucket = Bucket::new_public(bucket_name, region).unwrap();
let path = "benches_65536.parquet".to_string();

let (data, _) = bucket.head_object(&path).await.unwrap();
let length = data.content_length.unwrap() as usize;

let range_get = std::sync::Arc::new(move |start: u64, length: usize| {
let bucket = bucket.clone();
let path = path.clone();
// just to get a sense of what is being queried in s3
Box::pin(async move {
let bucket = bucket.clone();
let path = path.clone();
println!("getting {} bytes starting at {}", length, start);
let (mut data, _) = bucket
.get_object_range(&path, start, Some(start + length as u64 - 1))
.await
.map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x.to_string()))?;

println!("got {}/{} bytes starting at {}", data.len(), length, start);
data.truncate(length);
Ok(SeekOutput { start, data })
}) as BoxFuture<'static, std::io::Result<SeekOutput>>
});

let mut reader = RangedStreamer::new(length, 4 * 1024, range_get);

let metadata = read_metadata_async(&mut reader).await?;

// metadata
println!("{}", metadata.num_rows);

// pages of the first row group and first column
// This is IO bounded and SHOULD be done in a shared thread pool (e.g. Tokio)
let pages = get_page_stream(&metadata, 0, 0, &mut reader, vec![]).await?;

// decompress the pages. This is CPU bounded and SHOULD be done in a dedicated thread pool (e.g. Rayon)
let pages = pages.map(|compressed_page| decompress(compressed_page?, &mut vec![]));

// deserialize the pages. This is CPU bounded and SHOULD be done in a dedicated thread pool (e.g. Rayon)
let array =
page_stream_to_array(pages, &metadata.row_groups[0].columns()[0], DataType::Int64).await?;

let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
// ... and have fun with it.
println!("len: {}", array.len());
println!("null_count: {}", array.null_count());
Ok(())
}
113 changes: 113 additions & 0 deletions examples/s3/src/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Special thanks to Alice for the help: https://users.rust-lang.org/t/63019/6
use std::io::{Result, SeekFrom};
use std::pin::Pin;

use futures::{
future::BoxFuture,
io::{AsyncRead, AsyncSeek},
Future,
};

pub struct RangedStreamer {
pos: u64,
length: u64, // total size
state: State,
range_get: F,
min_request_size: usize, // requests have at least this size
}

enum State {
HasChunk(SeekOutput),
Seeking(BoxFuture<'static, std::io::Result<SeekOutput>>),
}

pub struct SeekOutput {
pub start: u64,
pub data: Vec<u8>,
}

pub type F = std::sync::Arc<
dyn Fn(u64, usize) -> BoxFuture<'static, std::io::Result<SeekOutput>> + Send + Sync,
>;

impl RangedStreamer {
pub fn new(length: usize, min_request_size: usize, range_get: F) -> Self {
let length = length as u64;
Self {
pos: 0,
length,
state: State::HasChunk(SeekOutput {
start: 0,
data: vec![],
}),
range_get,
min_request_size,
}
}
}

// whether `test_interval` is inside `a` (start, length).
fn range_includes(a: (usize, usize), test_interval: (usize, usize)) -> bool {
if test_interval.0 < a.0 {
return false;
}
let test_end = test_interval.0 + test_interval.1;
let a_end = a.0 + a.1;
if test_end > a_end {
return false;
}
true
}

impl AsyncRead for RangedStreamer {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<Result<usize>> {
let requested_range = (self.pos as usize, buf.len());
let min_request_size = self.min_request_size;
match &mut self.state {
State::HasChunk(output) => {
let existing_range = (output.start as usize, output.data.len());
if range_includes(existing_range, requested_range) {
let offset = requested_range.0 - existing_range.0;
buf.copy_from_slice(&output.data[offset..offset + buf.len()]);
self.pos += buf.len() as u64;
std::task::Poll::Ready(Ok(buf.len()))
} else {
let start = requested_range.0 as u64;
let length = std::cmp::max(min_request_size, requested_range.1);
let future = (self.range_get)(start, length);
self.state = State::Seeking(Box::pin(future));
self.poll_read(cx, buf)
}
}
State::Seeking(ref mut future) => match Pin::new(future).poll(cx) {
std::task::Poll::Ready(v) => {
match v {
Ok(output) => self.state = State::HasChunk(output),
Err(e) => return std::task::Poll::Ready(Err(e)),
};
self.poll_read(cx, buf)
}
std::task::Poll::Pending => std::task::Poll::Pending,
},
}
}
}

impl AsyncSeek for RangedStreamer {
fn poll_seek(
mut self: std::pin::Pin<&mut Self>,
_: &mut std::task::Context<'_>,
pos: std::io::SeekFrom,
) -> std::task::Poll<Result<u64>> {
match pos {
SeekFrom::Start(pos) => self.pos = pos,
SeekFrom::End(pos) => self.pos = (self.length as i64 + pos) as u64,
SeekFrom::Current(pos) => self.pos = (self.pos as i64 + pos) as u64,
};
std::task::Poll::Ready(Ok(self.pos))
}
}

0 comments on commit b5a9c5f

Please sign in to comment.