From ff0b7d79bec87703a704d51e9776c019dde8cc12 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 26 Oct 2021 09:33:32 +0200 Subject: [PATCH] reduce monomorphization of take_rand structs --- .../src/chunked_array/ops/compare_inner.rs | 12 +++++------ .../src/chunked_array/ops/take/take_random.rs | 21 ++++++++++--------- polars/polars-core/src/datatypes.rs | 3 ++- polars/polars-core/src/frame/row.rs | 18 ++++++++++++++-- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/polars/polars-core/src/chunked_array/ops/compare_inner.rs b/polars/polars-core/src/chunked_array/ops/compare_inner.rs index 1ea4ae48a606..4a3d995c20ec 100644 --- a/polars/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/polars/polars-core/src/chunked_array/ops/compare_inner.rs @@ -43,7 +43,7 @@ macro_rules! impl_traits { ($struct:ty, $T:tt) => { impl<$T> PartialEqInner for $struct where - $T: PolarsNumericType + Sync, + $T: NumericNative + Sync, { #[inline] unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool { @@ -53,7 +53,7 @@ macro_rules! impl_traits { impl<$T> PartialOrdInner for $struct where - $T: PolarsNumericType + Sync, + $T: NumericNative + Sync, { #[inline] unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { @@ -106,11 +106,11 @@ where }; Box::new(t) } else { - let t = NumTakeRandomSingleChunk::<'_, T> { arr }; + let t = NumTakeRandomSingleChunk::<'_, T::Native> { arr }; Box::new(t) } } else { - let t = NumTakeRandomChunked::<'_, T> { + let t = NumTakeRandomChunked::<'_, T::Native> { chunks: chunks.collect(), chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(), }; @@ -219,11 +219,11 @@ where }; Box::new(t) } else { - let t = NumTakeRandomSingleChunk::<'_, T> { arr }; + let t = NumTakeRandomSingleChunk::<'_, T::Native> { arr }; Box::new(t) } } else { - let t = NumTakeRandomChunked::<'_, T> { + let t = NumTakeRandomChunked::<'_, T::Native> { chunks: chunks.collect(), chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(), }; diff --git a/polars/polars-core/src/chunked_array/ops/take/take_random.rs b/polars/polars-core/src/chunked_array/ops/take/take_random.rs index 5d576fee72f5..1abdc74558c2 100644 --- a/polars/polars-core/src/chunked_array/ops/take/take_random.rs +++ b/polars/polars-core/src/chunked_array/ops/take/take_random.rs @@ -122,6 +122,7 @@ where } } +#[allow(clippy::type_complexity)] impl<'a, T> IntoTakeRandom<'a> for &'a ChunkedArray where T: PolarsNumericType, @@ -129,8 +130,8 @@ where type Item = T::Native; type TakeRandom = TakeRandBranch3< NumTakeRandomCont<'a, T::Native>, - NumTakeRandomSingleChunk<'a, T>, - NumTakeRandomChunked<'a, T>, + NumTakeRandomSingleChunk<'a, T::Native>, + NumTakeRandomChunked<'a, T::Native>, >; #[inline] @@ -271,17 +272,17 @@ impl<'a> IntoTakeRandom<'a> for &'a ListChunked { pub struct NumTakeRandomChunked<'a, T> where - T: PolarsNumericType, + T: NumericNative, { - pub(crate) chunks: Vec<&'a PrimitiveArray>, + pub(crate) chunks: Vec<&'a PrimitiveArray>, pub(crate) chunk_lens: Vec, } impl<'a, T> TakeRandom for NumTakeRandomChunked<'a, T> where - T: PolarsNumericType, + T: NumericNative, { - type Item = T::Native; + type Item = T; #[inline] fn get(&self, index: usize) -> Option { @@ -317,16 +318,16 @@ where pub struct NumTakeRandomSingleChunk<'a, T> where - T: PolarsNumericType, + T: NumericNative, { - pub(crate) arr: &'a PrimitiveArray, + pub(crate) arr: &'a PrimitiveArray, } impl<'a, T> TakeRandom for NumTakeRandomSingleChunk<'a, T> where - T: PolarsNumericType, + T: NumericNative, { - type Item = T::Native; + type Item = T; #[inline] fn get(&self, index: usize) -> Option { diff --git a/polars/polars-core/src/datatypes.rs b/polars/polars-core/src/datatypes.rs index 57df8d2d42a6..eb36e2d9fcc5 100644 --- a/polars/polars-core/src/datatypes.rs +++ b/polars/polars-core/src/datatypes.rs @@ -128,6 +128,7 @@ pub type CategoricalChunked = ChunkedArray; pub trait NumericNative: PartialOrd + + NativeType + Num + NumCast + Zero @@ -156,7 +157,7 @@ impl NumericNative for f32 {} impl NumericNative for f64 {} pub trait PolarsNumericType: Send + Sync + PolarsDataType + 'static { - type Native: NativeType + NumericNative; + type Native: NumericNative; } impl PolarsNumericType for UInt8Type { type Native = u8; diff --git a/polars/polars-core/src/frame/row.rs b/polars/polars-core/src/frame/row.rs index bcfbef3aebfe..f600e4c39fe3 100644 --- a/polars/polars-core/src/frame/row.rs +++ b/polars/polars-core/src/frame/row.rs @@ -137,9 +137,11 @@ impl DataFrame { let iter = columns.iter().map(|s| { (0..s.len()).zip(row.0.iter_mut()).for_each(|(i, av)| { - *av = s.get(i); + // Safety: + // we iterate over the length of s, so we are in bounds + unsafe { *av = s.get_unchecked(i) }; }); - // borrow checkery does not allow row borrow, so we deref from raw ptr. + // borrow checker does not allow row borrow, so we deref from raw ptr. // we do all this to amortize allocs // Safety: // row is still alive @@ -201,6 +203,8 @@ impl<'a> From<&AnyValue<'a>> for Field { Date(_) => Field::new("", DataType::Date), #[cfg(feature = "dtype-datetime")] Datetime(_) => Field::new("", DataType::Datetime), + #[cfg(feature = "dtype-time")] + Time(_) => Field::new("", DataType::Time), _ => unimplemented!(), } } @@ -232,6 +236,8 @@ pub(crate) enum Buffer { Date(PrimitiveChunkedBuilder), #[cfg(feature = "dtype-datetime")] Datetime(PrimitiveChunkedBuilder), + #[cfg(feature = "dtype-time")] + Time(PrimitiveChunkedBuilder), Float32(PrimitiveChunkedBuilder), Float64(PrimitiveChunkedBuilder), Utf8(Utf8ChunkedBuilder), @@ -250,6 +256,8 @@ impl Debug for Buffer { Date(_) => f.write_str("Date"), #[cfg(feature = "dtype-datetime")] Datetime(_) => f.write_str("datetime"), + #[cfg(feature = "dtype-time")] + Time(_) => f.write_str("time"), Float32(_) => f.write_str("f32"), Float64(_) => f.write_str("f64"), Utf8(_) => f.write_str("utf8"), @@ -275,6 +283,8 @@ impl Buffer { (Date(builder), AnyValue::Null) => builder.append_null(), #[cfg(feature = "dtype-datetime")] (Datetime(builder), AnyValue::Datetime(v)) => builder.append_value(v), + #[cfg(feature = "dtype-time")] + (Time(builder), AnyValue::Time(v)) => builder.append_value(v), (Float32(builder), AnyValue::Null) => builder.append_null(), (Float64(builder), AnyValue::Float64(v)) => builder.append_value(v), (Utf8(builder), AnyValue::Utf8(v)) => builder.append_value(v), @@ -297,6 +307,8 @@ impl Buffer { Date(b) => b.finish().into_date().into_series(), #[cfg(feature = "dtype-datetime")] Datetime(b) => b.finish().into_date().into_series(), + #[cfg(feature = "dtype-time")] + Time(b) => b.finish().into_date().into_series(), Float32(b) => b.finish().into_series(), Float64(b) => b.finish().into_series(), Utf8(b) => b.finish().into_series(), @@ -319,6 +331,8 @@ impl From<(&DataType, usize)> for Buffer { Date => Buffer::Date(PrimitiveChunkedBuilder::new("", len)), #[cfg(feature = "dtype-datetime")] Datetime => Buffer::Datetime(PrimitiveChunkedBuilder::new("", len)), + #[cfg(feature = "dtype-time")] + Time => Buffer::Time(PrimitiveChunkedBuilder::new("", len)), Float32 => Buffer::Float32(PrimitiveChunkedBuilder::new("", len)), Float64 => Buffer::Float64(PrimitiveChunkedBuilder::new("", len)), Utf8 => Buffer::Utf8(Utf8ChunkedBuilder::new("", len, len * 5)),