Skip to content

Commit

Permalink
Fix potential misuse of FileWriter API's (sync + async) (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
TurnOfACard authored May 19, 2022
1 parent 19fbc5c commit 3774868
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod write;
pub use streaming_decompression::fallible_streaming_iterator;
pub use streaming_decompression::FallibleStreamingIterator;

const HEADER_SIZE: u64 = PARQUET_MAGIC.len() as u64;
const FOOTER_SIZE: u64 = 8;
const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1'];

Expand Down
5 changes: 3 additions & 2 deletions src/read/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use parquet_format_async_temp::FileMetaData as TFileMetaData;
use super::super::{metadata::FileMetaData, DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, PARQUET_MAGIC};

use crate::error::{Error, Result};
use crate::HEADER_SIZE;

pub(super) fn metadata_len(buffer: &[u8], len: usize) -> i32 {
i32::from_le_bytes(buffer[len - 8..len - 4].try_into().unwrap())
Expand Down Expand Up @@ -41,9 +42,9 @@ fn stream_len(seek: &mut impl Seek) -> std::result::Result<u64, std::io::Error>
pub fn read_metadata<R: Read + Seek>(reader: &mut R) -> Result<FileMetaData> {
// check file is large enough to hold footer
let file_size = stream_len(reader)?;
if file_size < FOOTER_SIZE {
if file_size < HEADER_SIZE + FOOTER_SIZE {
return Err(general_err!(
"Invalid Parquet file. Size is smaller than footer"
"Invalid Parquet file. Size is smaller than header + footer"
));
}

Expand Down
3 changes: 2 additions & 1 deletion src/read/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use parquet_format_async_temp::FileMetaData as TFileMetaData;
use super::super::{metadata::FileMetaData, DEFAULT_FOOTER_READ_SIZE, FOOTER_SIZE, PARQUET_MAGIC};
use super::metadata::metadata_len;
use crate::error::{Error, Result};
use crate::HEADER_SIZE;

async fn stream_len(
seek: &mut (impl AsyncSeek + std::marker::Unpin),
Expand All @@ -30,7 +31,7 @@ pub async fn read_metadata<R: AsyncRead + AsyncSeek + Send + std::marker::Unpin>
let file_size = stream_len(reader).await?;

// check file is large enough to hold footer
if file_size < FOOTER_SIZE {
if file_size < HEADER_SIZE + FOOTER_SIZE {
return Err(general_err!(
"Invalid Parquet file. Size is smaller than footer"
));
Expand Down
37 changes: 29 additions & 8 deletions src/write/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use super::page::PageWriteSpec;
use super::{row_group::write_row_group, RowGroupIter, WriteOptions};

pub use crate::metadata::KeyValue;
use crate::write::State;

pub(super) fn start_file<W: Write>(writer: &mut W) -> Result<u64> {
writer.write_all(&PARQUET_MAGIC)?;
Expand Down Expand Up @@ -52,6 +53,8 @@ pub struct FileWriter<W: Write> {
offset: u64,
row_groups: Vec<RowGroup>,
page_specs: Vec<Vec<Vec<PageWriteSpec>>>,
/// Used to store the current state for writing the file
state: State,
}

// Accessors
Expand Down Expand Up @@ -83,13 +86,24 @@ impl<W: Write> FileWriter<W> {
offset: 0,
row_groups: vec![],
page_specs: vec![],
state: State::Initialised,
}
}

/// Writes the header of the file
pub fn start(&mut self) -> Result<()> {
self.offset = start_file(&mut self.writer)? as u64;
Ok(())
/// Writes the header of the file.
///
/// This is automatically called by [`Self::write`] if not called following [`Self::new`].
///
/// # Errors
/// Returns an error if data has been written to the file.
fn start(&mut self) -> Result<()> {
if self.offset == 0 {
self.offset = start_file(&mut self.writer)? as u64;
self.state = State::Started;
Ok(())
} else {
Err(Error::General("Start cannot be called twice".to_string()))
}
}

/// Writes a row group to the file.
Expand All @@ -101,9 +115,7 @@ impl<W: Write> FileWriter<W> {
E: std::error::Error,
{
if self.offset == 0 {
return Err(Error::General(
"You must call `start` before writing the first row group".to_string(),
));
self.start()?;
}
let ordinal = self.row_groups.len();
let (group, specs, size) = write_row_group(
Expand All @@ -119,8 +131,16 @@ impl<W: Write> FileWriter<W> {
Ok(())
}

/// Writes the footer of the parquet file. Returns the total size of the file.
/// Writes the footer of the parquet file. Returns the total size of the file and the
/// underlying writer.
pub fn end(&mut self, key_value_metadata: Option<Vec<KeyValue>>) -> Result<u64> {
if self.offset == 0 {
self.start()?;
}

if self.state != State::Started {
return Err(Error::General("End cannot be called twice".to_string()));
}
// compute file stats
let num_rows = self.row_groups.iter().map(|group| group.num_rows).sum();

Expand Down Expand Up @@ -176,6 +196,7 @@ impl<W: Write> FileWriter<W> {
);

let len = end_file(&mut self.writer, metadata)?;
self.state = State::Finished;
Ok(self.offset + len)
}

Expand Down
8 changes: 8 additions & 0 deletions src/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ pub enum Version {
V2,
}

/// Used to recall the state of the parquet writer - whether sync or async.
#[derive(PartialEq)]
enum State {
Initialised,
Started,
Finished,
}

impl From<Version> for i32 {
fn from(version: Version) -> Self {
match version {
Expand Down
48 changes: 36 additions & 12 deletions src/write/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use parquet_format_async_temp::{
FileMetaData, RowGroup,
};

use crate::write::State;
use crate::{
error::{Error, Result},
metadata::{KeyValue, SchemaDescriptor},
Expand Down Expand Up @@ -52,6 +53,8 @@ pub struct FileStreamer<W: AsyncWrite + Unpin + Send> {

offset: u64,
row_groups: Vec<RowGroup>,
/// Used to store the current state for writing the file
state: State,
}

// Accessors
Expand Down Expand Up @@ -82,13 +85,24 @@ impl<W: AsyncWrite + Unpin + Send> FileStreamer<W> {
created_by,
offset: 0,
row_groups: vec![],
state: State::Initialised,
}
}

/// Writes the header of the file
pub async fn start(&mut self) -> Result<()> {
self.offset = start_file(&mut self.writer).await? as u64;
Ok(())
/// Writes the header of the file.
///
/// This is automatically called by [`Self::write`] if not called following [`Self::new`].
///
/// # Errors
/// Returns an error if data has been written to the file.
async fn start(&mut self) -> Result<()> {
if self.offset == 0 {
self.offset = start_file(&mut self.writer).await? as u64;
self.state = State::Started;
Ok(())
} else {
Err(Error::General("Start cannot be called twice".to_string()))
}
}

/// Writes a row group to the file.
Expand All @@ -98,9 +112,7 @@ impl<W: AsyncWrite + Unpin + Send> FileStreamer<W> {
E: std::error::Error,
{
if self.offset == 0 {
return Err(Error::General(
"You must call `start` before writing the first row group".to_string(),
));
self.start().await?;
}
let (group, _specs, size) = write_row_group_async(
&mut self.writer,
Expand All @@ -116,23 +128,35 @@ impl<W: AsyncWrite + Unpin + Send> FileStreamer<W> {

/// Writes the footer of the parquet file. Returns the total size of the file and the
/// underlying writer.
pub async fn end(mut self, key_value_metadata: Option<Vec<KeyValue>>) -> Result<(u64, W)> {
pub async fn end(&mut self, key_value_metadata: Option<Vec<KeyValue>>) -> Result<u64> {
if self.offset == 0 {
self.start().await?;
}

if self.state != State::Started {
return Err(Error::General("End cannot be called twice".to_string()));
}
// compute file stats
let num_rows = self.row_groups.iter().map(|group| group.num_rows).sum();

let metadata = FileMetaData::new(
self.options.version.into(),
self.schema.into_thrift(),
self.schema.clone().into_thrift(),
num_rows,
self.row_groups,
self.row_groups.clone(),
key_value_metadata,
self.created_by,
self.created_by.clone(),
None,
None,
None,
);

let len = end_file(&mut self.writer, metadata).await?;
Ok((self.offset + len, self.writer))
Ok(self.offset + len)
}

/// Returns the underlying writer.
pub fn into_inner(self) -> W {
self.writer
}
}
1 change: 0 additions & 1 deletion tests/it/write/indexes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ fn write_file() -> Result<Vec<u8>> {
let writer = Cursor::new(vec![]);
let mut writer = FileWriter::new(writer, schema, options, None);

writer.start()?;
writer.write(DynIter::new(columns))?;
writer.end(None)?;

Expand Down
7 changes: 2 additions & 5 deletions tests/it/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ fn test_column(column: &str, compression: CompressionOptions) -> Result<()> {
let writer = Cursor::new(vec![]);
let mut writer = FileWriter::new(writer, schema, options, None);

writer.start()?;
writer.write(DynIter::new(columns))?;
writer.end(None)?;

Expand Down Expand Up @@ -214,7 +213,6 @@ fn basic() -> Result<()> {
let writer = Cursor::new(vec![]);
let mut writer = FileWriter::new(writer, schema, options, None);

writer.start()?;
writer.write(DynIter::new(columns))?;
writer.end(None)?;

Expand Down Expand Up @@ -272,11 +270,10 @@ async fn test_column_async(column: &str) -> Result<()> {
let writer = futures::io::Cursor::new(vec![]);
let mut writer = FileStreamer::new(writer, schema, options, None);

writer.start().await?;
writer.write(DynIter::new(columns)).await?;
let writer = writer.end(None).await?.1;
writer.end(None).await?;

let data = writer.into_inner();
let data = writer.into_inner().into_inner();

let (result, statistics) = read_column_async(&mut futures::io::Cursor::new(data)).await?;
assert_eq!(array, result);
Expand Down

0 comments on commit 3774868

Please sign in to comment.