Skip to content

Commit

Permalink
Support read decimal data from csv reader if user provide the schema …
Browse files Browse the repository at this point in the history
…with decimal data type (#941) (#974)

* support decimal data type for csv reader

* format code and fix lint check

* fix the clippy error

* enchance the parse csv to decimal and add more test

Co-authored-by: Kun Liu <[email protected]>
  • Loading branch information
alamb and liukun4515 authored Nov 24, 2021
1 parent 4fa0d4d commit 6c570cf
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 4 deletions.
4 changes: 2 additions & 2 deletions arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ pub struct FixedSizeBinaryBuilder {
builder: FixedSizeListBuilder<UInt8Builder>,
}

const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
9,
99,
999,
Expand Down Expand Up @@ -1158,7 +1158,7 @@ const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
9999999999999999999999999999999999999,
170141183460469231731687303715884105727,
];
const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
-9,
-99,
-999,
Expand Down
2 changes: 2 additions & 0 deletions arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ pub use self::builder::StringBuilder;
pub use self::builder::StringDictionaryBuilder;
pub use self::builder::StructBuilder;
pub use self::builder::UnionBuilder;
pub use self::builder::MAX_DECIMAL_FOR_EACH_PRECISION;
pub use self::builder::MIN_DECIMAL_FOR_EACH_PRECISION;

pub type Int8Builder = PrimitiveBuilder<Int8Type>;
pub type Int16Builder = PrimitiveBuilder<Int16Type>;
Expand Down
263 changes: 261 additions & 2 deletions arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,20 @@ use std::io::{Read, Seek, SeekFrom};
use std::sync::Arc;

use crate::array::{
ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray,
ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray, StringArray,
MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;

use csv_crate::{ByteRecord, StringRecord};
use std::ops::Neg;

lazy_static! {
static ref PARSE_DECIMAL_RE: Regex =
Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap();
static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d*\.\d+|\d+\.\d*)$").unwrap();
static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap();
static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$")
Expand Down Expand Up @@ -99,7 +103,7 @@ fn infer_field_schema(string: &str) -> DataType {
///
/// If `max_read_records` is not set, the whole file is read to infer its schema.
///
/// Return infered schema and number of records used for inference. This function does not change
/// Return inferred schema and number of records used for inference. This function does not change
/// reader cursor offset.
pub fn infer_file_schema<R: Read + Seek>(
reader: &mut R,
Expand Down Expand Up @@ -513,6 +517,9 @@ fn parse(
let field = &fields[i];
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i),
DataType::Decimal(precision, scale) => {
build_decimal_array(line_number, rows, i, *precision, *scale)
}
DataType::Int8 => build_primitive_array::<Int8Type>(line_number, rows, i),
DataType::Int16 => {
build_primitive_array::<Int16Type>(line_number, rows, i)
Expand Down Expand Up @@ -728,6 +735,161 @@ fn parse_bool(string: &str) -> Option<bool> {
}
}

// parse the column string to an Arrow Array
fn build_decimal_array(
_line_number: usize,
rows: &[StringRecord],
col_idx: usize,
precision: usize,
scale: usize,
) -> Result<ArrayRef> {
let mut decimal_builder = DecimalBuilder::new(rows.len(), precision, scale);
for row in rows {
let col_s = row.get(col_idx);
match col_s {
None => {
// No data for this row
decimal_builder.append_null()?;
}
Some(s) => {
if s.is_empty() {
// append null
decimal_builder.append_null()?;
} else {
let decimal_value: Result<i128> =
parse_decimal_with_parameter(s, precision, scale);
match decimal_value {
Ok(v) => {
decimal_builder.append_value(v)?;
}
Err(e) => {
return Err(e);
}
}
}
}
}
}
Ok(Arc::new(decimal_builder.finish()))
}

// Parse the string format decimal value to i128 format and checking the precision and scale.
// The result i128 value can't be out of bounds.
fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Result<i128> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
// each byte is digit、'-' or '.'
let mut base = 1;

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
match delimiter_position {
None => {
// there is no '.'
base = 10_i128.pow(scale as u32);
}
Some(mid) => {
// there is the '.'
if len - mid >= scale + 1 {
// If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value.
offset -= len - mid - 1 - scale;
} else {
// If the string value is "123.12" and the scale is 4, we should append '00' to the tail.
base = 10_i128.pow((scale + 1 + mid - len) as u32);
}
}
};

let bytes = s.as_bytes();
let mut negative = false;
let mut result: i128 = 0;

while offset > 0 {
match bytes[offset - 1] {
b'-' => {
negative = true;
}
b'.' => {
// do nothing
}
b'0'..=b'9' => {
result += i128::from(bytes[offset - 1] - b'0') * base;
base *= 10;
}
_ => {
return Err(ArrowError::ParseError(format!(
"can't match byte {}",
bytes[offset - 1]
)));
}
}
offset -= 1;
}
if negative {
result = result.neg();
}
if result > MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1]
|| result < MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1]
{
return Err(ArrowError::ParseError(format!(
"parse decimal overflow, the precision {}, the scale {}, the value {}",
precision, scale, s
)));
}
Ok(result)
} else {
Err(ArrowError::ParseError(format!(
"can't parse the string value {} to decimal",
s
)))
}
}

// Parse the string format decimal value to i128 format without checking the precision and scale.
// Like "125.12" to 12512_i128.
fn parse_decimal(s: &str) -> Result<i128> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
// each byte is digit、'-' or '.'
let bytes = s.as_bytes();
let mut negative = false;
let mut result: i128 = 0;
let mut base = 1;
while offset > 0 {
match bytes[offset - 1] {
b'-' => {
negative = true;
}
b'.' => {
// do nothing
}
b'0'..=b'9' => {
result += i128::from(bytes[offset - 1] - b'0') * base;
base *= 10;
}
_ => {
return Err(ArrowError::ParseError(format!(
"can't match byte {}",
bytes[offset - 1]
)));
}
}
offset -= 1;
}
if negative {
Ok(result.neg())
} else {
Ok(result)
}
} else {
Err(ArrowError::ParseError(format!(
"can't parse the string value {} to decimal",
s
)))
}
}

// parses a specific column (col_idx) into an Arrow Array.
fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
Expand Down Expand Up @@ -1055,6 +1217,37 @@ mod tests {
assert_eq!(&metadata, batch.schema().metadata());
}

#[test]
fn test_csv_reader_with_decimal() {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Decimal(26, 6), false),
Field::new("lng", DataType::Decimal(26, 6), false),
]);

let file = File::open("test/data/decimal_test.csv").unwrap();

let mut csv = Reader::new(file, Arc::new(schema), false, None, 1024, None, None);
let batch = csv.next().unwrap().unwrap();
// access data from a primitive array
let lat = batch
.column(1)
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();

assert_eq!("57.653484", lat.value_as_string(0));
assert_eq!("53.002666", lat.value_as_string(1));
assert_eq!("52.412811", lat.value_as_string(2));
assert_eq!("51.481583", lat.value_as_string(3));
assert_eq!("12.123456", lat.value_as_string(4));
assert_eq!("50.760000", lat.value_as_string(5));
assert_eq!("0.123000", lat.value_as_string(6));
assert_eq!("123.000000", lat.value_as_string(7));
assert_eq!("123.000000", lat.value_as_string(8));
assert_eq!("-50.760000", lat.value_as_string(9));
}

#[test]
fn test_csv_from_buf_reader() {
let schema = Schema::new(vec![
Expand Down Expand Up @@ -1348,6 +1541,8 @@ mod tests {
assert_eq!(infer_field_schema("false"), DataType::Boolean);
assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
assert_eq!(infer_field_schema("2020-11-08T14:20:01"), DataType::Date64);
assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
}

#[test]
Expand All @@ -1374,6 +1569,70 @@ mod tests {
);
}

#[test]
fn test_parse_decimal() {
let tests = [
("123.00", 12300i128),
("123.123", 123123i128),
("0.0123", 123i128),
("0.12300", 12300i128),
("-5.123", -5123i128),
("-45.432432", -45432432i128),
];
for (s, i) in tests {
let result = parse_decimal(s);
assert_eq!(i, result.unwrap());
}
}

#[test]
fn test_parse_decimal_with_parameter() {
let tests = [
("123.123", 123123i128),
("123.1234", 123123i128),
("123.1", 123100i128),
("123", 123000i128),
("-123.123", -123123i128),
("-123.1234", -123123i128),
("-123.1", -123100i128),
("-123", -123000i128),
("0.0000123", 0i128),
("12.", 12000i128),
("-12.", -12000i128),
("00.1", 100i128),
("-00.1", -100i128),
("12345678912345678.1234", 12345678912345678123i128),
("-12345678912345678.1234", -12345678912345678123i128),
("99999999999999999.999", 99999999999999999999i128),
("-99999999999999999.999", -99999999999999999999i128),
(".123", 123i128),
("-.123", -123i128),
("123.", 123000i128),
("-123.", -123000i128),
];
for (s, i) in tests {
let result = parse_decimal_with_parameter(s, 20, 3);
assert_eq!(i, result.unwrap())
}
let can_not_parse_tests = ["123,123", "."];
for s in can_not_parse_tests {
let result = parse_decimal_with_parameter(s, 20, 3);
assert_eq!(
format!(
"Parser error: can't parse the string value {} to decimal",
s
),
result.unwrap_err().to_string()
);
}
let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
for s in overflow_parse_tests {
let result = parse_decimal_with_parameter(s, 10, 3);
assert_eq!(format!(
"Parser error: parse decimal overflow, the precision {}, the scale {}, the value {}", 10,3, s),result.unwrap_err().to_string());
}
}

/// Interprets a naive_datetime (with no explicit timezone offset)
/// using the local timezone and returns the timestamp in UTC (0
/// offset)
Expand Down
10 changes: 10 additions & 0 deletions arrow/test/data/decimal_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"Elgin, Scotland, the UK",57.653484,-3.335724
"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404
"Solihull, Birmingham, UK",52.412811,-1.778197
"Cardiff, Cardiff county, UK",51.481583,-3.179090
"Cardiff, Cardiff county, UK",12.12345678,-3.179090
"Eastbourne, East Sussex, UK",50.76,0.290472
"Eastbourne, East Sussex, UK",.123,0.290472
"Eastbourne, East Sussex, UK",123.,0.290472
"Eastbourne, East Sussex, UK",123,0.290472
"Eastbourne, East Sussex, UK",-50.76,0.290472

0 comments on commit 6c570cf

Please sign in to comment.