Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of typing.Optional #50

Merged
merged 13 commits into from
Sep 4, 2023
37 changes: 37 additions & 0 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,27 @@ def _find_nodes_in_paths(
return list(nodes)


def _get_optional(tp: Key) -> Optional[Any]:
if get_origin(tp) != Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] == type(None) else args[1] # noqa: E721


def provide_none() -> None:
return None


class ReplicatorBase(Generic[IndexType]):
def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]):
if len(path) == 0:
raise UnsatisfiedRequirement(
'Could not find path to param in param table. This is likely caused '
'by requesting a Series that does not depend directly or transitively '
'on any param from a table.'
)
self._index_name = index_name
self.index = index
self._path = path
Expand Down Expand Up @@ -153,6 +172,7 @@ def _copy_node(
)

def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]:
value_name = _get_optional(value_name) or value_name
label = Label(self._index_name, i)
if isinstance(value_name, Item):
return Item(value_name.label + (label,), value_name.tp)
Expand Down Expand Up @@ -326,6 +346,8 @@ def __setitem__(self, key: Type[T], param: T) -> None:
param:
Concrete value to provide.
"""
if get_origin(key) == Union:
raise ValueError('Union (or Optional) parameters are not allowed.')
# TODO Switch to isinstance(key, NewType) once our minimum is Python 3.10
# Note that we cannot pass mypy in Python<3.10 since NewType is not a type.
if hasattr(key, '__supertype__'):
Expand Down Expand Up @@ -389,6 +411,10 @@ def _set_provider(
# isinstance does not work here and types.NoneType available only in 3.10+
if key == type(None): # noqa: E721
raise ValueError(f'Provider {provider} returning `None` is not allowed')
if get_origin(key) == Union:
raise ValueError(
f'Provider {provider} returning a Union (or Optional) is not allowed.'
)
if get_origin(key) == Series:
raise ValueError(
f'Provider {provider} returning a sciline.Series is not allowed. '
Expand Down Expand Up @@ -461,6 +487,17 @@ def build(
if get_origin(tp) == Series:
graph.update(self._build_series(tp)) # type: ignore[arg-type]
continue
if (optional_arg := _get_optional(tp)) is not None:
try:
optional_subgraph = self.build(
optional_arg, search_param_tables=search_param_tables
)
except UnsatisfiedRequirement:
graph[tp] = (provide_none, ())
else:
graph[tp] = optional_subgraph.pop(optional_arg)
graph.update(optional_subgraph)
continue
provider: Callable[..., T]
provider, bound = self._get_provider(tp)
tps = get_type_hints(provider)
Expand Down
10 changes: 5 additions & 5 deletions tests/complex_workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,21 @@ class IofQ(sl.Scope[Run, npt.NDArray[np.float64]], npt.NDArray[np.float64]):


def incident_monitor(x: Raw[Run]) -> IncidentMonitor[Run]:
return IncidentMonitor(x.monitor1)
return IncidentMonitor[Run](x.monitor1)


def transmission_monitor(x: Raw[Run]) -> TransmissionMonitor[Run]:
return TransmissionMonitor(x.monitor2)
return TransmissionMonitor[Run](x.monitor2)


def mask_detector(x: Raw[Run], mask: DetectorMask) -> Masked[Run]:
return Masked(x.data * mask)
return Masked[Run](x.data * mask)


def transmission(
incident: IncidentMonitor[Run], transmission: TransmissionMonitor[Run]
) -> TransmissionFraction[Run]:
return TransmissionFraction(incident / transmission)
return TransmissionFraction[Run](incident / transmission)


def iofq(
Expand All @@ -77,7 +77,7 @@ def iofq(
direct_beam: DirectBeam,
transmission: TransmissionFraction[Run],
) -> IofQ[Run]:
return IofQ(x / (solid_angle * direct_beam * transmission))
return IofQ[Run](x / (solid_angle * direct_beam * transmission))


reduction = [incident_monitor, transmission_monitor, mask_detector, transmission, iofq]
Expand Down
193 changes: 193 additions & 0 deletions tests/pipeline_with_optional_test.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests where there is an optional and a required dependency on some value. And check that it works for any order or providers and whether the value can be provided or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added one test, not entirely sure what the other cases you are asking for are.

Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import NewType, Optional, Union

import pytest

import sciline as sl


def test_provider_returning_optional_disallowed() -> None:
def make_optional() -> Optional[int]:
return 3

with pytest.raises(ValueError):
sl.Pipeline([make_optional])


def test_provider_returning_union_disallowed() -> None:
def make_union() -> Union[int, float]:
return 3

with pytest.raises(ValueError):
sl.Pipeline([make_union])


def test_parameter_type_union_or_optional_disallowed() -> None:
pipeline = sl.Pipeline()
with pytest.raises(ValueError):
pipeline[Union[int, float]] = 3 # type: ignore[index]
with pytest.raises(ValueError):
pipeline[Optional[int]] = 3 # type: ignore[index]


def test_union_requirement_leads_to_UnsatisfiedRequirement() -> None:
def require_union(x: Union[int, float]) -> str:
return f'{x}'

pipeline = sl.Pipeline([require_union])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a provider for int and/or float so this actually tests something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, see update!

pipeline[int] = 1
with pytest.raises(sl.UnsatisfiedRequirement):
pipeline.compute(str)


def test_optional_dependency_can_be_filled_by_non_optional_param() -> None:
def use_optional(x: Optional[int]) -> str:
return f'{x or 123}'

pipeline = sl.Pipeline([use_optional], params={int: 1})
assert pipeline.compute(str) == '1'


def test_union_with_none_can_be_used_instead_of_Optional() -> None:
def use_union1(x: Union[int, None]) -> str:
return f'{x or 123}'

def use_union2(x: Union[None, int]) -> str:
return f'{x or 123}'

pipeline = sl.Pipeline([use_union1], params={int: 1})
assert pipeline.compute(str) == '1'
pipeline = sl.Pipeline([use_union2], params={int: 1})
assert pipeline.compute(str) == '1'


def test_optional_requested_directly_can_be_filled_by_non_optional_param() -> None:
pipeline = sl.Pipeline([], params={int: 1})
assert pipeline.compute(Optional[int]) == 1 # type: ignore[call-overload]


def test_optional_dependency_can_be_filled_transitively() -> None:
def use_optional(x: Optional[int]) -> str:
return f'{x or 123}'

def make_int(x: float) -> int:
return int(x)

pipeline = sl.Pipeline([use_optional, make_int], params={float: 2.2})
assert pipeline.compute(str) == '2'


def test_optional_dependency_is_set_to_none_if_no_provider_found() -> None:
def use_optional(x: Optional[int]) -> str:
return f'{x or 123}'

pipeline = sl.Pipeline([use_optional])
assert pipeline.compute(str) == '123'


def test_optional_dependency_is_set_to_none_if_no_provider_found_transitively() -> None:
def use_optional(x: Optional[int]) -> str:
return f'{x or 123}'

def make_int(x: float) -> int:
return int(x)

pipeline = sl.Pipeline([use_optional, make_int])
assert pipeline.compute(str) == '123'


def test_can_have_both_optional_and_non_optional_path_to_param() -> None:
Str1 = NewType('Str1', str)
Str2 = NewType('Str2', str)
Str12 = NewType('Str12', str)
Str21 = NewType('Str21', str)

def use_optional_int(x: Optional[int]) -> Str1:
return Str1(f'{x or 123}')

def use_int(x: int) -> Str2:
return Str2(f'{x}')

def combine12(x: Str1, y: Str2) -> Str12:
return Str12(f'{x} {y}')

def combine21(x: Str2, y: Str1) -> Str21:
return Str21(f'{x} {y}')

pipeline = sl.Pipeline(
[use_optional_int, use_int, combine12, combine21], params={int: 1}
)
assert pipeline.compute(Str12) == '1 1'
assert pipeline.compute(Str21) == '1 1'


def test_presence_of_optional_does_not_affect_related_exception() -> None:
Str1 = NewType('Str1', str)
Str2 = NewType('Str2', str)
Str12 = NewType('Str12', str)
Str21 = NewType('Str21', str)

def use_optional_int(x: Optional[int]) -> Str1:
return Str1(f'{x or 123}')

# Make sure the implementation does not unintentionally put "None" here,
# triggered by the presence of the optional dependency on int in another provider.
def use_int(x: int) -> Str2:
return Str2(f'{x}')

def combine12(x: Str1, y: Str2) -> Str12:
return Str12(f'{x} {y}')

def combine21(x: Str2, y: Str1) -> Str21:
return Str21(f'{x} {y}')

pipeline = sl.Pipeline([use_optional_int, use_int, combine12, combine21])
with pytest.raises(sl.UnsatisfiedRequirement):
pipeline.compute(Str12)
with pytest.raises(sl.UnsatisfiedRequirement):
pipeline.compute(Str21)


def test_optional_dependency_in_node_depending_on_param_table() -> None:
def use_optional(x: float, y: Optional[int]) -> str:
return f'{x} {y or 123}'

pl = sl.Pipeline([use_optional])
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
assert pl.compute(sl.Series[int, str]) == sl.Series(
int, {0: '1.0 123', 1: '2.0 123', 2: '3.0 123'}
)
pl[int] = 11
assert pl.compute(sl.Series[int, str]) == sl.Series(
int, {0: '1.0 11', 1: '2.0 11', 2: '3.0 11'}
)


def test_optional_dependency_can_be_filled_from_param_table() -> None:
def use_optional(x: Optional[float]) -> str:
return f'{x or 4.0}'

pl = sl.Pipeline([use_optional])
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
assert pl.compute(sl.Series[int, str]) == sl.Series(
int, {0: '1.0', 1: '2.0', 2: '3.0'}
)


def test_optional_without_anchoring_param_raises_when_requesting_series() -> None:
Param = NewType('Param', float)

def use_optional(x: Optional[float]) -> str:
return f'{x or 4.0}'

pl = sl.Pipeline([use_optional])
pl.set_param_table(sl.ParamTable(int, {Param: [1.0, 2.0, 3.0]}))
# It is a bit ambiguous what we would expect here: Above, we have another param
# used from the table, defining the length of the series. Here, we could replicate
# the output of use_optional(None) based on the `int` param table:
# sl.Series(int, {0: '4.0', 1: '4.0', 2: '4.0'})
# However, we are not supporting this for non-optional dependencies either since
# it is unclear whether that would bring conceptual issues or risk.
with pytest.raises(sl.UnsatisfiedRequirement):
pl.compute(sl.Series[int, str])
13 changes: 13 additions & 0 deletions tests/pipeline_with_param_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def test_can_compute_series_of_param_values() -> None:
assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})


def test_cannot_compute_series_of_non_table_param() -> None:
pl = sl.Pipeline()
# Table for defining length
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
pl[str] = 'abc'
# The alternative option would be to expect to return
# sl.Series(int, {0: 'abc', 1: 'abc', 2: 'abc'})
# For now, we are not supporting this since it is unclear if this would be
# conceptually sound and risk free.
with pytest.raises(sl.UnsatisfiedRequirement):
pl.compute(sl.Series[int, str])


def test_can_compute_series_of_derived_values() -> None:
def process(x: float) -> str:
return str(x)
Expand Down