Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Updating iter methods to support more general AsRef<str>
Browse files Browse the repository at this point in the history
  • Loading branch information
cjermain committed May 15, 2022
1 parent d0017a3 commit e9a1e30
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 18 deletions.
16 changes: 4 additions & 12 deletions src/io/ndjson/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,22 @@ pub fn deserialize(rows: &[String], data_type: DataType) -> Result<Arc<dyn Array
));
}

// deserialize strings to `Value`s
let rows = rows
.iter()
.map(|row| serde_json::from_str(row.as_ref()).map_err(ArrowError::from))
.collect::<Result<Vec<Value>, ArrowError>>()?;

// deserialize &[Value] to Array
Ok(_deserialize(&rows, data_type))
deserialize_iter(rows.iter(), data_type)
}


/// Deserializes an iterator of rows into an [`Array`] of [`DataType`].
/// # Implementation
/// This function is CPU-bounded.
/// This function is guaranteed to return an array of length equal to the leng
/// # Errors
/// This function errors iff any of the rows is not a valid JSON (i.e. the format is not valid NDJSON).
pub fn deserialize_iter<'a>(
rows: impl Iterator<Item=Option<&'a str>>,
pub fn deserialize_iter<A: AsRef<str>>(
rows: impl Iterator<Item = A>,
data_type: DataType,
) -> Result<Arc<dyn Array>, ArrowError> {
// deserialize strings to `Value`s
let rows = rows
.map(|row| serde_json::from_str(row.unwrap_or("null")).map_err(ArrowError::from))
.map(|row| serde_json::from_str(row.as_ref()).map_err(ArrowError::from))
.collect::<Result<Vec<Value>, ArrowError>>()?;

// deserialize &[Value] to Array
Expand Down
7 changes: 3 additions & 4 deletions src/io/ndjson/read/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,10 @@ pub fn infer<R: std::io::BufRead>(
///
/// # Implementation
/// This implementation infers each row by going through the entire iterator.
pub fn infer_iter<'a>(rows: impl Iterator<Item=Option<&'a str>>) -> Result<DataType>
{
pub fn infer_iter<A: AsRef<str>>(rows: impl Iterator<Item = A>) -> Result<DataType> {
let mut data_types = HashSet::new();
for row in rows.flatten() {
let v: Value = serde_json::from_str(row)?;
for row in rows {
let v: Value = serde_json::from_str(row.as_ref())?;
let data_type = infer_json(&v)?;
if data_type != DataType::Null {
data_types.insert(data_type);
Expand Down
6 changes: 4 additions & 2 deletions tests/it/io/ndjson/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,10 @@ fn utf8_array() -> Result<()> {
Some(r#"{"a": 2, "b": [{"c": 2}, {"c": 5}]}"#),
None,
]);
let data_type = ndjson_read::infer_iter(array.iter()).unwrap();
let new_array = ndjson_read::deserialize_iter(array.iter(), data_type).unwrap();
let data_type = ndjson_read::infer_iter(array.iter().map(|x| x.unwrap_or("null"))).unwrap();
let new_array =
ndjson_read::deserialize_iter(array.iter().map(|x| x.unwrap_or("null")), data_type)
.unwrap();

// Explicitly cast as StructArray
let new_array = new_array.as_any().downcast_ref::<StructArray>().unwrap();
Expand Down

0 comments on commit e9a1e30

Please sign in to comment.