Skip to content

Commit

Permalink
fix str to timestamp scalarvalue casting
Browse files Browse the repository at this point in the history
  • Loading branch information
houqp committed Sep 25, 2021
1 parent 4030615 commit fde82cf
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 17 deletions.
28 changes: 14 additions & 14 deletions datafusion/src/physical_plan/datetime_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ use arrow::{
array::*,
compute::cast,
datatypes::{DataType, TimeUnit},
scalar::PrimitiveScalar,
types::NativeType,
};
use arrow::{compute::temporal, temporal_conversions::timestamp_ns_to_datetime};
use chrono::prelude::{DateTime, Utc};
use chrono::Datelike;
use chrono::Duration;
use chrono::Timelike;
use std::convert::TryInto;

/// given a function `op` that maps a `&str` to a Result of an arrow native type,
/// returns a `PrimitiveArray` after the application
Expand Down Expand Up @@ -81,7 +83,7 @@ where
// given an function that maps a `&str` to a arrow native type,
// returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue`
// depending on the `args`'s variant.
fn handle<'a, O, F, S>(
fn handle<'a, O, F>(
args: &'a [ColumnarValue],
op: F,
name: &str,
Expand All @@ -90,7 +92,6 @@ fn handle<'a, O, F, S>(
where
O: NativeType,
ScalarValue: From<Option<O>>,
S: NativeType,
F: Fn(&'a str) -> Result<O>,
{
match &args[0] {
Expand All @@ -117,14 +118,13 @@ where
))),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) => {
let result = a.as_ref().map(|x| (op)(x)).transpose()?;
Ok(ColumnarValue::Scalar(result.into()))
}
ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x)).transpose()?;
Ok(ColumnarValue::Scalar(result.into()))
}
ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(match a {
Some(s) => {
let s = PrimitiveScalar::<O>::new(data_type, Some((op)(s)?));
ColumnarValue::Scalar(s.try_into()?)
}
None => ColumnarValue::Scalar(ScalarValue::new_null(data_type)),
}),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function {}",
other, name
Expand All @@ -140,7 +140,7 @@ fn string_to_timestamp_nanos_shim(s: &str) -> Result<i64> {

/// to_timestamp SQL function
pub fn to_timestamp(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle::<i64, _, i64>(
handle::<i64, _>(
args,
string_to_timestamp_nanos_shim,
"to_timestamp",
Expand All @@ -150,7 +150,7 @@ pub fn to_timestamp(args: &[ColumnarValue]) -> Result<ColumnarValue> {

/// to_timestamp_millis SQL function
pub fn to_timestamp_millis(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle::<i64, _, i64>(
handle::<i64, _>(
args,
|s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000),
"to_timestamp_millis",
Expand All @@ -160,7 +160,7 @@ pub fn to_timestamp_millis(args: &[ColumnarValue]) -> Result<ColumnarValue> {

/// to_timestamp_micros SQL function
pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle::<i64, _, i64>(
handle::<i64, _>(
args,
|s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000),
"to_timestamp_micros",
Expand All @@ -170,7 +170,7 @@ pub fn to_timestamp_micros(args: &[ColumnarValue]) -> Result<ColumnarValue> {

/// to_timestamp_seconds SQL function
pub fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle::<i64, _, i64>(
handle::<i64, _>(
args,
|s| string_to_timestamp_nanos_shim(s).map(|n| n / 1_000_000_000),
"to_timestamp_seconds",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ pub enum Distribution {
}

/// Represents the result from an expression
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum ColumnarValue {
/// Array of values
Array(ArrayRef),
Expand Down
52 changes: 50 additions & 2 deletions datafusion/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

use crate::error::{DataFusionError, Result};
use arrow::scalar::Scalar;
use arrow::{
array::*,
buffer::MutableBuffer,
datatypes::{DataType, Field, IntervalUnit, TimeUnit},
types::days_ms,
scalar::{PrimitiveScalar, Scalar},
types::{days_ms, NativeType},
};
use ordered_float::OrderedFloat;
use std::cmp::Ordering;
Expand Down Expand Up @@ -421,6 +421,25 @@ macro_rules! eq_array_primitive {
}

impl ScalarValue {
/// Create null scalar value for specific data type.
pub fn new_null(dt: DataType) -> Self {
match dt {
DataType::Timestamp(TimeUnit::Second, _) => {
ScalarValue::TimestampSecond(None)
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
ScalarValue::TimestampMillisecond(None)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
ScalarValue::TimestampMicrosecond(None)
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
ScalarValue::TimestampNanosecond(None)
}
_ => todo!("Create null scalar value for datatype: {:?}", dt),
}
}

/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
match self {
Expand Down Expand Up @@ -1272,6 +1291,35 @@ impl TryInto<Box<dyn Scalar>> for &ScalarValue {
}
}

impl<T: NativeType> TryFrom<PrimitiveScalar<T>> for ScalarValue {
type Error = DataFusionError;

fn try_from(s: PrimitiveScalar<T>) -> Result<ScalarValue> {
match s.data_type() {
DataType::Timestamp(TimeUnit::Second, _) => {
let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
Ok(ScalarValue::TimestampSecond(Some(s.value())))
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
Ok(ScalarValue::TimestampMicrosecond(Some(s.value())))
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
Ok(ScalarValue::TimestampMillisecond(Some(s.value())))
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
let s = s.as_any().downcast_ref::<PrimitiveScalar<i64>>().unwrap();
Ok(ScalarValue::TimestampNanosecond(Some(s.value())))
}
_ => Err(DataFusionError::Internal(
format!(
"Conversion from arrow Scalar to Datafusion ScalarValue not implemented for: {:?}", s))
),
}
}
}

impl TryFrom<&DataType> for ScalarValue {
type Error = DataFusionError;

Expand Down

0 comments on commit fde82cf

Please sign in to comment.