Skip to content

Commit

Permalink
Rename keys->targets
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Feb 15, 2024
1 parent 30c81f1 commit bb2acd7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def get(
else:
graph = self.build(keys, handler=handler) # type: ignore[arg-type]
return TaskGraph(
graph=graph, keys=keys, scheduler=scheduler # type: ignore[arg-type]
graph=graph, targets=keys, scheduler=scheduler # type: ignore[arg-type]
)

@overload
Expand Down
24 changes: 12 additions & 12 deletions src/sciline/task_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def __init__(
self,
*,
graph: Graph,
keys: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]],
targets: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]],
scheduler: Optional[Scheduler] = None,
) -> None:
self._graph = graph
self._keys = keys
self._keys = targets
if scheduler is None:
try:
scheduler = DaskScheduler()
Expand All @@ -84,7 +84,7 @@ def __init__(

def compute(
self,
keys: Optional[
targets: Optional[
Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]]
] = None,
) -> Any:
Expand All @@ -93,7 +93,7 @@ def compute(
Parameters
----------
keys:
targets:
Optional list of keys to compute. This can be used to override the keys
stored in the graph instance. Note that the keys must be present in the
graph as intermediate results, otherwise KeyError is raised.
Expand All @@ -103,18 +103,18 @@ def compute(
If ``keys`` is a single type, returns the single result that was computed.
If ``keys`` is a tuple of types, returns a dictionary with type as keys
and the corresponding results as values.
"""
if keys is None:
keys = self._keys
if isinstance(keys, tuple):
results = self._scheduler.get(self._graph, list(keys))
return dict(zip(keys, results))
if targets is None:
targets = self._keys
if isinstance(targets, tuple):
results = self._scheduler.get(self._graph, list(targets))
return dict(zip(targets, results))
else:
return self._scheduler.get(self._graph, [keys])[0]
return self._scheduler.get(self._graph, [targets])[0]

def keys(self) -> Generator[Key, None, None]:
"""Iterate over all keys of the graph.
"""
Iterate over all keys of the graph.
Yields all keys, i.e., the types of values that can be computed or are
provided as parameters.
Expand Down
17 changes: 10 additions & 7 deletions tests/task_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,40 @@ def make_task_graph() -> Graph:

def test_default_scheduler_is_dask_when_dask_available() -> None:
_ = pytest.importorskip("dask")
tg = TaskGraph(graph={}, keys=())
tg = TaskGraph(graph={}, targets=())
assert isinstance(tg._scheduler, sl.scheduler.DaskScheduler)


def test_compute_returns_value_when_initialized_with_single_key() -> None:
graph = make_task_graph()
tg = TaskGraph(graph=graph, keys=float)
tg = TaskGraph(graph=graph, targets=float)
assert tg.compute() == 0.5


def test_compute_returns_dict_when_initialized_with_key_tuple() -> None:
graph = make_task_graph()
assert TaskGraph(graph=graph, keys=(float,)).compute() == {float: 0.5}
assert TaskGraph(graph=graph, keys=(float, int)).compute() == {float: 0.5, int: 1}
assert TaskGraph(graph=graph, targets=(float,)).compute() == {float: 0.5}
assert TaskGraph(graph=graph, targets=(float, int)).compute() == {
float: 0.5,
int: 1,
}


def test_compute_returns_value_when_provided_with_single_key() -> None:
graph = make_task_graph()
tg = TaskGraph(graph=graph, keys=float)
tg = TaskGraph(graph=graph, targets=float)
assert tg.compute(int) == 1


def test_compute_returns_dict_when_provided_with_key_tuple() -> None:
graph = make_task_graph()
tg = TaskGraph(graph=graph, keys=float)
tg = TaskGraph(graph=graph, targets=float)
assert tg.compute((int, float)) == {int: 1, float: 0.5}


def test_compute_raises_when_provided_with_key_not_in_graph() -> None:
graph = make_task_graph()
tg = TaskGraph(graph=graph, keys=float)
tg = TaskGraph(graph=graph, targets=float)
with pytest.raises(KeyError):
tg.compute(str)
with pytest.raises(KeyError):
Expand Down

0 comments on commit bb2acd7

Please sign in to comment.