From e9a1e3035fc9fc29f6951c12c2fc08f6d0ab91e6 Mon Sep 17 00:00:00 2001 From: Colin Jermain Date: Sun, 15 May 2022 18:26:41 -0400 Subject: [PATCH] Updating iter methods to support more general AsRef --- src/io/ndjson/read/deserialize.rs | 16 ++++------------ src/io/ndjson/read/file.rs | 7 +++---- tests/it/io/ndjson/read.rs | 6 ++++-- 3 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/io/ndjson/read/deserialize.rs b/src/io/ndjson/read/deserialize.rs index 230580453c8..5960b164989 100644 --- a/src/io/ndjson/read/deserialize.rs +++ b/src/io/ndjson/read/deserialize.rs @@ -22,30 +22,22 @@ pub fn deserialize(rows: &[String], data_type: DataType) -> Result, ArrowError>>()?; - - // deserialize &[Value] to Array - Ok(_deserialize(&rows, data_type)) + deserialize_iter(rows.iter(), data_type) } - /// Deserializes an iterator of rows into an [`Array`] of [`DataType`]. /// # Implementation /// This function is CPU-bounded. /// This function is guaranteed to return an array of length equal to the leng /// # Errors /// This function errors iff any of the rows is not a valid JSON (i.e. the format is not valid NDJSON). -pub fn deserialize_iter<'a>( - rows: impl Iterator>, +pub fn deserialize_iter>( + rows: impl Iterator, data_type: DataType, ) -> Result, ArrowError> { // deserialize strings to `Value`s let rows = rows - .map(|row| serde_json::from_str(row.unwrap_or("null")).map_err(ArrowError::from)) + .map(|row| serde_json::from_str(row.as_ref()).map_err(ArrowError::from)) .collect::, ArrowError>>()?; // deserialize &[Value] to Array diff --git a/src/io/ndjson/read/file.rs b/src/io/ndjson/read/file.rs index f75069ef1f7..9bf6129543e 100644 --- a/src/io/ndjson/read/file.rs +++ b/src/io/ndjson/read/file.rs @@ -131,11 +131,10 @@ pub fn infer( /// /// # Implementation /// This implementation infers each row by going through the entire iterator. -pub fn infer_iter<'a>(rows: impl Iterator>) -> Result -{ +pub fn infer_iter>(rows: impl Iterator) -> Result { let mut data_types = HashSet::new(); - for row in rows.flatten() { - let v: Value = serde_json::from_str(row)?; + for row in rows { + let v: Value = serde_json::from_str(row.as_ref())?; let data_type = infer_json(&v)?; if data_type != DataType::Null { data_types.insert(data_type); diff --git a/tests/it/io/ndjson/read.rs b/tests/it/io/ndjson/read.rs index 82a3cb1eba0..5e19c1f351c 100644 --- a/tests/it/io/ndjson/read.rs +++ b/tests/it/io/ndjson/read.rs @@ -282,8 +282,10 @@ fn utf8_array() -> Result<()> { Some(r#"{"a": 2, "b": [{"c": 2}, {"c": 5}]}"#), None, ]); - let data_type = ndjson_read::infer_iter(array.iter()).unwrap(); - let new_array = ndjson_read::deserialize_iter(array.iter(), data_type).unwrap(); + let data_type = ndjson_read::infer_iter(array.iter().map(|x| x.unwrap_or("null"))).unwrap(); + let new_array = + ndjson_read::deserialize_iter(array.iter().map(|x| x.unwrap_or("null")), data_type) + .unwrap(); // Explicitly cast as StructArray let new_array = new_array.as_any().downcast_ref::().unwrap();