Skip to content

Commit

Permalink
feat: Implements arr.shift (pola-rs#14298)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Feb 7, 2024
1 parent 8c382c3 commit 279540f
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 1 deletion.
25 changes: 25 additions & 0 deletions crates/polars-core/src/chunked_array/array/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,31 @@ impl ArrayChunked {
.collect_ca_with_dtype(self.name(), self.dtype().clone())
}

/// Zip with a `ChunkedArray` then apply a binary function `F` elementwise.
/// # Safety
// Return series of `F` must has the same dtype and number of elements as input series.
#[must_use]
pub unsafe fn zip_and_apply_amortized_same_type<'a, T, F>(
&'a self,
ca: &'a ChunkedArray<T>,
mut f: F,
) -> Self
where
T: PolarsDataType,
F: FnMut(Option<UnstableSeries<'a>>, Option<T::Physical<'a>>) -> Option<Series>,
{
if self.is_empty() {
return self.clone();
}
self.amortized_iter()
.zip(ca.iter())
.map(|(opt_s, opt_v)| {
let out = f(opt_s, opt_v);
out.map(|s| to_arr(&s))
})
.collect_ca_with_dtype(self.name(), self.dtype().clone())
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray<V>
Expand Down
33 changes: 33 additions & 0 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,39 @@ pub trait ArrayNameSpace: AsArray {
let ca = self.as_array();
array_count_matches(ca, element)
}

fn array_shift(&self, n: &Series) -> PolarsResult<Series> {
let ca = self.as_array();
let n_s = n.cast(&DataType::Int64)?;
let n = n_s.i64()?;
let out = match n.len() {
1 => {
if let Some(n) = n.get(0) {
// SAFETY: Shift does not change the dtype and number of elements of sub-array.
unsafe { ca.apply_amortized_same_type(|s| s.as_ref().shift(n)) }
} else {
ArrayChunked::full_null_with_dtype(
ca.name(),
ca.len(),
&ca.inner_dtype(),
ca.width(),
)
}
},
_ => {
// SAFETY: Shift does not change the dtype and number of elements of sub-array.
unsafe {
ca.zip_and_apply_amortized_same_type(n, |opt_s, opt_periods| {
match (opt_s, opt_periods) {
(Some(s), Some(n)) => Some(s.as_ref().shift(n)),
_ => None,
}
})
}
},
};
Ok(out.into_series())
}
}

impl ArrayNameSpace for ArrayChunked {}
10 changes: 10 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,14 @@ impl ArrayNameSpace {
)
.with_fmt("arr.to_struct")
}

/// Shift every sub-array.
pub fn shift(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::Shift),
&[n],
false,
false,
)
}
}
11 changes: 11 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum ArrayFunction {
Contains,
#[cfg(feature = "array_count")]
CountMatches,
Shift,
}

impl ArrayFunction {
Expand All @@ -52,6 +53,7 @@ impl ArrayFunction {
Contains => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
Shift => mapper.with_same_dtype(),
}
}
}
Expand Down Expand Up @@ -90,6 +92,7 @@ impl Display for ArrayFunction {
Contains => "contains",
#[cfg(feature = "array_count")]
CountMatches => "count_matches",
Shift => "shift",
};
write!(f, "arr.{name}")
}
Expand Down Expand Up @@ -121,6 +124,7 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Contains => map_as_slice!(contains),
#[cfg(feature = "array_count")]
CountMatches => map_as_slice!(count_matches),
Shift => map_as_slice!(shift),
}
}
}
Expand Down Expand Up @@ -224,3 +228,10 @@ pub(super) fn count_matches(args: &[Series]) -> PolarsResult<Series> {
let ca = s.array()?;
ca.array_count_matches(element.get(0).unwrap())
}

pub(super) fn shift(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].array()?;
let n = &s[1];

ca.array_shift(n)
}
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ The following methods are available under the `expr.arr` attribute.
Expr.arr.contains
Expr.arr.count_matches
Expr.arr.to_struct
Expr.arr.shift
3 changes: 2 additions & 1 deletion py-polars/docs/source/reference/series/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ The following methods are available under the `Series.arr` attribute.
Series.arr.explode
Series.arr.contains
Series.arr.count_matches
Series.arr.to_struct
Series.arr.to_struct
Series.arr.shift
49 changes: 49 additions & 0 deletions py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,52 @@ def to_struct(
else:
pyexpr = self._pyexpr.arr_to_struct(fields)
return wrap_expr(pyexpr)

def shift(self, n: int | IntoExprColumn = 1) -> Expr:
"""
Shift array values by the given number of indices.
Parameters
----------
n
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
Notes
-----
This method is similar to the `LAG` operation in SQL when the value for `n`
is positive. With a negative value for `n`, it is similar to `LEAD`.
Examples
--------
By default, array values are shifted forward by one index.
>>> df = pl.DataFrame(
... {"a": [[1, 2, 3], [4, 5, 6]]}, schema={"a": pl.Array(pl.Int64, 3)}
... )
>>> df.with_columns(shift=pl.col("a").arr.shift())
shape: (2, 2)
┌───────────────┬───────────────┐
│ a ┆ shift │
│ --- ┆ --- │
│ array[i64, 3] ┆ array[i64, 3] │
╞═══════════════╪═══════════════╡
│ [1, 2, 3] ┆ [null, 1, 2] │
│ [4, 5, 6] ┆ [null, 4, 5] │
└───────────────┴───────────────┘
Pass a negative value to shift in the opposite direction instead.
>>> df.with_columns(shift=pl.col("a").arr.shift(-2))
shape: (2, 2)
┌───────────────┬─────────────────┐
│ a ┆ shift │
│ --- ┆ --- │
│ array[i64, 3] ┆ array[i64, 3] │
╞═══════════════╪═════════════════╡
│ [1, 2, 3] ┆ [3, null, null] │
│ [4, 5, 6] ┆ [6, null, null] │
└───────────────┴─────────────────┘
"""
n = parse_as_expression(n)
return wrap_expr(self._pyexpr.arr_shift(n))
39 changes: 39 additions & 0 deletions py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,42 @@ def to_struct(
"""
s = wrap_s(self._s)
return s.to_frame().select(F.col(s.name).arr.to_struct(fields)).to_series()

def shift(self, n: int | IntoExprColumn = 1) -> Series:
"""
Shift array values by the given number of indices.
Parameters
----------
n
Number of indices to shift forward. If a negative value is passed, values
are shifted in the opposite direction instead.
Notes
-----
This method is similar to the `LAG` operation in SQL when the value for `n`
is positive. With a negative value for `n`, it is similar to `LEAD`.
Examples
--------
By default, array values are shifted forward by one index.
>>> s = pl.Series([[1, 2, 3], [4, 5, 6]], dtype=pl.Array(pl.Int64, 3))
>>> s.arr.shift()
shape: (2,)
Series: '' [array[i64, 3]]
[
[null, 1, 2]
[null, 4, 5]
]
Pass a negative value to shift in the opposite direction instead.
>>> s.arr.shift(-2)
shape: (2,)
Series: '' [array[i64, 3]]
[
[3, null, null]
[6, null, null]
]
"""
4 changes: 4 additions & 0 deletions py-polars/src/expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,8 @@ impl PyExpr {

Ok(self.inner.clone().arr().to_struct(name_gen).into())
}

fn arr_shift(&self, n: PyExpr) -> Self {
self.inner.clone().arr().shift(n.inner).into()
}
}
19 changes: 19 additions & 0 deletions py-polars/tests/unit/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,22 @@ def test_array_to_struct() -> None:
assert df.lazy().select(pl.col("a").arr.to_struct()).unnest(
"a"
).sum().collect().columns == ["field_0", "field_1", "field_2"]


def test_array_shift() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], None, [4, 5, 6], [7, 8, 9]], "n": [None, 1, 1, -2]},
schema={"a": pl.Array(pl.Int64, 3), "n": pl.Int64},
)

out = df.select(
lit=pl.col("a").arr.shift(1), expr=pl.col("a").arr.shift(pl.col("n"))
)
expected = pl.DataFrame(
{
"lit": [[None, 1, 2], None, [None, 4, 5], [None, 7, 8]],
"expr": [None, None, [None, 4, 5], [9, None, None]],
},
schema={"lit": pl.Array(pl.Int64, 3), "expr": pl.Array(pl.Int64, 3)},
)
assert_frame_equal(out, expected)

0 comments on commit 279540f

Please sign in to comment.