From d0017a3da086f43ac44846451c952d78e18becd3 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sun, 15 May 2022 16:53:32 -0400 Subject: [PATCH] Adding iter methods to io::ndjson::read --- src/io/ndjson/read/deserialize.rs | 20 ++++++++++++++++++++ src/io/ndjson/read/file.rs | 20 ++++++++++++++++++++ src/io/ndjson/read/mod.rs | 4 ++-- tests/it/io/ndjson/read.rs | 23 +++++++++++++++++++++++ 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/src/io/ndjson/read/deserialize.rs b/src/io/ndjson/read/deserialize.rs index e02927d0248..230580453c8 100644 --- a/src/io/ndjson/read/deserialize.rs +++ b/src/io/ndjson/read/deserialize.rs @@ -31,3 +31,23 @@ pub fn deserialize(rows: &[String], data_type: DataType) -> Result( + rows: impl Iterator>, + data_type: DataType, +) -> Result, ArrowError> { + // deserialize strings to `Value`s + let rows = rows + .map(|row| serde_json::from_str(row.unwrap_or("null")).map_err(ArrowError::from)) + .collect::, ArrowError>>()?; + + // deserialize &[Value] to Array + Ok(_deserialize(&rows, data_type)) +} diff --git a/src/io/ndjson/read/file.rs b/src/io/ndjson/read/file.rs index c45dc139e1d..f75069ef1f7 100644 --- a/src/io/ndjson/read/file.rs +++ b/src/io/ndjson/read/file.rs @@ -125,3 +125,23 @@ pub fn infer( let v: Vec<&DataType> = data_types.iter().collect(); Ok(coerce_data_type(&v)) } + +/// Infers the [`DataType`] from an iterator of JSON strings. A limited number of +/// rows can be used by passing `rows.take(number_of_rows)` as an input. +/// +/// # Implementation +/// This implementation infers each row by going through the entire iterator. +pub fn infer_iter<'a>(rows: impl Iterator>) -> Result +{ + let mut data_types = HashSet::new(); + for row in rows.flatten() { + let v: Value = serde_json::from_str(row)?; + let data_type = infer_json(&v)?; + if data_type != DataType::Null { + data_types.insert(data_type); + } + } + + let v: Vec<&DataType> = data_types.iter().collect(); + Ok(coerce_data_type(&v)) +} diff --git a/src/io/ndjson/read/mod.rs b/src/io/ndjson/read/mod.rs index 5c52bd183fc..6e7da6131bf 100644 --- a/src/io/ndjson/read/mod.rs +++ b/src/io/ndjson/read/mod.rs @@ -4,5 +4,5 @@ pub use fallible_streaming_iterator::FallibleStreamingIterator; mod deserialize; mod file; -pub use deserialize::deserialize; -pub use file::{infer, FileReader}; +pub use deserialize::{deserialize, deserialize_iter}; +pub use file::{infer, infer_iter, FileReader}; diff --git a/tests/it/io/ndjson/read.rs b/tests/it/io/ndjson/read.rs index 4b74248c569..82a3cb1eba0 100644 --- a/tests/it/io/ndjson/read.rs +++ b/tests/it/io/ndjson/read.rs @@ -273,3 +273,26 @@ fn skip_empty_lines() -> Result<()> { assert_eq!(3, arrays[0].len()); Ok(()) } + +#[test] +fn utf8_array() -> Result<()> { + let array = Utf8Array::::from([ + Some(r#"{"a": 1, "b": [{"c": 0}, {"c": 1}]}"#), + None, + Some(r#"{"a": 2, "b": [{"c": 2}, {"c": 5}]}"#), + None, + ]); + let data_type = ndjson_read::infer_iter(array.iter()).unwrap(); + let new_array = ndjson_read::deserialize_iter(array.iter(), data_type).unwrap(); + + // Explicitly cast as StructArray + let new_array = new_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(array.len(), new_array.len()); + assert_eq!(array.null_count(), new_array.null_count()); + assert_eq!(array.validity().unwrap(), new_array.validity().unwrap()); + + let field_names: Vec = new_array.fields().iter().map(|f| f.name.clone()).collect(); + assert_eq!(field_names, vec!["a".to_string(), "b".to_string()]); + Ok(()) +}