Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ScalarValue::try_as_str to get str value from logical strings #14167

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,50 @@ impl ScalarValue {
ScalarValue::from(value).cast_to(target_type)
}

/// Returns the Some(`&str`) representation of `ScalarValue` of logical string type
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the new function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 doc

///
/// Returns `None` if this `ScalarValue` is not a logical string type or the
/// `ScalarValue` represents the `NULL` value.
///
/// Note you can use [`Option::flatten`] to check for non null logical
/// strings.
///
/// For example, [`ScalarValue::Utf8`], [`ScalarValue::LargeUtf8`], and
/// [`ScalarValue::Dictionary`] with a logical string value and store
/// strings and can be accessed as `&str` using this method.
///
/// # Example: logical strings
/// ```
/// # use datafusion_common::ScalarValue;
/// /// non strings return None
/// let scalar = ScalarValue::from(42);
/// assert_eq!(scalar.try_as_str(), None);
/// // Non null logical string returns Some(Some(&str))
/// let scalar = ScalarValue::from("hello");
/// assert_eq!(scalar.try_as_str(), Some(Some("hello")));
/// // Null logical string returns Some(None)
/// let scalar = ScalarValue::Utf8(None);
/// assert_eq!(scalar.try_as_str(), Some(None));
/// ```
///
/// # Example: use [`Option::flatten`] to check for non-null logical strings
/// ```
/// # use datafusion_common::ScalarValue;
/// // Non null logical string returns Some(Some(&str))
/// let scalar = ScalarValue::from("hello");
/// assert_eq!(scalar.try_as_str().flatten(), Some("hello"));
/// ```
pub fn try_as_str(&self) -> Option<Option<&str>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not -> Result<Option<&str>> for this try method?
Caller can always convert to an option.

(Also, most of the use cases in this PR are converting a returned None to an error).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review @wiedld

DataFusionError always has an owned String in it, so returning an Result is actually quite slow as it needs to allocate some memory and copy stuff around. Thus I think this API should return an Option

let v = match self {
ScalarValue::Utf8(v) => v,
ScalarValue::LargeUtf8(v) => v,
ScalarValue::Utf8View(v) => v,
ScalarValue::Dictionary(_, v) => return v.try_as_str(),
_ => return None,
};
Some(v.as_ref().map(|v| v.as_str()))
}

/// Try to cast this value to a ScalarValue of type `data_type`
pub fn cast_to(&self, target_type: &DataType) -> Result<Self> {
self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS)
Expand Down
18 changes: 5 additions & 13 deletions datafusion/core/tests/sql/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ async fn parquet_distinct_partition_col() -> Result<()> {
assert_eq!(min_limit, resulting_limit);

let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
let month = match extract_as_utf(&s) {
Some(month) => month,
s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
};
assert!(
matches!(s.data_type(), DataType::Dictionary(_, v) if v.as_ref() == &DataType::Utf8),
"Expected month as Dict(_, Utf8) found {s:?}"
);
let month = s.try_as_str().flatten().unwrap();

let sql_on_partition_boundary = format!(
"SELECT month from t where month = '{}' LIMIT {}",
Expand All @@ -241,15 +242,6 @@ async fn parquet_distinct_partition_col() -> Result<()> {
Ok(())
}

fn extract_as_utf(v: &ScalarValue) -> Option<String> {
if let ScalarValue::Dictionary(_, v) = v {
if let ScalarValue::Utf8(v) = v.as_ref() {
return v.clone();
}
}
None
}

#[tokio::test]
async fn csv_filter_with_file_col() -> Result<()> {
let ctx = SessionContext::new_with_config(
Expand Down
15 changes: 7 additions & 8 deletions datafusion/functions-aggregate/src/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,14 @@ impl AggregateUDFImpl for StringAgg {

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() {
return match lit.value() {
ScalarValue::Utf8(Some(delimiter))
| ScalarValue::LargeUtf8(Some(delimiter)) => {
Ok(Box::new(StringAggAccumulator::new(delimiter.as_str())))
return match lit.value().try_as_str() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty good example of can reducing the repetition in the code to check for string literal values. This also now implicitly will work for Dictionary values where it would not have before

Some(Some(delimiter)) => {
Ok(Box::new(StringAggAccumulator::new(delimiter)))
}
Some(None) => Ok(Box::new(StringAggAccumulator::new(""))),
None => {
not_impl_err!("StringAgg not supported for delimiter {}", lit.value())
}
ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None)
| ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))),
e => not_impl_err!("StringAgg not supported for delimiter {}", e),
};
}

Expand Down
8 changes: 3 additions & 5 deletions datafusion/functions/src/crypto/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
);
}
let digest_algorithm = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8View(Some(method))
| ScalarValue::Utf8(Some(method))
| ScalarValue::LargeUtf8(Some(method)) => method.parse::<DigestAlgorithm>(),
other => exec_err!("Unsupported data type {other:?} for function digest"),
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also avoid a bunch more duplication of stuff like this:

        let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part {

If we added a similar convenience method ColumnarValue::try_as_scalar_str() that returned a Option<Option<&str>>

Similarly we could do the same with Expr::try_as_scalar_str()

Some(Some(method)) => method.parse::<DigestAlgorithm>(),
_ => exec_err!("Unsupported data type {scalar:?} for function digest"),
},
ColumnarValue::Array(_) => {
internal_err!("Digest using dynamically decided method is not yet supported")
Expand Down
33 changes: 10 additions & 23 deletions datafusion/functions/src/datetime/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,12 @@ where
))),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8View(a)
| ScalarValue::LargeUtf8(a)
| ScalarValue::Utf8(a) => {
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(a) => {
let result = a.as_ref().map(|x| op(x)).transpose()?;
Ok(ColumnarValue::Scalar(S::scalar(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
_ => exec_err!("Unsupported data type {scalar:?} for function {name}"),
},
}
}
Expand Down Expand Up @@ -270,10 +268,8 @@ where
}
},
// if the first argument is a scalar utf8 all arguments are expected to be scalar utf8
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8View(a)
| ScalarValue::LargeUtf8(a)
| ScalarValue::Utf8(a) => {
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(a) => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is clearer now

let a = a.as_ref();
// ASK: Why do we trust `a` to be non-null at this point?
let a = unwrap_or_internal_err!(a);
Expand All @@ -291,7 +287,7 @@ where
};

if let Some(s) = x {
match op(a.as_str(), s.as_str()) {
match op(a, s.as_str()) {
Ok(r) => {
ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some(
op2(r),
Expand Down Expand Up @@ -408,19 +404,10 @@ where
DataType::Utf8 => Ok(a.as_string::<i32>().value(pos)),
other => exec_err!("Unexpected type encountered '{other}'"),
},
ColumnarValue::Scalar(s) => match s {
ScalarValue::Utf8View(a)
| ScalarValue::LargeUtf8(a)
| ScalarValue::Utf8(a) => {
if let Some(v) = a {
Ok(v.as_str())
} else {
continue;
}
}
other => {
exec_err!("Unexpected scalar type encountered '{other}'")
}
ColumnarValue::Scalar(s) => match s.try_as_str() {
Some(Some(v)) => Ok(v),
Some(None) => continue, // null string
None => exec_err!("Unexpected scalar type encountered '{s}'"),
},
}?;

Expand Down
16 changes: 6 additions & 10 deletions datafusion/functions/src/encoding/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,10 @@ fn encode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
);
}
let encoding = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
method.parse::<Encoding>()
}
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method)) => method.parse::<Encoding>(),
_ => not_impl_err!(
"Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported"
"Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}"
),
},
ColumnarValue::Array(_) => not_impl_err!(
Expand All @@ -572,12 +570,10 @@ fn decode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
);
}
let encoding = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
method.parse::<Encoding>()
}
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method))=> method.parse::<Encoding>(),
_ => not_impl_err!(
"Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported"
"Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}"
),
},
ColumnarValue::Array(_) => not_impl_err!(
Expand Down
20 changes: 9 additions & 11 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,16 @@ impl ScalarUDFImpl for ConcatFunc {
if array_len.is_none() {
let mut result = String::new();
for arg in args {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => {
result.push_str(v);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
other => plan_err!(
let ColumnarValue::Scalar(scalar) = arg else {
return internal_err!("concat expected scalar value, got {arg:?}");
};

match scalar.try_as_str() {
Some(Some(v)) => result.push_str(v),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {:?}",
other
scalar
)?,
}
}
Expand Down
62 changes: 34 additions & 28 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,48 +124,54 @@ impl ScalarUDFImpl for ConcatWsFunc {

// Scalar
if array_len.is_none() {
let sep = match &args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s,
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {
let ColumnarValue::Scalar(scalar) = &args[0] else {
// loop above checks for all args being scalar
unreachable!()
};
let sep = match scalar.try_as_str() {
Some(Some(s)) => s,
Some(None) => {
// null literal string
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
}
_ => unreachable!(),
None => return internal_err!("Expected string literal, got {scalar:?}"),
};

let mut result = String::new();
let iter = &mut args[1..].iter();

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
// iterator over Option<str>
let iter = &mut args[1..].iter().map(|arg| {
let ColumnarValue::Scalar(scalar) = arg else {
// loop above checks for all args being scalar
unreachable!()
};
scalar.try_as_str()
});

// append first non null arg
for scalar in iter.by_ref() {
match scalar {
Some(Some(s)) => {
result.push_str(s);
break;
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
Some(None) => {} // null literal string
None => {
return internal_err!("Expected string literal, got {scalar:?}")
}
}
}

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
// handle subsequent non null args
for scalar in iter.by_ref() {
match scalar {
Some(Some(s)) => {
result.push_str(sep);
result.push_str(s);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
Some(None) => {} // null literal string
None => {
return internal_err!("Expected string literal, got {scalar:?}")
}
}
}

Expand Down
7 changes: 1 addition & 6 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,7 @@ fn try_cast_string_literal(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
let string_value = match lit_value {
ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => {
s.clone()
}
_ => return None,
};
let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
let scalar_value = match target_type {
DataType::Utf8 => ScalarValue::Utf8(string_value),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
Expand Down
21 changes: 6 additions & 15 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,13 @@ macro_rules! compute_utf8_flag_op_scalar {
.downcast_ref::<$ARRAYTYPE>()
.expect("compute_utf8_flag_op_scalar failed to downcast array");

let string_value = match $RIGHT {
ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value,
ScalarValue::Dictionary(_, value) => {
match *value {
ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value,
other => return internal_err!(
"compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'",
other, stringify!($OP)
)
}
},
let string_value = match $RIGHT.try_as_str() {
Some(Some(string_value)) => string_value,
// null literal or non string
_ => return internal_err!(
"compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
$RIGHT, stringify!($OP)
)

"compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
$RIGHT, stringify!($OP)
)
};

let flag = $FLAG.then_some("i");
Expand Down
Loading
Loading