From ba42fd1dfd0147819854cd44616e2e0d8ca66f23 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 14 May 2024 10:54:02 -0500 Subject: [PATCH] fix: fix return type hint for dedupe() (#90) --- .../_rapids_dependency_file_generator.py | 32 ++++++++++--------- .../test_rapids_dependency_file_generator.py | 11 ++++++- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/rapids_dependency_file_generator/_rapids_dependency_file_generator.py b/src/rapids_dependency_file_generator/_rapids_dependency_file_generator.py index c17f208a..7416ed9e 100644 --- a/src/rapids_dependency_file_generator/_rapids_dependency_file_generator.py +++ b/src/rapids_dependency_file_generator/_rapids_dependency_file_generator.py @@ -3,7 +3,6 @@ import os import textwrap import typing -from collections import defaultdict from collections.abc import Generator import tomlkit @@ -42,7 +41,7 @@ def delete_existing_files(root: os.PathLike) -> None: def dedupe( dependencies: list[typing.Union[str, _config.PipRequirements]], -) -> list[typing.Union[str, dict[str, str]]]: +) -> typing.Sequence[typing.Union[str, dict[str, list[str]]]]: """Generate the unique set of dependencies contained in a dependency list. Parameters @@ -52,18 +51,21 @@ def dedupe( Returns ------- - list[str | dict[str, str]] - The `dependencies` with all duplicates removed. + Sequence[str | dict[str, list[str]]] + The ``dependencies`` with all duplicates removed. """ - deduped = sorted({dep for dep in dependencies if isinstance(dep, str)}) - dict_deps = defaultdict(list) - for dep in filter(lambda dep: not isinstance(dep, str), dependencies): - if isinstance(dep, _config.PipRequirements): - dict_deps["pip"].extend(dep.pip) - dict_deps["pip"] = sorted(set(dict_deps["pip"])) - if dict_deps: - deduped.append(dict(dict_deps)) - return deduped + string_deps: set[str] = set() + pip_deps: set[str] = set() + for dep in dependencies: + if isinstance(dep, str): + string_deps.add(dep) + elif isinstance(dep, _config.PipRequirements): + pip_deps.update(dep.pip) + + if pip_deps: + return [*sorted(string_deps), {"pip": sorted(pip_deps)}] + else: + return sorted(string_deps) def grid(gridspec: dict[str, list[str]]) -> Generator[dict[str, str]]: @@ -97,7 +99,7 @@ def make_dependency_file( config_file: os.PathLike, output_dir: os.PathLike, conda_channels: list[str], - dependencies: list[typing.Union[str, dict[str, list[str]]]], + dependencies: typing.Sequence[typing.Union[str, dict[str, list[str]]]], extras: _config.FileExtras, ): """Generate the contents of the dependency file. @@ -115,7 +117,7 @@ def make_dependency_file( conda_channels : list[str] The channels to include in the file. Only used when `file_type` is CONDA. - dependencies : list[str | dict[str, list[str]]] + dependencies : Sequence[str | dict[str, list[str]]] The dependencies to include in the file. extras : FileExtras Any extra information provided for generating this dependency file. diff --git a/tests/test_rapids_dependency_file_generator.py b/tests/test_rapids_dependency_file_generator.py index 2f34a40a..33b60744 100644 --- a/tests/test_rapids_dependency_file_generator.py +++ b/tests/test_rapids_dependency_file_generator.py @@ -19,7 +19,7 @@ def test_dedupe(): deduped = dedupe(["dep1", "dep1", "dep2"]) assert deduped == ["dep1", "dep2"] - # list w/ pip dependencies + # list w/ mix of simple and pip dependencies deduped = dedupe( [ "dep1", @@ -30,6 +30,15 @@ def test_dedupe(): ) assert deduped == ["dep1", {"pip": ["pip_dep1", "pip_dep2"]}] + # list w/ only pip dependencies + deduped = dedupe( + [ + _config.PipRequirements(pip=["pip_dep1", "pip_dep2"]), + _config.PipRequirements(pip=["pip_dep3", "pip_dep1"]), + ] + ) + assert deduped == [{"pip": ["pip_dep1", "pip_dep2", "pip_dep3"]}] + @mock.patch( "rapids_dependency_file_generator._rapids_dependency_file_generator.os.path.relpath"