From 59c2f715a96e2ee90ebdadea721cf3f9b391c29d Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 20 Nov 2021 08:49:03 +0000 Subject: [PATCH] Simplified example. --- examples/s3/Cargo.toml | 1 + examples/s3/src/main.rs | 12 ++-- examples/s3/src/stream.rs | 113 -------------------------------------- 3 files changed, 6 insertions(+), 120 deletions(-) delete mode 100644 examples/s3/src/stream.rs diff --git a/examples/s3/Cargo.toml b/examples/s3/Cargo.toml index ed47cc23aea..84d7b8c0216 100644 --- a/examples/s3/Cargo.toml +++ b/examples/s3/Cargo.toml @@ -8,3 +8,4 @@ arrow2 = { path = "../../", default-features = false, features = ["io_parquet", rust-s3 = { version = "0.27.0", features = ["tokio"] } futures = "0.3" tokio = { version = "1.0.0", features = ["macros", "rt-multi-thread"] } +range-reader = "0.1" diff --git a/examples/s3/src/main.rs b/examples/s3/src/main.rs index 4e1f1d70ad2..fb544c34f5d 100644 --- a/examples/s3/src/main.rs +++ b/examples/s3/src/main.rs @@ -5,11 +5,9 @@ use arrow2::io::parquet::read::{ decompress, get_page_stream, page_stream_to_array, read_metadata_async, }; use futures::{future::BoxFuture, StreamExt}; +use range_reader::{RangeOutput, RangedAsyncReader}; use s3::Bucket; -mod stream; -use stream::{RangedStreamer, SeekOutput}; - #[tokio::main] async fn main() -> Result<()> { let bucket_name = "dev-jorgecardleitao"; @@ -36,12 +34,12 @@ async fn main() -> Result<()> { .map_err(|x| std::io::Error::new(std::io::ErrorKind::Other, x.to_string()))?; println!("got {}/{} bytes starting at {}", data.len(), length, start); data.truncate(length); - Ok(SeekOutput { start, data }) - }) as BoxFuture<'static, std::io::Result> + Ok(RangeOutput { start, data }) + }) as BoxFuture<'static, std::io::Result> }); - // at least 4kb per s3 request. Adjust as you like. - let mut reader = RangedStreamer::new(length, 4 * 1024, range_get); + // at least 4kb per s3 request. Adjust to your liking. + let mut reader = RangedAsyncReader::new(length, 4 * 1024, range_get); let metadata = read_metadata_async(&mut reader).await?; diff --git a/examples/s3/src/stream.rs b/examples/s3/src/stream.rs deleted file mode 100644 index d09d26e7dd1..00000000000 --- a/examples/s3/src/stream.rs +++ /dev/null @@ -1,113 +0,0 @@ -// Special thanks to Alice for the help: https://users.rust-lang.org/t/63019/6 -use std::io::{Result, SeekFrom}; -use std::pin::Pin; - -use futures::{ - future::BoxFuture, - io::{AsyncRead, AsyncSeek}, - Future, -}; - -pub struct RangedStreamer { - pos: u64, - length: u64, // total size - state: State, - range_get: F, - min_request_size: usize, // requests have at least this size -} - -enum State { - HasChunk(SeekOutput), - Seeking(BoxFuture<'static, std::io::Result>), -} - -pub struct SeekOutput { - pub start: u64, - pub data: Vec, -} - -pub type F = Box< - dyn Fn(u64, usize) -> BoxFuture<'static, std::io::Result> + Send + Sync, ->; - -impl RangedStreamer { - pub fn new(length: usize, min_request_size: usize, range_get: F) -> Self { - let length = length as u64; - Self { - pos: 0, - length, - state: State::HasChunk(SeekOutput { - start: 0, - data: vec![], - }), - range_get, - min_request_size, - } - } -} - -// whether `test_interval` is inside `a` (start, length). -fn range_includes(a: (usize, usize), test_interval: (usize, usize)) -> bool { - if test_interval.0 < a.0 { - return false; - } - let test_end = test_interval.0 + test_interval.1; - let a_end = a.0 + a.1; - if test_end > a_end { - return false; - } - true -} - -impl AsyncRead for RangedStreamer { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut [u8], - ) -> std::task::Poll> { - let requested_range = (self.pos as usize, buf.len()); - let min_request_size = self.min_request_size; - match &mut self.state { - State::HasChunk(output) => { - let existing_range = (output.start as usize, output.data.len()); - if range_includes(existing_range, requested_range) { - let offset = requested_range.0 - existing_range.0; - buf.copy_from_slice(&output.data[offset..offset + buf.len()]); - self.pos += buf.len() as u64; - std::task::Poll::Ready(Ok(buf.len())) - } else { - let start = requested_range.0 as u64; - let length = std::cmp::max(min_request_size, requested_range.1); - let future = (self.range_get)(start, length); - self.state = State::Seeking(Box::pin(future)); - self.poll_read(cx, buf) - } - } - State::Seeking(ref mut future) => match Pin::new(future).poll(cx) { - std::task::Poll::Ready(v) => { - match v { - Ok(output) => self.state = State::HasChunk(output), - Err(e) => return std::task::Poll::Ready(Err(e)), - }; - self.poll_read(cx, buf) - } - std::task::Poll::Pending => std::task::Poll::Pending, - }, - } - } -} - -impl AsyncSeek for RangedStreamer { - fn poll_seek( - mut self: std::pin::Pin<&mut Self>, - _: &mut std::task::Context<'_>, - pos: std::io::SeekFrom, - ) -> std::task::Poll> { - match pos { - SeekFrom::Start(pos) => self.pos = pos, - SeekFrom::End(pos) => self.pos = (self.length as i64 + pos) as u64, - SeekFrom::Current(pos) => self.pos = (self.pos as i64 + pos) as u64, - }; - std::task::Poll::Ready(Ok(self.pos)) - } -}