Skip to content

Commit

Permalink
Merge pull request #90 from scipp/allow-replace
Browse files Browse the repository at this point in the history
Support replacing params and providers
  • Loading branch information
SimonHeybrock authored Dec 19, 2023
2 parents 3380fed + 3b6f98d commit 6fbc045
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
4 changes: 0 additions & 4 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,8 @@ def _set_provider(
if (origin := get_origin(key)) is not None:
subproviders = self._subproviders.setdefault(origin, {})
args = get_args(key)
if args in subproviders:
raise ValueError(f'Provider for {key} already exists')
subproviders[args] = provider
else:
if key in self._providers:
raise ValueError(f'Provider for {key} already exists')
self._providers[key] = provider

def _get_provider(
Expand Down
109 changes: 106 additions & 3 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,114 @@ class B(Generic[T]):
pl[B[int]] = 1.0


def test_setitem_raises_if_key_exists() -> None:
def test_setitem_can_replace_param_with_param() -> None:
pl = sl.Pipeline()
pl[int] = 1
with pytest.raises(ValueError):
pl[int] = 2
pl[int] = 2
assert pl.compute(int) == 2


def test_insert_can_replace_param_with_provider() -> None:
def func() -> int:
return 2

pl = sl.Pipeline()
pl[int] = 1
pl.insert(func)
assert pl.compute(int) == 2


def test_setitem_can_replace_provider_with_param() -> None:
def func() -> int:
return 2

pl = sl.Pipeline()
pl.insert(func)
pl[int] = 1
assert pl.compute(int) == 1


def test_insert_can_replace_provider_with_provider() -> None:
def func1() -> int:
return 1

def func2() -> int:
return 2

pl = sl.Pipeline()
pl.insert(func1)
pl.insert(func2)
assert pl.compute(int) == 2


def test_insert_can_replace_generic_provider_with_generic_provider() -> None:
T = TypeVar('T', int, float)

@dataclass
class A(Generic[T]):
value: T

def func1(x: T) -> A[T]:
return A[T](x)

def func2(x: T) -> A[T]:
return A[T](x + x)

pl = sl.Pipeline()
pl[int] = 1
pl.insert(func1)
pl.insert(func2)
assert pl.compute(A[int]) == A[int](2)


def test_insert_can_replace_generic_param_with_generic_provider() -> None:
T = TypeVar('T', int, float)

@dataclass
class A(Generic[T]):
value: T

def func(x: T) -> A[T]:
return A[T](x + x)

pl = sl.Pipeline()
pl[int] = 1
pl[A[T]] = A[T](1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)
pl.insert(func)
assert pl.compute(A[int]) == A[int](2)


def test_setitem_can_replace_generic_provider_with_generic_param() -> None:
T = TypeVar('T', int, float)

@dataclass
class A(Generic[T]):
value: T

def func(x: T) -> A[T]:
return A[T](x + x)

pl = sl.Pipeline()
pl[int] = 1
pl.insert(func)
assert pl.compute(A[int]) == A[int](2)
pl[A[T]] = A[T](1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)


def test_setitem_can_replace_generic_param_with_generic_param() -> None:
T = TypeVar('T')

@dataclass
class A(Generic[T]):
value: T

pl = sl.Pipeline()
pl[A[T]] = A[T](1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)
pl[A[T]] = A[T](2) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](2)


def test_init_with_params() -> None:
Expand Down
8 changes: 8 additions & 0 deletions tests/pipeline_with_param_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,11 @@ def process(x: float, missing: Missing) -> str:
pl = sl.Pipeline([process])
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
pl.get(sl.Series[int, str], handler=sl.HandleAsComputeTimeException())


def test_param_table_column_and_param_of_same_type_can_coexist() -> None:
pl = sl.Pipeline()
pl[float] = 1.0
pl.set_param_table(sl.ParamTable(int, {float: [2.0, 3.0]}))
assert pl.compute(float) == 1.0
assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 2.0, 1: 3.0})

0 comments on commit 6fbc045

Please sign in to comment.