Skip to content

Commit

Permalink
support decimal data type in create table (apache#1431)
Browse files Browse the repository at this point in the history
* support decimal data type in create table
  • Loading branch information
liukun4515 authored Dec 11, 2021
1 parent cb37855 commit dc80c11
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 3 deletions.
47 changes: 44 additions & 3 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => {
Ok(DataType::Utf8)
}
SQLDataType::Decimal(_, _) => Ok(DataType::Float64),
SQLDataType::Decimal(precision, scale) => {
match (precision, scale) {
(None, _) | (_, None) => {
return Err(DataFusionError::Internal(format!(
"Invalid Decimal type ({:?}), precision or scale can't be empty.",
sql_type
)));
}
(Some(p), Some(s)) => {
// TODO add bound checker in some utils file or function
if *p > 38 || *s > *p {
return Err(DataFusionError::Internal(format!(
"Error Decimal Type ({:?}), precision must be less than or equal to 38 and scale can't be greater than precision",
sql_type
)));
} else {
Ok(DataType::Decimal(*p as usize, *s as usize))
}
}
}
}
SQLDataType::Float(_) => Ok(DataType::Float32),
SQLDataType::Real => Ok(DataType::Float32),
SQLDataType::Double => Ok(DataType::Float64),
Expand Down Expand Up @@ -2022,8 +2042,8 @@ fn extract_possible_join_keys(
}

/// Convert SQL data type to relational representation of data type
pub fn convert_data_type(sql: &SQLDataType) -> Result<DataType> {
match sql {
pub fn convert_data_type(sql_type: &SQLDataType) -> Result<DataType> {
match sql_type {
SQLDataType::Boolean => Ok(DataType::Boolean),
SQLDataType::SmallInt(_) => Ok(DataType::Int16),
SQLDataType::Int(_) => Ok(DataType::Int32),
Expand All @@ -2034,6 +2054,27 @@ pub fn convert_data_type(sql: &SQLDataType) -> Result<DataType> {
SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8),
SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
SQLDataType::Date => Ok(DataType::Date32),
SQLDataType::Decimal(precision, scale) => {
match (precision, scale) {
(None, _) | (_, None) => {
return Err(DataFusionError::Internal(format!(
"Invalid Decimal type ({:?}), precision or scale can't be empty.",
sql_type
)));
}
(Some(p), Some(s)) => {
// TODO add bound checker in some utils file or function
if *p > 38 || *s > *p {
return Err(DataFusionError::Internal(format!(
"Error Decimal Type ({:?})",
sql_type
)));
} else {
Ok(DataType::Decimal(*p as usize, *s as usize))
}
}
}
}
other => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL type {:?}",
other
Expand Down
53 changes: 53 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3761,6 +3761,28 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
Ok(())
}

async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) {
let df = ctx
.sql(
"CREATE EXTERNAL TABLE aggregate_simple (
c1 DECIMAL(10,6) NOT NULL,
c2 DOUBLE NOT NULL,
c3 BOOLEAN NOT NULL
)
STORED AS CSV
WITH HEADER ROW
LOCATION 'tests/aggregate_simple.csv'",
)
.await
.expect("Creating dataframe for CREATE EXTERNAL TABLE with decimal data type");

let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE");
assert!(
results.is_empty(),
"Expected no rows from executing CREATE EXTERNAL TABLE"
);
}

async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> {
// It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats
let schema = Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -6459,3 +6481,34 @@ async fn test_select_wildcard_without_table() -> Result<()> {
}
Ok(())
}

#[tokio::test]
async fn csv_query_with_decimal_by_sql() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await;
let sql = "SELECT c1 from aggregate_simple";
let actual = execute_to_batches(&mut ctx, sql).await;
let expected = vec![
"+----------+",
"| c1 |",
"+----------+",
"| 0.000010 |",
"| 0.000020 |",
"| 0.000020 |",
"| 0.000030 |",
"| 0.000030 |",
"| 0.000030 |",
"| 0.000040 |",
"| 0.000040 |",
"| 0.000040 |",
"| 0.000040 |",
"| 0.000050 |",
"| 0.000050 |",
"| 0.000050 |",
"| 0.000050 |",
"| 0.000050 |",
"+----------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
}

0 comments on commit dc80c11

Please sign in to comment.