Skip to content

Commit

Permalink
perf: Ensure we hit specialized gather for binary/strings (pola-rs#15886
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ritchie46 authored Apr 25, 2024
1 parent 44aab96 commit 05475da
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 60 deletions.
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/arithmetic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl Add for &StringChunked {
type Output = StringChunked;

fn add(self, rhs: Self) -> Self::Output {
unsafe { (self.as_binary() + rhs.as_binary()).to_string() }
unsafe { (self.as_binary() + rhs.as_binary()).to_string_unchecked() }
}
}

Expand All @@ -39,7 +39,7 @@ impl Add<&str> for &StringChunked {
type Output = StringChunked;

fn add(self, rhs: &str) -> Self::Output {
unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_string() }
unsafe { ((&self.as_binary()) + rhs.as_bytes()).to_string_unchecked() }
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ impl ChunkCast for StringChunked {
impl BinaryChunked {
/// # Safety
/// String is not validated
pub unsafe fn to_string(&self) -> StringChunked {
pub unsafe fn to_string_unchecked(&self) -> StringChunked {
let chunks = self
.downcast_iter()
.map(|arr| arr.to_utf8view_unchecked().boxed())
Expand Down Expand Up @@ -334,7 +334,7 @@ impl ChunkCast for BinaryChunked {

unsafe fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult<Series> {
match data_type {
DataType::String => unsafe { Ok(self.to_string().into_series()) },
DataType::String => unsafe { Ok(self.to_string_unchecked().into_series()) },
_ => self.cast(data_type),
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ where

impl<T> ChunkedArray<T>
where
T: PolarsDataType<Structure = Flat>,
T: PolarsDataType<IsNested = FalseT>,
for<'a> T::Physical<'a>: TotalOrd,
{
/// Append in place. This is done by adding the chunks of `other` to this [`ChunkedArray`].
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl ChunkFilter<BooleanType> for BooleanChunked {
impl ChunkFilter<StringType> for StringChunked {
fn filter(&self, filter: &BooleanChunked) -> PolarsResult<ChunkedArray<StringType>> {
let out = self.as_binary().filter(filter)?;
unsafe { Ok(out.to_string()) }
unsafe { Ok(out.to_string_unchecked()) }
}
}

Expand Down
54 changes: 28 additions & 26 deletions crates/polars-core/src/chunked_array/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ unsafe fn gather_idx_array_unchecked<A: StaticArray>(
}
}

impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T> {
impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
where
T: PolarsDataType<HasViews = FalseT>,
{
/// Gather values from ChunkedArray by index.
unsafe fn take_unchecked(&self, indices: &I) -> Self {
let rechunked;
Expand All @@ -161,29 +164,6 @@ impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for
}
}

trait NotSpecialized {}
impl NotSpecialized for Int8Type {}
impl NotSpecialized for Int16Type {}
impl NotSpecialized for Int32Type {}
impl NotSpecialized for Int64Type {}
#[cfg(feature = "dtype-decimal")]
impl NotSpecialized for Int128Type {}
impl NotSpecialized for UInt8Type {}
impl NotSpecialized for UInt16Type {}
impl NotSpecialized for UInt32Type {}
impl NotSpecialized for UInt64Type {}
impl NotSpecialized for Float32Type {}
impl NotSpecialized for Float64Type {}
impl NotSpecialized for BooleanType {}
impl NotSpecialized for ListType {}
#[cfg(feature = "dtype-array")]
impl NotSpecialized for FixedSizeListType {}
impl NotSpecialized for BinaryOffsetType {}
#[cfg(feature = "dtype-decimal")]
impl NotSpecialized for DecimalType {}
#[cfg(feature = "object")]
impl<T> NotSpecialized for ObjectType<T> {}

pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
use crate::series::IsSorted::*;
match (sorted_arr, sorted_idx) {
Expand All @@ -196,7 +176,10 @@ pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) ->
}
}

impl<T: PolarsDataType + NotSpecialized> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T> {
impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
where
T: PolarsDataType<HasViews = FalseT>,
{
/// Gather values from ChunkedArray by index.
unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
let rechunked;
Expand Down Expand Up @@ -272,7 +255,26 @@ impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {

impl ChunkTakeUnchecked<IdxCa> for StringChunked {
unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
self.as_binary().take_unchecked(indices).to_string()
self.as_binary()
.take_unchecked(indices)
.to_string_unchecked()
}
}

impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
/// Gather values from ChunkedArray by index.
unsafe fn take_unchecked(&self, indices: &I) -> Self {
let indices = IdxCa::mmap_slice("", indices.as_ref());
self.take_unchecked(&indices)
}
}

impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
/// Gather values from ChunkedArray by index.
unsafe fn take_unchecked(&self, indices: &I) -> Self {
self.as_binary()
.take_unchecked(indices)
.to_string_unchecked()
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl ChunkReverse for BinaryChunked {

impl ChunkReverse for StringChunked {
fn reverse(&self) -> Self {
unsafe { self.as_binary().reverse().to_string() }
unsafe { self.as_binary().reverse().to_string_unchecked() }
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl ChunkShiftFill<StringType, Option<&str>> for StringChunked {
let ca = self.as_binary();
unsafe {
ca.shift_and_fill(periods, fill_value.map(|v| v.as_bytes()))
.to_string()
.to_string_unchecked()
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ fn ordering_other_columns<'a>(

impl ChunkSort<StringType> for StringChunked {
fn sort_with(&self, options: SortOptions) -> ChunkedArray<StringType> {
unsafe { self.as_binary().sort_with(options).to_string() }
unsafe { self.as_binary().sort_with(options).to_string_unchecked() }
}

fn sort(&self, descending: bool) -> StringChunked {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ where
impl ChunkUnique<StringType> for StringChunked {
fn unique(&self) -> PolarsResult<Self> {
let out = self.as_binary().unique()?;
Ok(unsafe { out.to_string() })
Ok(unsafe { out.to_string_unchecked() })
}

fn arg_unique(&self) -> PolarsResult<IdxCa> {
Expand Down
67 changes: 48 additions & 19 deletions crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ use crate::chunked_array::object::PolarsObjectSafe;
use crate::prelude::*;
use crate::utils::Wrap;

pub struct Nested;
pub struct Flat;
pub struct TrueT;
pub struct FalseT;

/// # Safety
///
Expand All @@ -68,7 +68,8 @@ pub unsafe trait PolarsDataType: Send + Sync + Sized {
ValueT<'a> = Self::Physical<'a>,
ZeroableValueT<'a> = Self::ZeroablePhysical<'a>,
>;
type Structure;
type IsNested;
type HasViews;

fn get_dtype() -> DataType
where
Expand All @@ -81,7 +82,8 @@ where
Physical<'a> = Self::Native,
ZeroablePhysical<'a> = Self::Native,
Array = PrimitiveArray<Self::Native>,
Structure = Flat,
IsNested = FalseT,
HasViews = FalseT,
>,
{
type Native: NumericNative;
Expand All @@ -99,7 +101,8 @@ macro_rules! impl_polars_num_datatype {
type Physical<'a> = $physical;
type ZeroablePhysical<'a> = $physical;
type Array = PrimitiveArray<$physical>;
type Structure = Flat;
type IsNested = FalseT;
type HasViews = FalseT;

#[inline]
fn get_dtype() -> DataType {
Expand All @@ -115,16 +118,17 @@ macro_rules! impl_polars_num_datatype {
};
}

macro_rules! impl_polars_datatype2 {
($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
macro_rules! impl_polars_datatype_pass_dtype {
($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $has_views:ident) => {
#[derive(Clone, Copy)]
pub struct $ca {}

unsafe impl PolarsDataType for $ca {
type Physical<$lt> = $phys;
type ZeroablePhysical<$lt> = $zerophys;
type Array = $arr;
type Structure = Flat;
type IsNested = FalseT;
type HasViews = $has_views;

#[inline]
fn get_dtype() -> DataType {
Expand All @@ -133,10 +137,31 @@ macro_rules! impl_polars_datatype2 {
}
};
}
macro_rules! impl_polars_binview_datatype {
($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
impl_polars_datatype_pass_dtype!(
$ca,
DataType::$variant,
$arr,
$lt,
$phys,
$zerophys,
TrueT
);
};
}

macro_rules! impl_polars_datatype {
($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty) => {
impl_polars_datatype2!($ca, DataType::$variant, $arr, $lt, $phys, $zerophys);
impl_polars_datatype_pass_dtype!(
$ca,
DataType::$variant,
$arr,
$lt,
$phys,
$zerophys,
FalseT
);
};
}

Expand All @@ -152,24 +177,25 @@ impl_polars_num_datatype!(PolarsFloatType, Float32Type, Float32, f32);
impl_polars_num_datatype!(PolarsFloatType, Float64Type, Float64, f64);
impl_polars_datatype!(DateType, Date, PrimitiveArray<i32>, 'a, i32, i32);
impl_polars_datatype!(TimeType, Time, PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype!(StringType, String, Utf8ViewArray, 'a, &'a str, Option<&'a str>);
impl_polars_datatype!(BinaryType, Binary, BinaryViewArray, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_binview_datatype!(StringType, String, Utf8ViewArray, 'a, &'a str, Option<&'a str>);
impl_polars_binview_datatype!(BinaryType, Binary, BinaryViewArray, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_datatype!(BinaryOffsetType, BinaryOffset, BinaryArray<i64>, 'a, &'a [u8], Option<&'a [u8]>);
impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool);

#[cfg(feature = "dtype-decimal")]
impl_polars_datatype2!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i128>, 'a, i128, i128);
impl_polars_datatype2!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype2!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64);
impl_polars_datatype2!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<u32>, 'a, u32, u32);
impl_polars_datatype_pass_dtype!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i128>, 'a, i128, i128, FalseT);
impl_polars_datatype_pass_dtype!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64, FalseT);
impl_polars_datatype_pass_dtype!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<i64>, 'a, i64, i64, FalseT);
impl_polars_datatype_pass_dtype!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray<u32>, 'a, u32, u32, FalseT);

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ListType {}
unsafe impl PolarsDataType for ListType {
type Physical<'a> = Box<dyn Array>;
type ZeroablePhysical<'a> = Option<Box<dyn Array>>;
type Array = ListArray<i64>;
type Structure = Nested;
type IsNested = TrueT;
type HasViews = FalseT;

fn get_dtype() -> DataType {
// Null as we cannot know anything without self.
Expand All @@ -184,7 +210,8 @@ unsafe impl PolarsDataType for FixedSizeListType {
type Physical<'a> = Box<dyn Array>;
type ZeroablePhysical<'a> = Option<Box<dyn Array>>;
type Array = FixedSizeListArray;
type Structure = Nested;
type IsNested = TrueT;
type HasViews = FalseT;

fn get_dtype() -> DataType {
// Null as we cannot know anything without self.
Expand All @@ -198,7 +225,8 @@ unsafe impl PolarsDataType for Int128Type {
type Physical<'a> = i128;
type ZeroablePhysical<'a> = i128;
type Array = PrimitiveArray<i128>;
type Structure = Flat;
type IsNested = FalseT;
type HasViews = FalseT;

fn get_dtype() -> DataType {
// Scale is not None to allow for get_any_value() to work.
Expand All @@ -218,7 +246,8 @@ unsafe impl<T: PolarsObject> PolarsDataType for ObjectType<T> {
type Physical<'a> = &'a T;
type ZeroablePhysical<'a> = Option<&'a T>;
type Array = ObjectArray<T>;
type Structure = Nested;
type IsNested = TrueT;
type HasViews = FalseT;

fn get_dtype() -> DataType {
DataType::Object(T::type_name(), None)
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/frame/group_by/aggregations/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ impl StringChunked {
#[allow(clippy::needless_lifetimes)]
pub(crate) unsafe fn agg_min<'a>(&'a self, groups: &GroupsProxy) -> Series {
let out = self.as_binary().agg_min(groups);
out.binary().unwrap().to_string().into_series()
out.binary().unwrap().to_string_unchecked().into_series()
}

#[allow(clippy::needless_lifetimes)]
pub(crate) unsafe fn agg_max<'a>(&'a self, groups: &GroupsProxy) -> Series {
let out = self.as_binary().agg_max(groups);
out.binary().unwrap().to_string().into_series()
out.binary().unwrap().to_string_unchecked().into_series()
}
}
4 changes: 2 additions & 2 deletions crates/polars-ops/src/chunked_array/gather/chunked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl TakeChunked for Series {
let ca = phys.str().unwrap();
let ca = ca.as_binary();
let out = take_unchecked_binview(&ca, by, sorted);
out.to_string().into_series()
out.to_string_unchecked().into_series()
},
List(_) => {
let ca = phys.list().unwrap();
Expand Down Expand Up @@ -169,7 +169,7 @@ impl TakeChunked for Series {
let ca = phys.str().unwrap();
let ca = ca.as_binary();
let out = take_unchecked_binview_opt(&ca, by);
out.to_string().into_series()
out.to_string_unchecked().into_series()
},
List(_) => {
let ca = phys.list().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/top_k.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ pub fn top_k(s: &[Series], descending: bool) -> PolarsResult<Series> {
DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, descending).into_series()),
DataType::String => {
let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, descending);
let ca = unsafe { ca.to_string() };
let ca = unsafe { ca.to_string_unchecked() };
Ok(ca.into_series())
},
DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_series()),
Expand Down

0 comments on commit 05475da

Please sign in to comment.