Skip to content

Commit

Permalink
adjust tests for 24.04
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jul 29, 2024
1 parent b635b33 commit 3862601
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
8 changes: 7 additions & 1 deletion nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from dask.highlevelgraph import HighLevelGraph
from dask.utils import parse_bytes
from fsspec.core import get_fs_token_paths
from packaging.version import Version

from merlin.core import dispatch
from merlin.core.dispatch import DataFrameType, annotate, is_cpu_object, nullable_series
Expand All @@ -53,6 +54,7 @@
PAD_OFFSET = 0
NULL_OFFSET = 1
OOV_OFFSET = 2
PA_GE_14 = Version(pa.__version__) >= Version("14.0")


class Categorify(StatOperator):
Expand Down Expand Up @@ -907,7 +909,11 @@ def _general_concat(
):
# Concatenate DataFrame or pa.Table objects
if isinstance(frames[0], pa.Table):
df = pa.concat_tables(frames, promote=True)
if PA_GE_14:
df = pa.concat_tables(frames, promote_options="default")
else:
df = pa.concat_tables(frames, promote=True)

if (
cardinality_memory_limit
and col_selector is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/ops/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_lambdaop_dtype_multi_op_propagation(cpu):
{
"a": np.arange(size),
"b": np.random.choice(["apple", "banana", "orange"], size),
"c": np.random.choice([0, 1], size).astype(np.float16),
"c": np.random.choice([0, 1], size),
}
)
ddf0 = dd.from_pandas(df0, npartitions=4)
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/test_dask_nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,12 @@ def test_dask_groupby_stats(client, tmpdir, datasets, part_mem_fraction):
gb_e = expect.groupby("name-cat").aggregate({"name-cat": "count", "x": ["sum", "min", "std"]})
gb_e.columns = ["count", "sum", "min", "std"]
df_check = got.merge(gb_e, left_on="name-cat", right_index=True, how="left")
assert_eq(df_check["name-cat_count"], df_check["count"], check_names=False)
assert_eq(df_check["name-cat_x_sum"], df_check["sum"], check_names=False)
assert_eq(df_check["name-cat_x_min"], df_check["min"], check_names=False)
assert_eq(df_check["name-cat_x_std"], df_check["std"].astype("float32"), check_names=False)
# Names and dtypes don't need to match (just values)
options = {"check_names": False, "check_dtype": False}
assert_eq(df_check["name-cat_count"], df_check["count"], **options)
assert_eq(df_check["name-cat_x_sum"], df_check["sum"], **options)
assert_eq(df_check["name-cat_x_min"], df_check["min"], **options)
assert_eq(df_check["name-cat_x_std"], df_check["std"], **options)


@pytest.mark.parametrize("part_mem_fraction", [0.01])
Expand Down

0 comments on commit 3862601

Please sign in to comment.