Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improved inference and deserialization of CSV #483

Merged
merged 1 commit into from
Oct 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ full = [
]
merge_sort = ["itertools"]
io_csv = ["io_csv_read", "io_csv_write"]
io_csv_read = ["csv", "lazy_static", "regex", "lexical-core"]
io_csv_read = ["csv", "lexical-core"]
io_csv_write = ["csv", "streaming-iterator", "lexical-core"]
io_json = ["serde", "serde_json", "indexmap"]
io_ipc = ["flatbuffers"]
Expand Down
54 changes: 41 additions & 13 deletions src/io/csv/read/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ use crate::{
datatypes::*,
error::{ArrowError, Result},
record_batch::RecordBatch,
temporal_conversions::EPOCH_DAYS_FROM_CE,
temporal_conversions,
types::{NativeType, NaturalDataType},
};

use super::infer_schema::RFC3339;

fn deserialize_primitive<T, F>(
rows: &[ByteRecord],
column: usize,
Expand Down Expand Up @@ -63,6 +65,22 @@ fn deserialize_binary<O: Offset>(rows: &[ByteRecord], column: usize) -> Arc<dyn
Arc::new(BinaryArray::<O>::from_trusted_len_iter(iter))
}

#[inline]
fn deserialize_datetime<T: chrono::TimeZone>(string: &str, tz: &T) -> Option<i64> {
let mut parsed = chrono::format::Parsed::new();
let fmt = chrono::format::StrftimeItems::new(RFC3339);
if chrono::format::parse(&mut parsed, string, fmt).is_ok() {
parsed
.to_datetime()
.map(|x| x.naive_utc())
.map(|x| tz.from_utc_datetime(&x))
.map(|x| x.timestamp_nanos())
.ok()
} else {
None
}
}

/// Deserializes `column` of `rows` into an [`Array`] of [`DataType`] `datatype`.
pub fn deserialize_column(
rows: &[ByteRecord],
Expand Down Expand Up @@ -115,7 +133,7 @@ pub fn deserialize_column(
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDate>().ok())
.map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
.map(|x| x.num_days_from_ce() - temporal_conversions::EPOCH_DAYS_FROM_CE)
}),
Date64 => deserialize_primitive(rows, column, datatype, |bytes| {
simdutf8::basic::from_utf8(bytes)
Expand All @@ -139,20 +157,30 @@ pub fn deserialize_column(
.map(|x| x.timestamp_nanos() / 1000)
})
}
Timestamp(TimeUnit::Millisecond, None) => {
deserialize_primitive(rows, column, datatype, |bytes| {
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1_000_000)
})
}
Timestamp(TimeUnit::Second, None) => {
Timestamp(time_unit, None) => deserialize_primitive(rows, column, datatype, |bytes| {
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos())
.map(|x| match time_unit {
TimeUnit::Second => x / 1_000_000_000,
TimeUnit::Millisecond => x / 1_000_000,
TimeUnit::Microsecond => x / 1_000,
TimeUnit::Nanosecond => x,
})
}),
Timestamp(time_unit, Some(ref tz)) => {
let tz = temporal_conversions::parse_offset(tz)?;
deserialize_primitive(rows, column, datatype, |bytes| {
simdutf8::basic::from_utf8(bytes)
.ok()
.and_then(|x| x.parse::<chrono::NaiveDateTime>().ok())
.map(|x| x.timestamp_nanos() / 1_000_000_000)
.and_then(|x| deserialize_datetime(x, &tz))
.map(|x| match time_unit {
TimeUnit::Second => x / 1_000_000_000,
TimeUnit::Millisecond => x / 1_000_000,
TimeUnit::Microsecond => x / 1_000,
TimeUnit::Nanosecond => x,
})
})
}
Utf8 => deserialize_utf8::<i32>(rows, column),
Expand Down
134 changes: 96 additions & 38 deletions src/io/csv/read/infer_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,82 @@ use std::{
io::{Read, Seek},
};

use super::Reader;
use csv::StringRecord;
use super::{ByteRecord, Reader};

use crate::datatypes::DataType;
use crate::datatypes::{DataType, TimeUnit};
use crate::datatypes::{Field, Schema};
use crate::error::Result;

/// Infer the schema of a CSV file by reading through the first n records of the file,
/// with `max_rows` controlling the maximum number of records to read.
///
/// If `max_rows` is not set, the whole file is read to infer its schema.
pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z";

fn is_boolean(bytes: &[u8]) -> bool {
bytes.eq_ignore_ascii_case(b"true") | bytes.eq_ignore_ascii_case(b"false")
}

fn is_float(bytes: &[u8]) -> bool {
lexical_core::parse::<f64>(bytes).is_ok()
}

fn is_integer(bytes: &[u8]) -> bool {
lexical_core::parse::<i64>(bytes).is_ok()
}

fn is_date(string: &str) -> bool {
string.parse::<chrono::NaiveDate>().is_ok()
}

fn is_time(string: &str) -> bool {
string.parse::<chrono::NaiveTime>().is_ok()
}

fn is_naive_datetime(string: &str) -> bool {
string.parse::<chrono::NaiveDateTime>().is_ok()
}

fn is_datetime(string: &str) -> Option<String> {
let mut parsed = chrono::format::Parsed::new();
let fmt = chrono::format::StrftimeItems::new(RFC3339);
if chrono::format::parse(&mut parsed, string, fmt).is_ok() {
parsed.offset.map(|x| {
let hours = x / 60 / 60;
let minutes = x / 60 - hours * 60;
format!("{:03}:{:02}", hours, minutes)
})
} else {
None
}
}

/// Infers [`DataType`] from `bytes`
pub fn infer(bytes: &[u8]) -> DataType {
if is_boolean(bytes) {
DataType::Boolean
} else if is_integer(bytes) {
DataType::Int64
} else if is_float(bytes) {
DataType::Float64
} else if let Ok(string) = simdutf8::basic::from_utf8(bytes) {
if is_date(string) {
DataType::Date32
} else if is_time(string) {
DataType::Time32(TimeUnit::Millisecond)
} else if is_naive_datetime(string) {
DataType::Timestamp(TimeUnit::Millisecond, None)
} else if let Some(offset) = is_datetime(string) {
DataType::Timestamp(TimeUnit::Millisecond, Some(offset))
} else {
DataType::Utf8
}
} else {
// invalid utf8
DataType::Binary
}
}

/// Infer the schema of a CSV file by reading through the first n records up to `max_rows`.
///
/// Return infered schema and number of records used for inference.
pub fn infer_schema<R: Read + Seek, F: Fn(&str) -> DataType>(
pub fn infer_schema<R: Read + Seek, F: Fn(&[u8]) -> DataType>(
reader: &mut Reader<R>,
max_rows: Option<usize>,
has_header: bool,
Expand All @@ -25,8 +87,7 @@ pub fn infer_schema<R: Read + Seek, F: Fn(&str) -> DataType>(
// get or create header names
// when has_header is false, creates default column names with column_ prefix
let headers: Vec<String> = if has_header {
let headers = &reader.headers()?.clone();
headers.iter().map(|s| s.to_string()).collect()
reader.headers()?.iter().map(|s| s.to_string()).collect()
} else {
let first_record_count = &reader.headers()?.len();
(0..*first_record_count)
Expand All @@ -42,12 +103,11 @@ pub fn infer_schema<R: Read + Seek, F: Fn(&str) -> DataType>(
let mut column_types: Vec<HashSet<DataType>> = vec![HashSet::new(); header_length];

let mut records_count = 0;
let mut fields = vec![];

let mut record = StringRecord::new();
let mut record = ByteRecord::new();
let max_records = max_rows.unwrap_or(usize::MAX);
while records_count < max_records {
if !reader.read_record(&mut record)? {
if !reader.read_byte_record(&mut record)? {
break;
}
records_count += 1;
Expand All @@ -60,32 +120,30 @@ pub fn infer_schema<R: Read + Seek, F: Fn(&str) -> DataType>(
}

// build schema from inference results
for i in 0..header_length {
let possibilities = &column_types[i];
let field_name = &headers[i];

// determine data type based on possible types
// if there are incompatible types, use DataType::Utf8
match possibilities.len() {
1 => {
for dtype in possibilities.iter() {
fields.push(Field::new(field_name, dtype.clone(), true));
}
}
2 => {
if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64)
{
// we have an integer and double, fall down to double
fields.push(Field::new(field_name, DataType::Float64, true));
} else {
// default to Utf8 for conflicting datatypes (e.g bool and int)
fields.push(Field::new(field_name, DataType::Utf8, true));
let fields = headers
.iter()
.zip(column_types.into_iter())
.map(|(field_name, mut possibilities)| {
// determine data type based on possible types
// if there are incompatible types, use DataType::Utf8
let data_type = match possibilities.len() {
1 => possibilities.drain().next().unwrap(),
2 => {
if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64)
{
// we have an integer and double, fall down to double
DataType::Float64
} else {
// default to Utf8 for conflicting datatypes (e.g bool and int)
DataType::Utf8
}
}
}
_ => fields.push(Field::new(field_name, DataType::Utf8, true)),
}
}
_ => DataType::Utf8,
};
Field::new(field_name, data_type, true)
})
.collect();

// return the reader seek back to the start
reader.seek(position)?;
Expand Down
2 changes: 1 addition & 1 deletion src/io/csv/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ pub use csv::{ByteRecord, Reader, ReaderBuilder};
mod infer_schema;

pub use deserialize::{deserialize_batch, deserialize_column};
pub use infer_schema::infer_schema;
pub use infer_schema::{infer, infer_schema};
pub use reader::*;
37 changes: 0 additions & 37 deletions src/io/csv/read/reader.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::io::Read;

use lazy_static::lazy_static;
use regex::{Regex, RegexBuilder};

use super::{ByteRecord, Reader};

use crate::{
Expand Down Expand Up @@ -52,37 +49,3 @@ pub fn read_rows<R: Read>(
}
Ok(row_number)
}

lazy_static! {
static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d+\.\d+)$").unwrap();
static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap();
static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$")
.case_insensitive(true)
.build()
.unwrap();
static ref DATE_RE: Regex = Regex::new(r"^\d{4}-\d\d-\d\d$").unwrap();
static ref DATETIME_RE: Regex = Regex::new(r"^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d$").unwrap();
}

/// Infer the data type of a record
pub fn infer(string: &str) -> DataType {
// when quoting is enabled in the reader, these quotes aren't escaped, we default to
// Utf8 for them
if string.starts_with('"') {
return DataType::Utf8;
}
// match regex in a particular order
if BOOLEAN_RE.is_match(string) {
DataType::Boolean
} else if DECIMAL_RE.is_match(string) {
DataType::Float64
} else if INTEGER_RE.is_match(string) {
DataType::Int64
} else if DATETIME_RE.is_match(string) {
DataType::Date64
} else if DATE_RE.is_match(string) {
DataType::Date32
} else {
DataType::Utf8
}
}
42 changes: 41 additions & 1 deletion tests/it/io/csv/read.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use proptest::prelude::*;

use std::io::Cursor;
use std::sync::Arc;

Expand Down Expand Up @@ -175,7 +177,7 @@ fn float32() -> Result<()> {
}

#[test]
fn binary() -> Result<()> {
fn deserialize_binary() -> Result<()> {
let input = vec!["aa", "bb"];
let input = input.join("\n");

Expand All @@ -185,3 +187,41 @@ fn binary() -> Result<()> {
assert_eq!(expected, result.as_ref());
Ok(())
}

#[test]
fn deserialize_timestamp() -> Result<()> {
let input = vec!["1996-12-19T16:34:57-02:00", "1996-12-19T16:34:58-02:00"];
let input = input.join("\n");

let data_type = DataType::Timestamp(TimeUnit::Millisecond, Some("-01:00".to_string()));

let expected = Int64Array::from([Some(851020497000), Some(851020498000)]).to(data_type.clone());

let result = test_deserialize(&input, data_type)?;
assert_eq!(expected, result.as_ref());
Ok(())
}

proptest! {
#[test]
#[cfg_attr(miri, ignore)] // miri and proptest do not work well :(
fn i64(v in any::<i64>()) {
assert_eq!(infer(v.to_string().as_bytes()), DataType::Int64);
}
}

proptest! {
#[test]
#[cfg_attr(miri, ignore)] // miri and proptest do not work well :(
fn utf8(v in "a.*") {
assert_eq!(infer(v.as_bytes()), DataType::Utf8);
}
}

proptest! {
#[test]
#[cfg_attr(miri, ignore)] // miri and proptest do not work well :(
fn dates(v in "1996-12-19T16:3[0-9]:57-02:00") {
assert_eq!(infer(v.as_bytes()), DataType::Timestamp(TimeUnit::Millisecond, Some("-02:00".to_string())));
}
}