diff --git a/polars/polars-core/src/frame/asof_join/groups.rs b/polars/polars-core/src/frame/asof_join/groups.rs index 8ca158c81c0f..6c02e568844e 100644 --- a/polars/polars-core/src/frame/asof_join/groups.rs +++ b/polars/polars-core/src/frame/asof_join/groups.rs @@ -28,24 +28,35 @@ pub(super) unsafe fn join_asof_backward_with_indirection_and_tolerance< if offsets.is_empty() { return (None, 0); } - let mut previous = *offsets.get_unchecked(0); - let first = *right.get_unchecked(previous as usize); + let mut previous_idx = *offsets.get_unchecked(0); + let first = *right.get_unchecked(previous_idx as usize); if val_l < first { (None, 0) } else { for (idx, &offset) in offsets.iter().enumerate() { let val_r = *right.get_unchecked(offset as usize); + + // the point that is larger is not allowed if val_r > val_l { - let dist = val_l - val_r; + // compute the distance of previous point, that one was still backwards + let previous_value = *right.get_unchecked(previous_idx as usize); + let dist = val_l - previous_value; return if dist > tolerance { (None, idx) } else { - (Some(previous), idx) + (Some(previous_idx), idx) }; } - previous = offset + previous_idx = offset + } + // check remaining values that still suffice the distance constraint + let previous_value = *right.get_unchecked(previous_idx as usize); + let dist = val_l - previous_value; + if dist > tolerance { + (None, offsets.len()) + } else { + (Some(previous_idx), offsets.len()) } - (None, offsets.len()) } } diff --git a/py-polars/tests/test_joins.py b/py-polars/tests/test_joins.py index 6000610a1d72..c018a28bbc5c 100644 --- a/py-polars/tests/test_joins.py +++ b/py-polars/tests/test_joins.py @@ -1,3 +1,5 @@ +from datetime import datetime + import numpy as np import polars as pl @@ -122,3 +124,59 @@ def test_join_asof_floats() -> None: "b": ["lrow1", "lrow2", "lrow3"], "b_right": ["rrow1", "rrow2", "rrow3"], } + + +def test_join_asof_tolerance() -> None: + df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + } + ) + + df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 0, 2), + datetime(2020, 1, 1, 9, 0, 4), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "C", "A"], + "quote": [100, 300, 501, 102], + } + ) + + assert df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="2s" + ).to_dict(False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + "quote": [100, None, 300, 501], + } + + assert df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="1s" + ).to_dict(False) == { + "time": [ + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 1), + datetime(2020, 1, 1, 9, 0, 3), + datetime(2020, 1, 1, 9, 0, 6), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + "quote": [100, None, 300, None], + }