diff --git a/src/lib.rs b/src/lib.rs index 6a58876bd..d8363d38b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ pub mod write; pub use streaming_decompression::fallible_streaming_iterator; pub use streaming_decompression::FallibleStreamingIterator; +const HEADER_SIZE: u64 = PARQUET_MAGIC.len() as u64; const FOOTER_SIZE: u64 = 8; const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; diff --git a/src/read/metadata.rs b/src/read/metadata.rs index b57358f26..6f166c1d0 100644 --- a/src/read/metadata.rs +++ b/src/read/metadata.rs @@ -10,6 +10,7 @@ use parquet_format_async_temp::FileMetaData as TFileMetaData; use super::super::{metadata::FileMetaData, DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, PARQUET_MAGIC}; use crate::error::{Error, Result}; +use crate::HEADER_SIZE; pub(super) fn metadata_len(buffer: &[u8], len: usize) -> i32 { i32::from_le_bytes(buffer[len - 8..len - 4].try_into().unwrap()) @@ -41,9 +42,9 @@ fn stream_len(seek: &mut impl Seek) -> std::result::Result pub fn read_metadata(reader: &mut R) -> Result { // check file is large enough to hold footer let file_size = stream_len(reader)?; - if file_size < FOOTER_SIZE { + if file_size < HEADER_SIZE + FOOTER_SIZE { return Err(general_err!( - "Invalid Parquet file. Size is smaller than footer" + "Invalid Parquet file. Size is smaller than header + footer" )); } diff --git a/src/read/stream.rs b/src/read/stream.rs index 4c4c5e58f..5e4c4108c 100644 --- a/src/read/stream.rs +++ b/src/read/stream.rs @@ -7,6 +7,7 @@ use parquet_format_async_temp::FileMetaData as TFileMetaData; use super::super::{metadata::FileMetaData, DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, PARQUET_MAGIC}; use super::metadata::metadata_len; use crate::error::{Error, Result}; +use crate::HEADER_SIZE; async fn stream_len( seek: &mut (impl AsyncSeek + std::marker::Unpin), @@ -30,7 +31,7 @@ pub async fn read_metadata let file_size = stream_len(reader).await?; // check file is large enough to hold footer - if file_size < FOOTER_SIZE { + if file_size < HEADER_SIZE + FOOTER_SIZE { return Err(general_err!( "Invalid Parquet file. Size is smaller than footer" )); diff --git a/src/write/file.rs b/src/write/file.rs index 36c4560ef..ecb276a5b 100644 --- a/src/write/file.rs +++ b/src/write/file.rs @@ -16,6 +16,7 @@ use super::page::PageWriteSpec; use super::{row_group::write_row_group, RowGroupIter, WriteOptions}; pub use crate::metadata::KeyValue; +use crate::write::State; pub(super) fn start_file(writer: &mut W) -> Result { writer.write_all(&PARQUET_MAGIC)?; @@ -52,6 +53,8 @@ pub struct FileWriter { offset: u64, row_groups: Vec, page_specs: Vec>>, + /// Used to store the current state for writing the file + state: State, } // Accessors @@ -83,13 +86,24 @@ impl FileWriter { offset: 0, row_groups: vec![], page_specs: vec![], + state: State::Initialised, } } - /// Writes the header of the file - pub fn start(&mut self) -> Result<()> { - self.offset = start_file(&mut self.writer)? as u64; - Ok(()) + /// Writes the header of the file. + /// + /// This is automatically called by [`Self::write`] if not called following [`Self::new`]. + /// + /// # Errors + /// Returns an error if data has been written to the file. + fn start(&mut self) -> Result<()> { + if self.offset == 0 { + self.offset = start_file(&mut self.writer)? as u64; + self.state = State::Started; + Ok(()) + } else { + Err(Error::General("Start cannot be called twice".to_string())) + } } /// Writes a row group to the file. @@ -101,9 +115,7 @@ impl FileWriter { E: std::error::Error, { if self.offset == 0 { - return Err(Error::General( - "You must call `start` before writing the first row group".to_string(), - )); + self.start()?; } let ordinal = self.row_groups.len(); let (group, specs, size) = write_row_group( @@ -119,8 +131,16 @@ impl FileWriter { Ok(()) } - /// Writes the footer of the parquet file. Returns the total size of the file. + /// Writes the footer of the parquet file. Returns the total size of the file and the + /// underlying writer. pub fn end(&mut self, key_value_metadata: Option>) -> Result { + if self.offset == 0 { + self.start()?; + } + + if self.state != State::Started { + return Err(Error::General("End cannot be called twice".to_string())); + } // compute file stats let num_rows = self.row_groups.iter().map(|group| group.num_rows).sum(); @@ -176,6 +196,7 @@ impl FileWriter { ); let len = end_file(&mut self.writer, metadata)?; + self.state = State::Finished; Ok(self.offset + len) } diff --git a/src/write/mod.rs b/src/write/mod.rs index e9e6dce57..6fdb4fe14 100644 --- a/src/write/mod.rs +++ b/src/write/mod.rs @@ -41,6 +41,14 @@ pub enum Version { V2, } +/// Used to recall the state of the parquet writer - whether sync or async. +#[derive(PartialEq)] +enum State { + Initialised, + Started, + Finished, +} + impl From for i32 { fn from(version: Version) -> Self { match version { diff --git a/src/write/stream.rs b/src/write/stream.rs index 37160e5f5..a398edece 100644 --- a/src/write/stream.rs +++ b/src/write/stream.rs @@ -7,6 +7,7 @@ use parquet_format_async_temp::{ FileMetaData, RowGroup, }; +use crate::write::State; use crate::{ error::{Error, Result}, metadata::{KeyValue, SchemaDescriptor}, @@ -52,6 +53,8 @@ pub struct FileStreamer { offset: u64, row_groups: Vec, + /// Used to store the current state for writing the file + state: State, } // Accessors @@ -82,13 +85,24 @@ impl FileStreamer { created_by, offset: 0, row_groups: vec![], + state: State::Initialised, } } - /// Writes the header of the file - pub async fn start(&mut self) -> Result<()> { - self.offset = start_file(&mut self.writer).await? as u64; - Ok(()) + /// Writes the header of the file. + /// + /// This is automatically called by [`Self::write`] if not called following [`Self::new`]. + /// + /// # Errors + /// Returns an error if data has been written to the file. + async fn start(&mut self) -> Result<()> { + if self.offset == 0 { + self.offset = start_file(&mut self.writer).await? as u64; + self.state = State::Started; + Ok(()) + } else { + Err(Error::General("Start cannot be called twice".to_string())) + } } /// Writes a row group to the file. @@ -98,9 +112,7 @@ impl FileStreamer { E: std::error::Error, { if self.offset == 0 { - return Err(Error::General( - "You must call `start` before writing the first row group".to_string(), - )); + self.start().await?; } let (group, _specs, size) = write_row_group_async( &mut self.writer, @@ -116,23 +128,35 @@ impl FileStreamer { /// Writes the footer of the parquet file. Returns the total size of the file and the /// underlying writer. - pub async fn end(mut self, key_value_metadata: Option>) -> Result<(u64, W)> { + pub async fn end(&mut self, key_value_metadata: Option>) -> Result { + if self.offset == 0 { + self.start().await?; + } + + if self.state != State::Started { + return Err(Error::General("End cannot be called twice".to_string())); + } // compute file stats let num_rows = self.row_groups.iter().map(|group| group.num_rows).sum(); let metadata = FileMetaData::new( self.options.version.into(), - self.schema.into_thrift(), + self.schema.clone().into_thrift(), num_rows, - self.row_groups, + self.row_groups.clone(), key_value_metadata, - self.created_by, + self.created_by.clone(), None, None, None, ); let len = end_file(&mut self.writer, metadata).await?; - Ok((self.offset + len, self.writer)) + Ok(self.offset + len) + } + + /// Returns the underlying writer. + pub fn into_inner(self) -> W { + self.writer } } diff --git a/tests/it/write/indexes.rs b/tests/it/write/indexes.rs index 891bc7cf7..545cf392b 100644 --- a/tests/it/write/indexes.rs +++ b/tests/it/write/indexes.rs @@ -51,7 +51,6 @@ fn write_file() -> Result> { let writer = Cursor::new(vec![]); let mut writer = FileWriter::new(writer, schema, options, None); - writer.start()?; writer.write(DynIter::new(columns))?; writer.end(None)?; diff --git a/tests/it/write/mod.rs b/tests/it/write/mod.rs index 1364c44bc..eccb0f589 100644 --- a/tests/it/write/mod.rs +++ b/tests/it/write/mod.rs @@ -90,7 +90,6 @@ fn test_column(column: &str, compression: CompressionOptions) -> Result<()> { let writer = Cursor::new(vec![]); let mut writer = FileWriter::new(writer, schema, options, None); - writer.start()?; writer.write(DynIter::new(columns))?; writer.end(None)?; @@ -214,7 +213,6 @@ fn basic() -> Result<()> { let writer = Cursor::new(vec![]); let mut writer = FileWriter::new(writer, schema, options, None); - writer.start()?; writer.write(DynIter::new(columns))?; writer.end(None)?; @@ -272,11 +270,10 @@ async fn test_column_async(column: &str) -> Result<()> { let writer = futures::io::Cursor::new(vec![]); let mut writer = FileStreamer::new(writer, schema, options, None); - writer.start().await?; writer.write(DynIter::new(columns)).await?; - let writer = writer.end(None).await?.1; + writer.end(None).await?; - let data = writer.into_inner(); + let data = writer.into_inner().into_inner(); let (result, statistics) = read_column_async(&mut futures::io::Cursor::new(data)).await?; assert_eq!(array, result);