Skip to content

Commit

Permalink
fix: fix return type hint for dedupe() (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored May 14, 2024
1 parent a8660f4 commit ba42fd1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import textwrap
import typing
from collections import defaultdict
from collections.abc import Generator

import tomlkit
Expand Down Expand Up @@ -42,7 +41,7 @@ def delete_existing_files(root: os.PathLike) -> None:

def dedupe(
dependencies: list[typing.Union[str, _config.PipRequirements]],
) -> list[typing.Union[str, dict[str, str]]]:
) -> typing.Sequence[typing.Union[str, dict[str, list[str]]]]:
"""Generate the unique set of dependencies contained in a dependency list.
Parameters
Expand All @@ -52,18 +51,21 @@ def dedupe(
Returns
-------
list[str | dict[str, str]]
The `dependencies` with all duplicates removed.
Sequence[str | dict[str, list[str]]]
The ``dependencies`` with all duplicates removed.
"""
deduped = sorted({dep for dep in dependencies if isinstance(dep, str)})
dict_deps = defaultdict(list)
for dep in filter(lambda dep: not isinstance(dep, str), dependencies):
if isinstance(dep, _config.PipRequirements):
dict_deps["pip"].extend(dep.pip)
dict_deps["pip"] = sorted(set(dict_deps["pip"]))
if dict_deps:
deduped.append(dict(dict_deps))
return deduped
string_deps: set[str] = set()
pip_deps: set[str] = set()
for dep in dependencies:
if isinstance(dep, str):
string_deps.add(dep)
elif isinstance(dep, _config.PipRequirements):
pip_deps.update(dep.pip)

if pip_deps:
return [*sorted(string_deps), {"pip": sorted(pip_deps)}]
else:
return sorted(string_deps)


def grid(gridspec: dict[str, list[str]]) -> Generator[dict[str, str]]:
Expand Down Expand Up @@ -97,7 +99,7 @@ def make_dependency_file(
config_file: os.PathLike,
output_dir: os.PathLike,
conda_channels: list[str],
dependencies: list[typing.Union[str, dict[str, list[str]]]],
dependencies: typing.Sequence[typing.Union[str, dict[str, list[str]]]],
extras: _config.FileExtras,
):
"""Generate the contents of the dependency file.
Expand All @@ -115,7 +117,7 @@ def make_dependency_file(
conda_channels : list[str]
The channels to include in the file. Only used when `file_type` is
CONDA.
dependencies : list[str | dict[str, list[str]]]
dependencies : Sequence[str | dict[str, list[str]]]
The dependencies to include in the file.
extras : FileExtras
Any extra information provided for generating this dependency file.
Expand Down
11 changes: 10 additions & 1 deletion tests/test_rapids_dependency_file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_dedupe():
deduped = dedupe(["dep1", "dep1", "dep2"])
assert deduped == ["dep1", "dep2"]

# list w/ pip dependencies
# list w/ mix of simple and pip dependencies
deduped = dedupe(
[
"dep1",
Expand All @@ -30,6 +30,15 @@ def test_dedupe():
)
assert deduped == ["dep1", {"pip": ["pip_dep1", "pip_dep2"]}]

# list w/ only pip dependencies
deduped = dedupe(
[
_config.PipRequirements(pip=["pip_dep1", "pip_dep2"]),
_config.PipRequirements(pip=["pip_dep3", "pip_dep1"]),
]
)
assert deduped == [{"pip": ["pip_dep1", "pip_dep2", "pip_dep3"]}]


@mock.patch(
"rapids_dependency_file_generator._rapids_dependency_file_generator.os.path.relpath"
Expand Down

0 comments on commit ba42fd1

Please sign in to comment.