From 7b5480e4c4bfa9a462a51950221ae4bbcf427cab Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 18 May 2022 07:09:37 +0200 Subject: [PATCH] Improve list builders, iteration and construction - Greatly improves performance of the list builders by: https://github.com/jorgecarleitao/arrow2/pull/991 - List builders now also support nested dtypes like List and Struct - Python DataFrame and Series constructor now support better nested dtype construction --- polars/polars-arrow/Cargo.toml | 4 +- polars/polars-arrow/src/array/list.rs | 1 + polars/polars-core/Cargo.toml | 3 +- .../src/chunked_array/builder/list.rs | 129 ++++++++++++++---- .../src/chunked_array/list/iterator.rs | 27 ++++ polars/polars-core/src/datatypes.rs | 47 +++++-- polars/polars-core/src/series/ops/to_list.rs | 16 ++- polars/polars-io/Cargo.toml | 3 +- polars/polars-lazy/src/dsl/functions.rs | 2 +- .../src/chunked_array/list/namespace.rs | 1 + polars/polars-time/Cargo.toml | 5 +- .../polars-time/src/chunkedarray/datetime.rs | 1 + py-polars/Cargo.lock | 3 +- py-polars/polars/internals/construction.py | 29 ++-- py-polars/src/apply/series.rs | 4 +- py-polars/src/conversion.rs | 8 +- py-polars/src/list_construction.rs | 10 +- py-polars/src/series.rs | 6 + py-polars/tests/test_apply.py | 12 +- py-polars/tests/test_struct.py | 25 ++++ 20 files changed, 271 insertions(+), 65 deletions(-) diff --git a/polars/polars-arrow/Cargo.toml b/polars/polars-arrow/Cargo.toml index 51a04ad39d53..1262e936338d 100644 --- a/polars/polars-arrow/Cargo.toml +++ b/polars/polars-arrow/Cargo.toml @@ -9,7 +9,8 @@ description = "Arrow interfaces for Polars DataFrame library" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "826a2b8ed8598a614c5df9115ea657d1e3c40184", features = ["compute_concatenate"], default-features = false } +arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "aafba7b4eb4991e016638cbc1d4df676912e8236", features = ["compute_concatenate"], default-features = false } +# arrow = { package = "arrow2", path = "../../../arrow2", features = ["compute_concatenate"], default-features = false } # arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "polars", default-features = false } # arrow = { package = "arrow2", version = "0.11", default-features = false, features = ["compute_concatenate"] } hashbrown = "0.12" @@ -20,4 +21,5 @@ thiserror = "^1.0" [features] strings = [] compute = ["arrow/compute_cast"] +temporal = ["arrow/compute_temporal"] bigidx = [] diff --git a/polars/polars-arrow/src/array/list.rs b/polars/polars-arrow/src/array/list.rs index 91947e469ddb..9fe835c42d6e 100644 --- a/polars/polars-arrow/src/array/list.rs +++ b/polars/polars-arrow/src/array/list.rs @@ -31,6 +31,7 @@ impl<'a> AnonymousBuilder<'a> { self.arrays.is_empty() } + #[inline] pub fn push(&mut self, arr: &'a dyn Array) { self.size += arr.len() as i64; self.offsets.push(self.size); diff --git a/polars/polars-core/Cargo.toml b/polars/polars-core/Cargo.toml index d18223950099..aade4f50bb19 100644 --- a/polars/polars-core/Cargo.toml +++ b/polars/polars-core/Cargo.toml @@ -173,7 +173,8 @@ thiserror = "^1.0" package = "arrow2" git = "https://github.com/jorgecarleitao/arrow2" # git = "https://github.com/ritchie46/arrow2" -rev = "826a2b8ed8598a614c5df9115ea657d1e3c40184" +rev = "aafba7b4eb4991e016638cbc1d4df676912e8236" +# path = "../../../arrow2" # branch = "polars" # version = "0.11" default-features = false diff --git a/polars/polars-core/src/chunked_array/builder/list.rs b/polars/polars-core/src/chunked_array/builder/list.rs index 62283fb614ff..1e76545fbbeb 100644 --- a/polars/polars-core/src/chunked_array/builder/list.rs +++ b/polars/polars-core/src/chunked_array/builder/list.rs @@ -36,9 +36,9 @@ where pub struct ListPrimitiveChunkedBuilder where - T: NumericNative, + T: PolarsNumericType, { - pub builder: LargePrimitiveBuilder, + pub builder: LargePrimitiveBuilder, field: Field, fast_explode: bool, } @@ -62,7 +62,7 @@ macro_rules! finish_list_builder { impl ListPrimitiveChunkedBuilder where - T: NumericNative, + T: PolarsNumericType, { pub fn new( name: &str, @@ -70,8 +70,8 @@ where values_capacity: usize, logical_type: DataType, ) -> Self { - let values = MutablePrimitiveArray::::with_capacity(values_capacity); - let builder = LargePrimitiveBuilder::::new_with_capacity(values, capacity); + let values = MutablePrimitiveArray::::with_capacity(values_capacity); + let builder = LargePrimitiveBuilder::::new_with_capacity(values, capacity); let field = Field::new(name, DataType::List(Box::new(logical_type))); Self { @@ -81,7 +81,7 @@ where } } - pub fn append_slice(&mut self, opt_v: Option<&[T]>) { + pub fn append_slice(&mut self, opt_v: Option<&[T::Native]>) { match opt_v { Some(items) => { let values = self.builder.mut_values(); @@ -99,7 +99,7 @@ where } /// Appends from an iterator over values #[inline] - pub fn append_iter_values + TrustedLen>(&mut self, iter: I) { + pub fn append_iter_values + TrustedLen>(&mut self, iter: I) { let values = self.builder.mut_values(); if iter.size_hint().0 == 0 { @@ -113,7 +113,7 @@ where /// Appends from an iterator over values #[inline] - pub fn append_iter> + TrustedLen>(&mut self, iter: I) { + pub fn append_iter> + TrustedLen>(&mut self, iter: I) { let values = self.builder.mut_values(); if iter.size_hint().0 == 0 { @@ -128,7 +128,7 @@ where impl ListBuilderTrait for ListPrimitiveChunkedBuilder where - T: NumericNative, + T: PolarsNumericType, { #[inline] fn append_opt_series(&mut self, opt_s: Option<&Series>) { @@ -151,12 +151,10 @@ where if s.is_empty() { self.fast_explode = false; } - let arrays = s.chunks(); + let ca = s.unpack::().unwrap(); let values = self.builder.mut_values(); - arrays.iter().for_each(|x| { - let arr = x.as_any().downcast_ref::>().unwrap(); - + ca.downcast_iter().for_each(|arr| { if !arr.has_validity() { values.extend_from_slice(arr.values().as_slice()) } else { @@ -350,14 +348,23 @@ pub fn get_list_builder( #[cfg(feature = "object")] DataType::Object(_) => _err(), #[cfg(feature = "dtype-struct")] - DataType::Struct(_) => _err(), + DataType::Struct(_) => Ok(Box::new(AnonymousOwnedListBuilder::new( + name, + list_capacity, + physical_type, + ))), + DataType::List(_) => Ok(Box::new(AnonymousOwnedListBuilder::new( + name, + list_capacity, + physical_type, + ))), #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => _err(), _ => { macro_rules! get_primitive_builder { ($type:ty) => {{ let builder = ListPrimitiveChunkedBuilder::<$type>::new( - &name, + name, list_capacity, value_capacity, dt.clone(), @@ -379,7 +386,7 @@ pub fn get_list_builder( Box::new(builder) }}; } - Ok(match_dtype_to_physical_apply_macro!( + Ok(match_dtype_to_logical_apply_macro!( physical_type, get_primitive_builder, get_utf8_builder, @@ -395,12 +402,18 @@ pub struct AnonymousListBuilder<'a> { pub dtype: DataType, } +impl Default for AnonymousListBuilder<'_> { + fn default() -> Self { + Self::new("", 0, Default::default()) + } +} + impl<'a> AnonymousListBuilder<'a> { - pub fn new(name: &str, capacity: usize, dtype: DataType) -> Self { + pub fn new(name: &str, capacity: usize, inner_dtype: DataType) -> Self { Self { name: name.into(), builder: AnonymousBuilder::new(capacity), - dtype, + dtype: inner_dtype, } } @@ -440,17 +453,85 @@ impl<'a> AnonymousListBuilder<'a> { } } - pub fn finish(self) -> ListChunked { - if self.builder.is_empty() { - ListChunked::full_null_with_dtype(&self.name, 0, &self.dtype) + pub fn finish(&mut self) -> ListChunked { + let slf = std::mem::take(self); + if slf.builder.is_empty() { + ListChunked::full_null_with_dtype(&slf.name, 0, &slf.dtype) } else { - let arr = self + let arr = slf .builder - .finish(Some(&self.dtype.to_physical().to_arrow())) + .finish(Some(&slf.dtype.to_physical().to_arrow())) .unwrap(); let mut ca = ListChunked::from_chunks("", vec![Arc::new(arr)]); - ca.field = Arc::new(Field::new(&self.name, DataType::List(Box::new(self.dtype)))); + ca.field = Arc::new(Field::new(&slf.name, DataType::List(Box::new(slf.dtype)))); ca } } } + +pub struct AnonymousOwnedListBuilder { + name: String, + builder: AnonymousBuilder<'static>, + owned: Vec, + inner_dtype: DataType, +} + +impl Default for AnonymousOwnedListBuilder { + fn default() -> Self { + Self::new("", 0, Default::default()) + } +} + +impl ListBuilderTrait for AnonymousOwnedListBuilder { + fn append_series(&mut self, s: &Series) { + // Safety + // we deref a raw pointer with a lifetime that is not static + // it is safe because we also clone Series (Arc +=1) and therefore the &dyn Arrays + // will not be dropped until the owned series are dropped + unsafe { + match s.dtype() { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => self.builder.push(&*(&**s.array_ref(0) as *const dyn Array)), + _ => { + self.builder + .push_multiple(&*(s.chunks().as_ref() as *const [ArrayRef])); + } + } + } + // this make sure that the underlying ArrayRef's are not dropped + self.owned.push(s.clone()); + } + + fn append_null(&mut self) { + self.builder.push_null() + } + + fn finish(&mut self) -> ListChunked { + let slf = std::mem::take(self); + if slf.builder.is_empty() { + ListChunked::full_null_with_dtype(&slf.name, 0, &slf.inner_dtype) + } else { + let arr = slf + .builder + .finish(Some(&slf.inner_dtype.to_physical().to_arrow())) + .unwrap(); + let mut ca = ListChunked::from_chunks("", vec![Arc::new(arr)]); + ca.field = Arc::new(Field::new( + &slf.name, + DataType::List(Box::new(slf.inner_dtype)), + )); + ca + } + } +} + +impl AnonymousOwnedListBuilder { + pub fn new(name: &str, capacity: usize, inner_dtype: DataType) -> Self { + Self { + name: name.into(), + builder: AnonymousBuilder::new(capacity), + owned: Vec::with_capacity(capacity), + inner_dtype, + } + } +} diff --git a/polars/polars-core/src/chunked_array/list/iterator.rs b/polars/polars-core/src/chunked_array/list/iterator.rs index 91b86c220c38..13363d6207a7 100644 --- a/polars/polars-core/src/chunked_array/list/iterator.rs +++ b/polars/polars-core/src/chunked_array/list/iterator.rs @@ -12,6 +12,9 @@ pub struct AmortizedListIter<'a, I: Iterator>> { inner: NonNull, lifetime: PhantomData<&'a ArrayRef>, iter: I, + // used only if feature="dtype-struct" + #[allow(dead_code)] + inner_dtype: DataType, } impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a, I> { @@ -20,7 +23,30 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a fn next(&mut self) -> Option { self.iter.next().map(|opt_val| { opt_val.map(|array_ref| { + #[cfg(feature = "dtype-struct")] + // structs arrays are bound to the series not to the arrayref + // so we must get a hold to the new array + if matches!(self.inner_dtype, DataType::Struct(_)) { + // Safety + // dtype is known + unsafe { + let array_ref = Arc::from(array_ref); + let mut s = Series::from_chunks_and_dtype_unchecked( + "", + vec![array_ref], + &self.inner_dtype, + ); + // swap the new series with the container + std::mem::swap(&mut *self.series_container, &mut s); + // return a reference to the container + // this lifetime is now bound to 'a + return UnstableSeries::new(&*(&*self.series_container as *const Series)); + } + } + + // update the inner state unsafe { *self.inner.as_mut() = array_ref.into() }; + // Safety // we cannot control the lifetime of an iterators `next` method. // but as long as self is alive the reference to the series container is valid @@ -85,6 +111,7 @@ impl ListChunked { inner: NonNull::new(ptr).unwrap(), lifetime: PhantomData, iter: self.downcast_iter().flat_map(|arr| arr.iter()), + inner_dtype: self.inner_dtype(), } } diff --git a/polars/polars-core/src/datatypes.rs b/polars/polars-core/src/datatypes.rs index e74de81b8306..0c7a422301e4 100644 --- a/polars/polars-core/src/datatypes.rs +++ b/polars/polars-core/src/datatypes.rs @@ -141,17 +141,38 @@ pub trait NumericNative: + FromPrimitive + NativeArithmetics { + type POLARSTYPE; +} +impl NumericNative for i8 { + type POLARSTYPE = Int8Type; +} +impl NumericNative for i16 { + type POLARSTYPE = Int16Type; +} +impl NumericNative for i32 { + type POLARSTYPE = Int32Type; +} +impl NumericNative for i64 { + type POLARSTYPE = Int64Type; +} +impl NumericNative for u8 { + type POLARSTYPE = UInt8Type; +} +impl NumericNative for u16 { + type POLARSTYPE = UInt16Type; +} +impl NumericNative for u32 { + type POLARSTYPE = UInt32Type; +} +impl NumericNative for u64 { + type POLARSTYPE = UInt64Type; +} +impl NumericNative for f32 { + type POLARSTYPE = Float32Type; +} +impl NumericNative for f64 { + type POLARSTYPE = Float64Type; } -impl NumericNative for i8 {} -impl NumericNative for i16 {} -impl NumericNative for i32 {} -impl NumericNative for i64 {} -impl NumericNative for u8 {} -impl NumericNative for u16 {} -impl NumericNative for u32 {} -impl NumericNative for u64 {} -impl NumericNative for f32 {} -impl NumericNative for f64 {} pub trait PolarsNumericType: Send + Sync + PolarsDataType + 'static { type Native: NumericNative; @@ -678,6 +699,12 @@ pub enum DataType { Unknown, } +impl Default for DataType { + fn default() -> Self { + DataType::Unknown + } +} + impl Hash for DataType { fn hash(&self, state: &mut H) { std::mem::discriminant(self).hash(state) diff --git a/polars/polars-core/src/series/ops/to_list.rs b/polars/polars-core/src/series/ops/to_list.rs index 7fa223f7bbe4..09f4eb954d13 100644 --- a/polars/polars-core/src/series/ops/to_list.rs +++ b/polars/polars-core/src/series/ops/to_list.rs @@ -4,11 +4,17 @@ use polars_arrow::kernels::list::array_to_unit_list; use std::borrow::Cow; fn reshape_fast_path(name: &str, s: &Series) -> Series { - let chunks = s - .chunks() - .iter() - .map(|arr| Arc::new(array_to_unit_list(arr.clone())) as ArrayRef) - .collect::>(); + let chunks = match s.dtype() { + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => { + vec![Arc::new(array_to_unit_list(s.array_ref(0).clone())) as ArrayRef] + } + _ => s + .chunks() + .iter() + .map(|arr| Arc::new(array_to_unit_list(arr.clone())) as ArrayRef) + .collect::>(), + }; let mut ca = ListChunked::from_chunks(name, chunks); ca.set_fast_explode(); diff --git a/polars/polars-io/Cargo.toml b/polars/polars-io/Cargo.toml index a4d17a096683..cf753db8d6f1 100644 --- a/polars/polars-io/Cargo.toml +++ b/polars/polars-io/Cargo.toml @@ -35,9 +35,10 @@ private = ["polars-time/private"] [dependencies] ahash = "0.7" anyhow = "1.0" -arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "826a2b8ed8598a614c5df9115ea657d1e3c40184", default-features = false } +arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "aafba7b4eb4991e016638cbc1d4df676912e8236", default-features = false } # arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "polars", default-features = false } # arrow = { package = "arrow2", version = "0.11", default-features = false } +# arrow = { package = "arrow2", path = "../../../arrow2", default-features = false } csv-core = { version = "0.1.10", optional = true } dirs = "4.0" flate2 = { version = "1", optional = true, default-features = false } diff --git a/polars/polars-lazy/src/dsl/functions.rs b/polars/polars-lazy/src/dsl/functions.rs index ba71a17587cd..ed1b8dd704ec 100644 --- a/polars/polars-lazy/src/dsl/functions.rs +++ b/polars/polars-lazy/src/dsl/functions.rs @@ -275,7 +275,7 @@ pub fn arange(low: Expr, high: Expr, step: usize) -> Expr { let sb = sb.cast(&DataType::Int64)?; let low = sa.i64()?; let high = sb.i64()?; - let mut builder = ListPrimitiveChunkedBuilder::::new( + let mut builder = ListPrimitiveChunkedBuilder::::new( "arange", low.len(), low.len() * 3, diff --git a/polars/polars-ops/src/chunked_array/list/namespace.rs b/polars/polars-ops/src/chunked_array/list/namespace.rs index a5d3248e16bd..ed63876690d4 100644 --- a/polars/polars-ops/src/chunked_array/list/namespace.rs +++ b/polars/polars-ops/src/chunked_array/list/namespace.rs @@ -217,6 +217,7 @@ pub trait ListNameSpaceImpl: AsList { let length = ca.len(); let mut other = other.to_vec(); let dtype = ca.dtype(); + dbg!(&ca, ca.dtype()); let inner_type = ca.inner_dtype(); // broadcasting path in case all unit length diff --git a/polars/polars-time/Cargo.toml b/polars/polars-time/Cargo.toml index ad9fe725f334..cf586d2068ed 100644 --- a/polars/polars-time/Cargo.toml +++ b/polars/polars-time/Cargo.toml @@ -9,16 +9,15 @@ description = "Time related code for the polars dataframe library" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "826a2b8ed8598a614c5df9115ea657d1e3c40184", default-features = false } chrono = "0.4" lexical = { version = "6", default-features = false, features = ["std", "parse-floats", "parse-integers"] } -polars-arrow = { version = "0.21.1", path = "../polars-arrow", features = ["compute"] } +polars-arrow = { version = "0.21.1", path = "../polars-arrow", features = ["compute", "temporal"] } polars-core = { version = "0.21.1", path = "../polars-core", default-features = false, features = ["private", "dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } serde = { version = "1", features = ["derive"], optional = true } [features] dtype-date = ["polars-core/dtype-date", "polars-core/temporal"] -dtype-datetime = ["polars-core/dtype-date", "polars-core/temporal", "arrow/compute_temporal"] +dtype-datetime = ["polars-core/dtype-date", "polars-core/temporal"] dtype-time = ["polars-core/dtype-time", "polars-core/temporal"] dtype-duration = ["polars-core/dtype-duration", "polars-core/temporal"] private = [] diff --git a/polars/polars-time/src/chunkedarray/datetime.rs b/polars/polars-time/src/chunkedarray/datetime.rs index b7ce12e4f4b5..e3bc7db0486e 100644 --- a/polars/polars-time/src/chunkedarray/datetime.rs +++ b/polars/polars-time/src/chunkedarray/datetime.rs @@ -3,6 +3,7 @@ use arrow::array::{Array, ArrayRef, PrimitiveArray}; use arrow::compute::cast::CastOptions; use arrow::compute::{cast::cast, temporal}; use arrow::error::Result as ArrowResult; +use polars_arrow::export::arrow; use polars_core::prelude::*; fn cast_and_apply< diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 414b91e344e1..4579df434d1b 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -74,7 +74,7 @@ dependencies = [ [[package]] name = "arrow2" version = "0.11.2" -source = "git+https://github.com/jorgecarleitao/arrow2?rev=826a2b8ed8598a614c5df9115ea657d1e3c40184#826a2b8ed8598a614c5df9115ea657d1e3c40184" +source = "git+https://github.com/jorgecarleitao/arrow2?rev=aafba7b4eb4991e016638cbc1d4df676912e8236#aafba7b4eb4991e016638cbc1d4df676912e8236" dependencies = [ "arrow-format", "avro-schema", @@ -1226,7 +1226,6 @@ dependencies = [ name = "polars-time" version = "0.21.1" dependencies = [ - "arrow2", "chrono", "lexical", "polars-arrow", diff --git a/py-polars/polars/internals/construction.py b/py-polars/polars/internals/construction.py index dd512befe1dc..fa34317f8227 100644 --- a/py-polars/polars/internals/construction.py +++ b/py-polars/polars/internals/construction.py @@ -131,6 +131,18 @@ def _get_first_non_none(values: Sequence[Optional[Any]]) -> Any: return next((v for v in values if v is not None), None) +def sequence_from_anyvalue_or_object(name: str, values: Sequence[Any]) -> "PySeries": + """ + Last resort conversion. AnyValues are most flexible and if they fail we go for object types + """ + + try: + return PySeries.new_from_anyvalues(name, values) + # raised if we cannot convert to Wrap + except RuntimeError: + return PySeries.new_object(name, values, False) + + def sequence_to_pyseries( name: str, values: Sequence[Any], @@ -208,11 +220,8 @@ def sequence_to_pyseries( else: try: nested_arrow_dtype = py_type_to_arrow_type(nested_dtype) - except ValueError as e: # pragma: no cover - raise ValueError( - f"Cannot construct Series from sequence of {nested_dtype}." - ) from e - + except ValueError: # pragma: no cover + return sequence_from_anyvalue_or_object(name, values) try: arrow_values = pa.array(values, pa.large_list(nested_arrow_dtype)) return arrow_to_pyseries(name, arrow_values) @@ -226,15 +235,15 @@ def sequence_to_pyseries( return PySeries.new_series_list(name, [v.inner() for v in values], strict) elif dtype_ == PySeries: return PySeries.new_series_list(name, values, strict) - else: constructor = py_type_to_constructor(dtype_) if constructor == PySeries.new_object: - np_constructor = numpy_type_to_constructor(dtype_) - if np_constructor is not None: - values = np.array(values) # type: ignore - constructor = np_constructor + try: + return PySeries.new_from_anyvalues(name, values) + # raised if we cannot convert to Wrap + except RuntimeError: + return sequence_from_anyvalue_or_object(name, values) return constructor(name, values, strict) diff --git a/py-polars/src/apply/series.rs b/py-polars/src/apply/series.rs index f20511b00b22..241d850f7985 100644 --- a/py-polars/src/apply/series.rs +++ b/py-polars/src/apply/series.rs @@ -1,5 +1,5 @@ use super::*; -use crate::conversion::to_wrapped; +use crate::conversion::slice_to_wrapped; use crate::series::PySeries; use crate::{PyPolarsErr, Wrap}; use polars::chunked_array::builder::get_list_builder; @@ -1902,7 +1902,7 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { fn make_dict_arg(py: Python, names: &[&str], vals: &[AnyValue]) -> Py { let dict = PyDict::new(py); - for (name, val) in names.iter().zip(to_wrapped(vals)) { + for (name, val) in names.iter().zip(slice_to_wrapped(vals)) { dict.set_item(name, val).unwrap() } dict.into_py(py) diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index dc040da1062a..f91ec3ad9818 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -19,7 +19,13 @@ use pyo3::{PyAny, PyResult}; use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; -pub(crate) fn to_wrapped(slice: &[T]) -> &[Wrap] { +pub(crate) fn slice_to_wrapped(slice: &[T]) -> &[Wrap] { + // Safety: + // Wrap is transparent. + unsafe { std::mem::transmute(slice) } +} + +pub(crate) fn slice_extract_wrapped(slice: &[Wrap]) -> &[T] { // Safety: // Wrap is transparent. unsafe { std::mem::transmute(slice) } diff --git a/py-polars/src/list_construction.rs b/py-polars/src/list_construction.rs index 53f0129f932a..24d80021db92 100644 --- a/py-polars/src/list_construction.rs +++ b/py-polars/src/list_construction.rs @@ -8,7 +8,7 @@ pub fn py_seq_to_list(name: &str, seq: &PyAny, dtype: &DataType) -> PyResult { let mut builder = - ListPrimitiveChunkedBuilder::::new(name, len, len * 5, DataType::Int64); + ListPrimitiveChunkedBuilder::::new(name, len, len * 5, DataType::Int64); for sub_seq in seq.iter()? { let sub_seq = sub_seq?; let (sub_seq, len) = get_pyseq(sub_seq)?; @@ -32,8 +32,12 @@ pub fn py_seq_to_list(name: &str, seq: &PyAny, dtype: &DataType) -> PyResult { - let mut builder = - ListPrimitiveChunkedBuilder::::new(name, len, len * 5, DataType::Float64); + let mut builder = ListPrimitiveChunkedBuilder::::new( + name, + len, + len * 5, + DataType::Float64, + ); for sub_seq in seq.iter()? { let sub_seq = sub_seq?; let (sub_seq, len) = get_pyseq(sub_seq)?; diff --git a/py-polars/src/series.rs b/py-polars/src/series.rs index ea9525146731..3c73f1c60b9d 100644 --- a/py-polars/src/series.rs +++ b/py-polars/src/series.rs @@ -197,6 +197,12 @@ impl From for PySeries { clippy::len_without_is_empty )] impl PySeries { + #[staticmethod] + pub fn new_from_anyvalues(name: &str, val: Vec>>) -> PySeries { + let avs = slice_extract_wrapped(&val); + Series::new(name, avs).into() + } + #[staticmethod] pub fn new_str(name: &str, val: Wrap, _strict: bool) -> Self { let mut s = val.0.into_series(); diff --git a/py-polars/tests/test_apply.py b/py-polars/tests/test_apply.py index c4bd4f52c2cf..52aeb250f546 100644 --- a/py-polars/tests/test_apply.py +++ b/py-polars/tests/test_apply.py @@ -1,3 +1,4 @@ +import typing from datetime import date, datetime, timedelta from functools import reduce from typing import List, Optional @@ -55,6 +56,7 @@ def test_apply_return_py_object() -> None: assert out.shape == (1, 2) +@typing.no_type_check def test_agg_objects() -> None: df = pl.DataFrame( { @@ -64,8 +66,16 @@ def test_agg_objects() -> None: } ) + class Foo: + def __init__(self, payload): + self.payload = payload + out = df.groupby("groups").agg( - [pl.apply([pl.col("dates"), pl.col("names")], lambda s: dict(zip(s[0], s[1])))] + [ + pl.apply( + [pl.col("dates"), pl.col("names")], lambda s: Foo(dict(zip(s[0], s[1]))) + ) + ] ) assert out.dtypes == [pl.Utf8, pl.Object] diff --git a/py-polars/tests/test_struct.py b/py-polars/tests/test_struct.py index df306fa5402d..78d47e17b798 100644 --- a/py-polars/tests/test_struct.py +++ b/py-polars/tests/test_struct.py @@ -278,3 +278,28 @@ def test_sort_df_with_list_struct() -> None: "a": [1], "b": [[{"c": 1}]], } + + +def test_struct_list_head_tail() -> None: + assert pl.DataFrame( + { + "list_of_struct": [ + [{"a": 1, "b": 4}, {"a": 3, "b": 6}], + [{"a": 10, "b": 40}, {"a": 20, "b": 50}, {"a": 30, "b": 60}], + ] + } + ).with_columns( + [ + pl.col("list_of_struct").arr.head(1).alias("head"), + pl.col("list_of_struct").arr.tail(1).alias("tail"), + ] + ).to_dict( + False + ) == { + "list_of_struct": [ + [{"a": 1, "b": 4}, {"a": 3, "b": 6}], + [{"a": 10, "b": 40}, {"a": 20, "b": 50}, {"a": 30, "b": 60}], + ], + "head": [[{"a": 1, "b": 4}], [{"a": 10, "b": 40}]], + "tail": [[{"a": 3, "b": 6}], [{"a": 30, "b": 60}]], + }