diff --git a/datafusion/src/physical_plan/datetime_expressions.rs b/datafusion/src/physical_plan/datetime_expressions.rs index 59a93481cda3..7e965f6b6c56 100644 --- a/datafusion/src/physical_plan/datetime_expressions.rs +++ b/datafusion/src/physical_plan/datetime_expressions.rs @@ -28,6 +28,7 @@ use arrow::{ array::*, compute::cast, datatypes::{DataType, TimeUnit}, + scalar::PrimitiveScalar, types::NativeType, }; use arrow::{compute::temporal, temporal_conversions::timestamp_ns_to_datetime}; @@ -35,6 +36,7 @@ use chrono::prelude::{DateTime, Utc}; use chrono::Datelike; use chrono::Duration; use chrono::Timelike; +use std::convert::TryInto; /// given a function `op` that maps a `&str` to a Result of an arrow native type, /// returns a `PrimitiveArray` after the application @@ -81,7 +83,7 @@ where // given an function that maps a `&str` to a arrow native type, // returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` // depending on the `args`'s variant. -fn handle<'a, O, F, S>( +fn handle<'a, O, F>( args: &'a [ColumnarValue], op: F, name: &str, @@ -90,7 +92,6 @@ fn handle<'a, O, F, S>( where O: NativeType, ScalarValue: From>, - S: NativeType, F: Fn(&'a str) -> Result, { match &args[0] { @@ -117,14 +118,13 @@ where ))), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(result.into())) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(result.into())) - } + ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(match a { + Some(s) => { + let s = PrimitiveScalar::::new(data_type, Some((op)(s)?)); + ColumnarValue::Scalar(s.try_into()?) + } + None => ColumnarValue::Scalar(ScalarValue::new_null(data_type)), + }), other => Err(DataFusionError::Internal(format!( "Unsupported data type {:?} for function {}", other, name @@ -140,7 +140,7 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result { /// to_timestamp SQL function pub fn to_timestamp(args: &[ColumnarValue]) -> Result { - handle::( + handle::( args, string_to_timestamp_nanos_shim, "to_timestamp", @@ -150,7 +150,7 @@ pub fn to_timestamp(args: &[ColumnarValue]) -> Result { /// to_timestamp_millis SQL function pub fn to_timestamp_millis(args: &[ColumnarValue]) -> Result { - handle::( + handle::( args, |s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000), "to_timestamp_millis", @@ -160,7 +160,7 @@ pub fn to_timestamp_millis(args: &[ColumnarValue]) -> Result { /// to_timestamp_micros SQL function pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result { - handle::( + handle::( args, |s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000), "to_timestamp_micros", @@ -170,7 +170,7 @@ pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result { /// to_timestamp_seconds SQL function pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result { - handle::( + handle::( args, |s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000_000), "to_timestamp_seconds", diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index c9fe567253bf..52ce6d3ad311 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -390,7 +390,7 @@ pub enum Distribution { } /// Represents the result from an expression -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum ColumnarValue { /// Array of values Array(ArrayRef), diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index bbe951fa53af..84af5528b825 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -20,12 +20,12 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; use crate::error::{DataFusionError, Result}; -use arrow::scalar::Scalar; use arrow::{ array::*, buffer::MutableBuffer, datatypes::{DataType, Field, IntervalUnit, TimeUnit}, - types::days_ms, + scalar::{PrimitiveScalar, Scalar}, + types::{days_ms, NativeType}, }; use ordered_float::OrderedFloat; use std::cmp::Ordering; @@ -421,6 +421,25 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Create null scalar value for specific data type. + pub fn new_null(dt: DataType) -> Self { + match dt { + DataType::Timestamp(TimeUnit::Second, _) => { + ScalarValue::TimestampSecond(None) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + ScalarValue::TimestampMillisecond(None) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + ScalarValue::TimestampMicrosecond(None) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + ScalarValue::TimestampNanosecond(None) + } + _ => todo!("Create null scalar value for datatype: {:?}", dt), + } + } + /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { match self { @@ -1272,6 +1291,35 @@ impl TryInto> for &ScalarValue { } } +impl TryFrom> for ScalarValue { + type Error = DataFusionError; + + fn try_from(s: PrimitiveScalar) -> Result { + match s.data_type() { + DataType::Timestamp(TimeUnit::Second, _) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampSecond(Some(s.value()))) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampMicrosecond(Some(s.value()))) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampMillisecond(Some(s.value()))) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let s = s.as_any().downcast_ref::>().unwrap(); + Ok(ScalarValue::TimestampNanosecond(Some(s.value()))) + } + _ => Err(DataFusionError::Internal( + format!( + "Conversion from arrow Scalar to Datafusion ScalarValue not implemented for: {:?}", s)) + ), + } + } +} + impl TryFrom<&DataType> for ScalarValue { type Error = DataFusionError;