diff --git a/src/io/json/read/infer_schema.rs b/src/io/json/read/infer_schema.rs index cef77dab210..150cd5f26de 100644 --- a/src/io/json/read/infer_schema.rs +++ b/src/io/json/read/infer_schema.rs @@ -12,6 +12,8 @@ use super::iterator::ValueIter; type Tracker = HashMap>; +const ITEM_NAME: &str = "item"; + /// Infers the fields of a JSON file by reading the first `number_of_rows` rows. /// # Examples /// ``` @@ -92,10 +94,12 @@ fn infer_value(value: &Value) -> Result { Value::Null => DataType::Null, Value::Number(number) => infer_number(number), Value::String(_) => DataType::Utf8, - Value::Object(_) => { - return Err(ArrowError::NotYetImplemented( - "Inferring schema from nested JSON structs currently not supported".to_string(), - )) + Value::Object(inner) => { + let fields = inner + .iter() + .map(|(key, value)| infer_value(value).map(|dt| Field::new(key, dt, true))) + .collect::>>()?; + DataType::Struct(fields) } }) } @@ -127,7 +131,7 @@ fn infer_array(array: &[Value]) -> Result { Ok(if !types.is_empty() { let types = types.into_iter().collect::>(); let dt = coerce_data_type(&types); - DataType::List(Box::new(Field::new("item", dt, true))) + DataType::List(Box::new(Field::new(ITEM_NAME, dt, true))) } else { DataType::Null }) @@ -168,29 +172,66 @@ fn resolve_fields(spec: HashMap>) -> Vec { /// Coerce an heterogeneous set of [`DataType`] into a single one. Rules: /// * `Int64` and `Float64` are `Float64` /// * Lists and scalars are coerced to a list of a compatible scalar +/// * Structs contain the union of all fields /// * All other types are coerced to `Utf8` -fn coerce_data_type>(dt: &[A]) -> DataType { +fn coerce_data_type>(datatypes: &[A]) -> DataType { use DataType::*; - if dt.len() == 1 { - return dt[0].borrow().clone(); - } else if dt.len() > 2 { - return List(Box::new(Field::new("item", Utf8, true))); + + let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow()); + + if are_all_equal { + return datatypes[0].borrow().clone(); + } + + let are_all_structs = datatypes.iter().all(|x| matches!(x.borrow(), Struct(_))); + + if are_all_structs { + // all are structs => union of all fields (that may have equal names) + let fields = datatypes.iter().fold(vec![], |mut acc, dt| { + if let Struct(new_fields) = dt.borrow() { + acc.extend(new_fields); + }; + acc + }); + // group fields by unique + let fields = fields.iter().fold( + HashMap::<&String, Vec<&DataType>>::new(), + |mut acc, field| { + match acc.entry(&field.name) { + indexmap::map::Entry::Occupied(mut v) => { + v.get_mut().push(&field.data_type); + } + indexmap::map::Entry::Vacant(v) => { + v.insert(vec![&field.data_type]); + } + } + acc + }, + ); + // and finally, coerce each of the fields within the same name + let fields = fields + .into_iter() + .map(|(name, dts)| Field::new(name, coerce_data_type(&dts), true)) + .collect(); + return Struct(fields); + } else if datatypes.len() > 2 { + return Utf8; } - let (lhs, rhs) = (dt[0].borrow(), dt[1].borrow()); + let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow()); return match (lhs, rhs) { (lhs, rhs) if lhs == rhs => lhs.clone(), (List(lhs), List(rhs)) => { let inner = coerce_data_type(&[lhs.data_type(), rhs.data_type()]); - List(Box::new(Field::new("item", inner, true))) + List(Box::new(Field::new(ITEM_NAME, inner, true))) } (scalar, List(list)) => { let inner = coerce_data_type(&[scalar, list.data_type()]); - List(Box::new(Field::new("item", inner, true))) + List(Box::new(Field::new(ITEM_NAME, inner, true))) } (List(list), scalar) => { let inner = coerce_data_type(&[scalar, list.data_type()]); - List(Box::new(Field::new("item", inner, true))) + List(Box::new(Field::new(ITEM_NAME, inner, true))) } (Float64, Int64) => Float64, (Int64, Float64) => Float64, @@ -209,21 +250,27 @@ mod test { use crate::datatypes::DataType::*; assert_eq!( - coerce_data_type(&[Float64, List(Box::new(Field::new("item", Float64, true)))]), - List(Box::new(Field::new("item", Float64, true))), + coerce_data_type(&[ + Float64, + List(Box::new(Field::new(ITEM_NAME, Float64, true))) + ]), + List(Box::new(Field::new(ITEM_NAME, Float64, true))), ); assert_eq!( - coerce_data_type(&[Float64, List(Box::new(Field::new("item", Int64, true)))]), - List(Box::new(Field::new("item", Float64, true))), + coerce_data_type(&[Float64, List(Box::new(Field::new(ITEM_NAME, Int64, true)))]), + List(Box::new(Field::new(ITEM_NAME, Float64, true))), ); assert_eq!( - coerce_data_type(&[Int64, List(Box::new(Field::new("item", Int64, true)))]), - List(Box::new(Field::new("item", Int64, true))), + coerce_data_type(&[Int64, List(Box::new(Field::new(ITEM_NAME, Int64, true)))]), + List(Box::new(Field::new(ITEM_NAME, Int64, true))), ); // boolean and number are incompatible, return utf8 assert_eq!( - coerce_data_type(&[Boolean, List(Box::new(Field::new("item", Float64, true)))]), - List(Box::new(Field::new("item", Utf8, true))), + coerce_data_type(&[ + Boolean, + List(Box::new(Field::new(ITEM_NAME, Float64, true))) + ]), + List(Box::new(Field::new(ITEM_NAME, Utf8, true))), ); } } diff --git a/tests/it/io/json/read.rs b/tests/it/io/json/read.rs index db973ee0d53..cab14a58380 100644 --- a/tests/it/io/json/read.rs +++ b/tests/it/io/json/read.rs @@ -139,11 +139,7 @@ fn infer_schema_mixed_list() -> Result<()> { DataType::List(Box::new(Field::new("item", DataType::Boolean, true))), true, ), - Field::new( - "d", - DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), - true, - ), + Field::new("d", DataType::Utf8, true), ]; let result = read::infer(&mut Cursor::new(data), None)?; @@ -151,3 +147,27 @@ fn infer_schema_mixed_list() -> Result<()> { assert_eq!(result, fields); Ok(()) } + +#[test] +fn infer_nested_struct() -> Result<()> { + let data = r#"{"a": {"a": 2.0, "b": 2}} + {"a": {"b": 2}} + {"a": {"a": 2.0, "b": 2, "c": true}} + {"a": {"a": 2.0, "b": 2}} + "#; + + let fields = vec![Field::new( + "a", + DataType::Struct(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Boolean, true), + ]), + true, + )]; + + let result = read::infer(&mut Cursor::new(data), None)?; + + assert_eq!(result, fields); + Ok(()) +}