-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorporate dyn scalar kernels #1685
Changes from 2 commits
0e412eb
1e74cb5
8fe0e5d
402e2ba
2c21230
c82ce3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,97 +16,98 @@ | |
# under the License. | ||
|
||
[package] | ||
name = "datafusion" | ||
authors = ["Apache Arrow <[email protected]>"] | ||
description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" | ||
version = "6.0.0" | ||
edition = "2021" | ||
homepage = "https://github.com/apache/arrow-datafusion" | ||
repository = "https://github.com/apache/arrow-datafusion" | ||
readme = "../README.md" | ||
authors = ["Apache Arrow <[email protected]>"] | ||
license = "Apache-2.0" | ||
keywords = [ "arrow", "query", "sql" ] | ||
include = [ | ||
"benches/*.rs", | ||
"src/**/*.rs", | ||
"Cargo.toml", | ||
"benches/*.rs", | ||
"src/**/*.rs", | ||
"Cargo.toml", | ||
] | ||
edition = "2021" | ||
keywords = ["arrow", "query", "sql"] | ||
license = "Apache-2.0" | ||
name = "datafusion" | ||
readme = "../README.md" | ||
repository = "https://github.com/apache/arrow-datafusion" | ||
rust-version = "1.58" | ||
version = "6.0.0" | ||
|
||
[lib] | ||
name = "datafusion" | ||
path = "src/lib.rs" | ||
|
||
[features] | ||
default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] | ||
simd = ["arrow/simd"] | ||
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] | ||
default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] | ||
pyarrow = ["pyo3", "arrow/pyarrow"] | ||
regex_expressions = ["regex"] | ||
simd = ["arrow/simd"] | ||
unicode_expressions = ["unicode-segmentation"] | ||
pyarrow = ["pyo3", "arrow/pyarrow"] | ||
# Used for testing ONLY: causes all values to hash to the same value (test for collisions) | ||
force_hash_collisions = [] | ||
# Used to enable the avro format | ||
avro = ["avro-rs", "num-traits"] | ||
|
||
[dependencies] | ||
ahash = { version = "0.7", default-features = false } | ||
hashbrown = { version = "0.12", features = ["raw"] } | ||
arrow = { version = "8.0.0", features = ["prettyprint"] } | ||
parquet = { version = "8.0.0", features = ["arrow"] } | ||
sqlparser = "0.13" | ||
paste = "^1.0" | ||
num_cpus = "1.13.0" | ||
chrono = { version = "0.4", default-features = false } | ||
ahash = {version = "0.7", default-features = false} | ||
arrow = {version = "8.0.0", features = ["prettyprint"]} | ||
async-trait = "0.1.41" | ||
avro-rs = {version = "0.13", features = ["snappy"], optional = true} | ||
blake2 = {version = "^0.10.2", optional = true} | ||
blake3 = {version = "1.0", optional = true} | ||
chrono = {version = "0.4", default-features = false} | ||
futures = "0.3" | ||
pin-project-lite= "^0.2.7" | ||
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs"] } | ||
tokio-stream = "0.1" | ||
hashbrown = {version = "0.12", features = ["raw"]} | ||
lazy_static = {version = "^1.4.0"} | ||
log = "^0.4" | ||
md-5 = { version = "^0.10.0", optional = true } | ||
sha2 = { version = "^0.10.1", optional = true } | ||
blake2 = { version = "^0.10.2", optional = true } | ||
blake3 = { version = "1.0", optional = true } | ||
md-5 = {version = "^0.10.0", optional = true} | ||
num = "0.4" | ||
num-traits = {version = "0.2", optional = true} | ||
num_cpus = "1.13.0" | ||
ordered-float = "2.0" | ||
unicode-segmentation = { version = "^1.7.1", optional = true } | ||
regex = { version = "^1.4.3", optional = true } | ||
lazy_static = { version = "^1.4.0" } | ||
smallvec = { version = "1.6", features = ["union"] } | ||
parquet = {version = "8.0.0", features = ["arrow"]} | ||
paste = "^1.0" | ||
pin-project-lite = "^0.2.7" | ||
pyo3 = {version = "0.15", optional = true} | ||
rand = "0.8" | ||
avro-rs = { version = "0.13", features = ["snappy"], optional = true } | ||
num-traits = { version = "0.2", optional = true } | ||
pyo3 = { version = "0.15", optional = true } | ||
regex = {version = "^1.4.3", optional = true} | ||
sha2 = {version = "^0.10.1", optional = true} | ||
smallvec = {version = "1.6", features = ["union"]} | ||
sqlparser = "0.13" | ||
tempfile = "3" | ||
tokio = {version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs"]} | ||
tokio-stream = "0.1" | ||
unicode-segmentation = {version = "^1.7.1", optional = true} | ||
|
||
[dev-dependencies] | ||
criterion = "0.3" | ||
doc-comment = "0.3" | ||
|
||
[[bench]] | ||
name = "aggregate_query_sql" | ||
harness = false | ||
name = "aggregate_query_sql" | ||
|
||
[[bench]] | ||
name = "sort_limit_query_sql" | ||
harness = false | ||
name = "sort_limit_query_sql" | ||
|
||
[[bench]] | ||
name = "math_query_sql" | ||
harness = false | ||
name = "math_query_sql" | ||
|
||
[[bench]] | ||
name = "filter_query_sql" | ||
harness = false | ||
name = "filter_query_sql" | ||
|
||
[[bench]] | ||
name = "window_query_sql" | ||
harness = false | ||
name = "window_query_sql" | ||
|
||
[[bench]] | ||
name = "scalar" | ||
harness = false | ||
name = "scalar" | ||
|
||
[[bench]] | ||
name = "physical_plan" | ||
harness = false | ||
name = "physical_plan" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use std::convert::TryInto; | ||
use std::{any::Any, sync::Arc}; | ||
|
||
use arrow::array::TimestampMillisecondArray; | ||
|
@@ -28,6 +29,18 @@ use arrow::compute::kernels::comparison::{ | |
eq_bool, eq_bool_scalar, gt_bool, gt_bool_scalar, gt_eq_bool, gt_eq_bool_scalar, | ||
lt_bool, lt_bool_scalar, lt_eq_bool, lt_eq_bool_scalar, neq_bool, neq_bool_scalar, | ||
}; | ||
use arrow::compute::kernels::comparison::{ | ||
eq_dyn_bool_scalar, gt_dyn_bool_scalar, gt_eq_dyn_bool_scalar, lt_dyn_bool_scalar, | ||
lt_eq_dyn_bool_scalar, neq_dyn_bool_scalar, | ||
}; | ||
use arrow::compute::kernels::comparison::{ | ||
eq_dyn_scalar, gt_dyn_scalar, gt_eq_dyn_scalar, lt_dyn_scalar, lt_eq_dyn_scalar, | ||
neq_dyn_scalar, | ||
}; | ||
use arrow::compute::kernels::comparison::{ | ||
eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, lt_dyn_utf8_scalar, | ||
lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar, | ||
}; | ||
use arrow::compute::kernels::comparison::{ | ||
eq_scalar, gt_eq_scalar, gt_scalar, lt_eq_scalar, lt_scalar, neq_scalar, | ||
}; | ||
|
@@ -44,6 +57,8 @@ use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit}; | |
use arrow::error::ArrowError::DivideByZero; | ||
use arrow::record_batch::RecordBatch; | ||
|
||
use num::ToPrimitive; | ||
|
||
use crate::error::{DataFusionError, Result}; | ||
use crate::logical_plan::Operator; | ||
use crate::physical_plan::coercion_rule::binary_rule::coerce_types; | ||
|
@@ -429,6 +444,24 @@ macro_rules! compute_utf8_op_scalar { | |
}}; | ||
} | ||
|
||
/// Invoke a compute kernel on a data array and a scalar value | ||
macro_rules! compute_utf8_op_dyn_scalar { | ||
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ | ||
if let ScalarValue::Utf8(Some(string_value)) = $RIGHT { | ||
Ok(Arc::new(paste::expr! {[<$OP _dyn_utf8_scalar>]}( | ||
$LEFT, | ||
&string_value, | ||
)?)) | ||
} else { | ||
Err(DataFusionError::Internal(format!( | ||
"compute_utf8_op_scalar for '{}' failed to cast literal value {}", | ||
stringify!($OP), | ||
$RIGHT | ||
))) | ||
} | ||
}}; | ||
} | ||
|
||
/// Invoke a compute kernel on a boolean data array and a scalar value | ||
macro_rules! compute_bool_op_scalar { | ||
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ | ||
|
@@ -446,6 +479,18 @@ macro_rules! compute_bool_op_scalar { | |
}}; | ||
} | ||
|
||
/// Invoke a compute kernel on a boolean data array and a scalar value | ||
macro_rules! compute_bool_op_dyn_scalar { | ||
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ | ||
// generate the scalar function name, such as lt_dyn_bool_scalar, from the $OP parameter | ||
// (which could have a value of lt) and the suffix _scalar | ||
Ok(Arc::new(paste::expr! {[<$OP _dyn_bool_scalar>]}( | ||
$LEFT, | ||
$RIGHT.try_into()?, | ||
)?)) | ||
}}; | ||
} | ||
|
||
/// Invoke a bool compute kernel on array(s) | ||
macro_rules! compute_bool_op { | ||
// invoke binary operator | ||
|
@@ -474,7 +519,7 @@ macro_rules! compute_bool_op { | |
/// LEFT is array, RIGHT is scalar value | ||
macro_rules! compute_op_scalar { | ||
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ | ||
use std::convert::TryInto; | ||
// use std::convert::TryInto; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it can be removed. |
||
let ll = $LEFT | ||
.as_any() | ||
.downcast_ref::<$DT>() | ||
|
@@ -488,6 +533,19 @@ macro_rules! compute_op_scalar { | |
}}; | ||
} | ||
|
||
/// Invoke a dyn compute kernel on a data array and a scalar value | ||
/// LEFT is Primitive or Dictionart array of numeric values, RIGHT is scalar value | ||
macro_rules! compute_op_dyn_scalar { | ||
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ | ||
// generate the scalar function name, such as lt_dyn_scalar, from the $OP parameter | ||
// (which could have a value of lt_dyn) and the suffix _scalar | ||
Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}( | ||
$LEFT, | ||
$RIGHT, | ||
)?)) | ||
}}; | ||
} | ||
|
||
/// Invoke a compute kernel on array(s) | ||
macro_rules! compute_op { | ||
// invoke binary operator | ||
|
@@ -878,26 +936,92 @@ impl PhysicalExpr for BinaryExpr { | |
} | ||
} | ||
|
||
/// The binary_array_op_scalar macro includes types that extend beyond the primitive, | ||
/// such as Utf8 strings. | ||
#[macro_export] | ||
macro_rules! binary_array_op_dyn_scalar { | ||
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ | ||
let is_numeric = DataType::is_numeric($LEFT.data_type()); | ||
let is_numeric_dict = match $LEFT.data_type() { | ||
DataType::Dictionary(_, val_type) => DataType::is_numeric(val_type), | ||
_ => false | ||
}; | ||
let numeric_like = is_numeric | is_numeric_dict; | ||
|
||
let is_string = ($LEFT.data_type() == &DataType::Utf8) | ($LEFT.data_type() == &DataType::LargeUtf8); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to me that the I wonder if we could do something like match $RIGHT {
ScalarValue::Utf8(v) => {
Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar_utf8>]}(
$LEFT,
v,
)?))
}
..
ScalarValue::Int8(v) => {
Ok(Arc::new(paste::expr! {[<$OP _dyn_scalar>]}(
$LEFT,
v,
)?))
}
...
} Though we will probably need some sort of wrapper to handle types not yet supported in arrow-rs 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alamb this basically solved everything i think...we'll see how CI comes out |
||
let is_string_dict = match $LEFT.data_type() { | ||
DataType::Dictionary(_, val_type) => match **val_type { | ||
DataType::Utf8 | DataType::LargeUtf8 => true, | ||
_ => false | ||
} | ||
_ => false | ||
}; | ||
let string_like = is_string | is_string_dict; | ||
|
||
let result: Result<Arc<dyn Array>> = if numeric_like { | ||
compute_op_dyn_scalar!($LEFT, $RIGHT.try_into()?, $OP) | ||
} else if string_like { | ||
compute_utf8_op_dyn_scalar!($LEFT, $RIGHT, $OP) | ||
} else { | ||
let r: Result<Arc<dyn Array>> = match $LEFT.data_type() { | ||
|
||
DataType::Decimal(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, DecimalArray), | ||
DataType::Boolean => compute_bool_op_dyn_scalar!($LEFT, $RIGHT, $OP), | ||
DataType::Timestamp(TimeUnit::Nanosecond, _) => { | ||
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampNanosecondArray) | ||
} | ||
DataType::Timestamp(TimeUnit::Microsecond, _) => { | ||
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMicrosecondArray) | ||
} | ||
DataType::Timestamp(TimeUnit::Millisecond, _) => { | ||
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampMillisecondArray) | ||
} | ||
DataType::Timestamp(TimeUnit::Second, _) => { | ||
compute_op_scalar!($LEFT, $RIGHT, $OP, TimestampSecondArray) | ||
} | ||
DataType::Date32 => { | ||
compute_op_scalar!($LEFT, $RIGHT, $OP, Date32Array) | ||
} | ||
DataType::Date64 => { | ||
compute_op_scalar!($LEFT, $RIGHT, $OP, Date64Array) | ||
} | ||
other => Err(DataFusionError::Internal(format!( | ||
"Data type {:?} not supported for scalar operation '{}' on dyn array", | ||
other, stringify!($OP) | ||
))), | ||
}; | ||
r | ||
}; | ||
Some(result) | ||
}} | ||
} | ||
|
||
impl BinaryExpr { | ||
/// Evaluate the expression of the left input is an array and | ||
/// right is literal - use scalar operations | ||
fn evaluate_array_scalar( | ||
&self, | ||
array: &ArrayRef, | ||
array: &dyn Array, | ||
scalar: &ScalarValue, | ||
) -> Result<Option<Result<ArrayRef>>> { | ||
let scalar_result = match &self.op { | ||
Operator::Lt => binary_array_op_scalar!(array, scalar.clone(), lt), | ||
Operator::Lt => { | ||
binary_array_op_dyn_scalar!(array, scalar.clone(), lt) | ||
} | ||
Operator::LtEq => { | ||
binary_array_op_scalar!(array, scalar.clone(), lt_eq) | ||
binary_array_op_dyn_scalar!(array, scalar.clone(), lt_eq) | ||
} | ||
Operator::Gt => { | ||
binary_array_op_dyn_scalar!(array, scalar.clone(), gt) | ||
} | ||
Operator::Gt => binary_array_op_scalar!(array, scalar.clone(), gt), | ||
Operator::GtEq => { | ||
binary_array_op_scalar!(array, scalar.clone(), gt_eq) | ||
binary_array_op_dyn_scalar!(array, scalar.clone(), gt_eq) | ||
} | ||
Operator::Eq => { | ||
binary_array_op_dyn_scalar!(array, scalar.clone(), eq) | ||
} | ||
Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq), | ||
Operator::NotEq => { | ||
binary_array_op_scalar!(array, scalar.clone(), neq) | ||
binary_array_op_dyn_scalar!(array, scalar.clone(), neq) | ||
} | ||
Operator::Like => { | ||
binary_string_array_op_scalar!(array, scalar.clone(), like) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -333,6 +333,29 @@ impl std::hash::Hash for ScalarValue { | |||||
} | ||||||
} | ||||||
|
||||||
impl num::ToPrimitive for ScalarValue { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||||||
fn to_i64(&self) -> Option<i64> { | ||||||
use ScalarValue::*; | ||||||
match self { | ||||||
Int8(v) => Some(v.unwrap() as i64), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can avoid the unwrap using something like:
Suggested change
|
||||||
Int16(v) => Some(v.unwrap() as i64), | ||||||
Int32(v) => Some(v.unwrap() as i64), | ||||||
Int64(v) => Some(v.unwrap() as i64), | ||||||
_ => None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can probably implement it for UInt* as well and the timestamp types |
||||||
} | ||||||
} | ||||||
fn to_u64(&self) -> Option<u64> { | ||||||
use ScalarValue::*; | ||||||
match self { | ||||||
UInt8(v) => Some(v.unwrap() as u64), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be ok to implement for |
||||||
UInt16(v) => Some(v.unwrap() as u64), | ||||||
UInt32(v) => Some(v.unwrap() as u64), | ||||||
UInt64(v) => Some(v.unwrap() as u64), | ||||||
_ => None, | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
// return the index into the dictionary values for array@index as well | ||||||
// as a reference to the dictionary values array. Returns None for the | ||||||
// index if the array is NULL at index | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😭why are there so many meaningless changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe I accidentally started auto sorting my Cargo.toml. Or maybe it was related to upgrading cargo / rust? I'm not sure to be honest - trying to figure out what caused the change.