diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 155f391d4c04..a743359d83ae 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -50,7 +50,9 @@ use arrow::{ compute::length::length, datatypes::TimeUnit, datatypes::{DataType, Field, Schema}, + error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, + types::NativeType, }; use fmt::{Debug, Formatter}; use std::convert::From; @@ -720,6 +722,46 @@ macro_rules! invoke_if_unicode_expressions_feature_flag { }; } +fn unary_offsets_string(array: &Utf8Array, op: F) -> PrimitiveArray +where + O: Offset + NativeType, + F: Fn(O) -> O, +{ + let values = array + .offsets() + .windows(2) + .map(|offset| op(offset[1] - offset[0])); + + let values = arrow::buffer::Buffer::from_trusted_len_iter(values); + + let data_type = if O::is_large() { + DataType::Int64 + } else { + DataType::Int32 + }; + + PrimitiveArray::::from_data(data_type, values, array.validity().cloned()) +} + +/// Returns an array of integers with the number of bits on each string of the array. +/// TODO: contribute this back upstream? +fn bit_length(array: &dyn Array) -> ArrowResult> { + match array.data_type() { + DataType::Utf8 => { + let array = array.as_any().downcast_ref::>().unwrap(); + Ok(Box::new(unary_offsets_string::(array, |x| x * 8))) + } + DataType::LargeUtf8 => { + let array = array.as_any().downcast_ref::>().unwrap(); + Ok(Box::new(unary_offsets_string::(array, |x| x * 8))) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "length not supported for {:?}", + array.data_type() + ))), + } +} + /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, @@ -761,7 +803,9 @@ pub fn create_physical_fun( ))), }), BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(_v) => todo!(), + ColumnarValue::Array(v) => { + Ok(ColumnarValue::Array(bit_length(v.as_ref())?.into())) + } ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| (x.len() * 8) as i32),