diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f8a9a045a75..96191833fe1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -172,7 +172,7 @@ jobs: submodules: true - uses: actions-rs/toolchain@v1 with: - toolchain: nightly + toolchain: nightly-2022-03-16 target: ${{ matrix.target }} override: true - uses: Swatinem/rust-cache@v1 diff --git a/examples/extension.rs b/examples/extension.rs index c2299028be5..70b9e076029 100644 --- a/examples/extension.rs +++ b/examples/extension.rs @@ -38,11 +38,13 @@ fn write_ipc(writer: W, array: impl Array + 'static) -> Result< let schema = vec![Field::new("a", array.data_type().clone(), false)].into(); let options = write::WriteOptions { compression: None }; - let mut writer = write::FileWriter::try_new(writer, &schema, None, options)?; + let mut writer = write::FileWriter::new(writer, schema, None, options); let batch = Chunk::try_new(vec![Arc::new(array) as Arc])?; + writer.start()?; writer.write(&batch, None)?; + writer.finish()?; Ok(writer.into_inner()) } diff --git a/examples/ipc_file_write.rs b/examples/ipc_file_write.rs index 3bd64981724..a1e4e43e2ec 100644 --- a/examples/ipc_file_write.rs +++ b/examples/ipc_file_write.rs @@ -7,12 +7,13 @@ use arrow2::datatypes::{DataType, Field, Schema}; use arrow2::error::Result; use arrow2::io::ipc::write; -fn write_batches(path: &str, schema: &Schema, columns: &[Chunk>]) -> Result<()> { +fn write_batches(path: &str, schema: Schema, columns: &[Chunk>]) -> Result<()> { let file = File::create(path)?; let options = write::WriteOptions { compression: None }; - let mut writer = write::FileWriter::try_new(file, schema, None, options)?; + let mut writer = write::FileWriter::new(file, schema, None, options); + writer.start()?; for columns in columns { writer.write(columns, None)? } @@ -37,6 +38,6 @@ fn main() -> Result<()> { let batch = Chunk::try_new(vec![Arc::new(a) as Arc, Arc::new(b)])?; // write it - write_batches(file_path, &schema, &[batch])?; + write_batches(file_path, schema, &[batch])?; Ok(()) } diff --git a/src/io/ipc/write/writer.rs b/src/io/ipc/write/writer.rs index 7c3527931de..467e9fefd8f 100644 --- a/src/io/ipc/write/writer.rs +++ b/src/io/ipc/write/writer.rs @@ -15,6 +15,13 @@ use crate::chunk::Chunk; use crate::datatypes::*; use crate::error::{ArrowError, Result}; +#[derive(Clone, Copy, PartialEq, Eq)] +enum State { + None, + Started, + Finished, +} + /// Arrow file writer pub struct FileWriter { /// The object to write to @@ -31,47 +38,49 @@ pub struct FileWriter { /// Record blocks that will be written as part of the IPC footer record_blocks: Vec, /// Whether the writer footer has been written, and the writer is finished - finished: bool, + state: State, /// Keeps track of dictionaries that have been written dictionary_tracker: DictionaryTracker, } impl FileWriter { - /// Try create a new writer, with the schema written as part of the header + /// Creates a new [`FileWriter`] and writes the header to `writer` pub fn try_new( - mut writer: W, + writer: W, schema: &Schema, ipc_fields: Option>, options: WriteOptions, ) -> Result { - // write magic to header - writer.write_all(&ARROW_MAGIC[..])?; - // create an 8-byte boundary after the header - writer.write_all(&[0, 0])?; - // write the schema, set the written bytes to the schema + let mut slf = Self::new(writer, schema.clone(), ipc_fields, options); + slf.start()?; + + Ok(slf) + } + /// Creates a new [`FileWriter`]. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { let ipc_fields = if let Some(ipc_fields) = ipc_fields { ipc_fields } else { default_ipc_fields(&schema.fields) }; - let encoded_message = EncodedData { - ipc_message: schema_to_bytes(schema, &ipc_fields), - arrow_data: vec![], - }; - let (meta, data) = write_message(&mut writer, encoded_message)?; - Ok(Self { + Self { writer, options, - schema: schema.clone(), + schema, ipc_fields, - block_offsets: meta + data + 8, + block_offsets: 0, dictionary_blocks: vec![], record_blocks: vec![], - finished: false, + state: State::None, dictionary_tracker: DictionaryTracker::new(true), - }) + } } /// Consumes itself into the inner writer @@ -79,17 +88,40 @@ impl FileWriter { self.writer } + /// Writes the header and first (schema) message to the file. + /// # Errors + /// Errors if the file has been started or has finished. + pub fn start(&mut self) -> Result<()> { + if self.state != State::None { + return Err(ArrowError::oos("The IPC file can only be started once")); + } + // write magic to header + self.writer.write_all(&ARROW_MAGIC[..])?; + // create an 8-byte boundary after the header + self.writer.write_all(&[0, 0])?; + // write the schema, set the written bytes to the schema + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(&self.schema, &self.ipc_fields), + arrow_data: vec![], + }; + + let (meta, data) = write_message(&mut self.writer, encoded_message)?; + self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment + self.state = State::Started; + Ok(()) + } + /// Writes [`Chunk`] to the file pub fn write( &mut self, columns: &Chunk>, ipc_fields: Option<&[IpcField]>, ) -> Result<()> { - if self.finished { - return Err(ArrowError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "Cannot write to a finished file".to_string(), - ))); + if self.state != State::Started { + return Err(ArrowError::oos( + "The IPC file must be started before it can be written to. Call `start` before `write`", + )); } let ipc_fields = if let Some(ipc_fields) = ipc_fields { @@ -132,6 +164,12 @@ impl FileWriter { /// Write footer and closing tag, then mark the writer as done pub fn finish(&mut self) -> Result<()> { + if self.state != State::Started { + return Err(ArrowError::oos( + "The IPC file must be started before it can be finished. Call `start` before `finish`", + )); + } + // write EOS write_continuation(&mut self.writer, 0)?; @@ -151,7 +189,7 @@ impl FileWriter { .write_all(&(footer_data.len() as i32).to_le_bytes())?; self.writer.write_all(&ARROW_MAGIC)?; self.writer.flush()?; - self.finished = true; + self.state = State::Finished; Ok(()) }