From 7a218ac76187a7143b4ad86170f9b891ebaad1da Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 12 Oct 2022 12:28:06 +0200 Subject: [PATCH 01/45] =?UTF-8?q?=F0=9F=94=A7=20FIX=20auto=5Fpersist=20dec?= =?UTF-8?q?orator=20typing=20(#239)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is used in aiida-core, to decorate the `WorkChain` class. The problem currently is that, because of the decorator typing, no methods defined on subclasses of `Savable` are available for static analysers (like IDE auto-completions). This change ensures the output of the decorator is the same as the input. --- pyproject.toml | 3 ++- src/plumpy/persistence.py | 9 ++++++--- tox.ini | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b7d42530..e0382961 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,9 +46,10 @@ docs = [ 'myst-nb~=0.11.0', 'sphinx~=3.2.0', 'sphinx-book-theme~=0.0.39', + 'importlib-metadata~=4.12.0', ] pre-commit = [ - 'mypy==0.790', + 'mypy==0.982', 'pre-commit~=2.2', 'pylint==2.12.2', ] diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 10bfbed9..7b49a330 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,7 +9,7 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union import yaml @@ -340,9 +340,12 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -def auto_persist(*members: str) -> Callable[[Type['Savable']], Type['Savable']]: +SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') - def wrapped(savable: Type['Savable']) -> Type['Savable']: + +def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: + + def wrapped(savable: SavableClsType) -> SavableClsType: # pylint: disable=protected-access if savable._auto_persist is None: savable._auto_persist = set() diff --git a/tox.ini b/tox.ini index 1759c26e..3b47c63b 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,7 @@ extras = tests commands = pytest {posargs} -[testenv:pre-commit] +[testenv:py{37,38,39}-pre-commit] description = Run the style checks and formatting extras = pre-commit From 3b9d39faa007a72e744556e1c5818b933e663cd5 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 12 Oct 2022 20:25:37 +0200 Subject: [PATCH 02/45] `Process`: Add the `is_excepted` property Will return `True` when the process state is `ProcessState.EXCEPTED`. --- src/plumpy/processes.py | 8 ++++++++ test/test_processes.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c8b7b2a7..6f28b843 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -487,6 +487,14 @@ def exception(self) -> Optional[BaseException]: return None + @property + def is_excepted(self) -> bool: + """Return whether the process excepted. + + :return: boolean, True if the process is in ``EXCEPTED`` state. + """ + return self.state == process_states.ProcessState.EXCEPTED + def done(self) -> bool: """Return True if the call was successfully killed or finished running. diff --git a/test/test_processes.py b/test/test_processes.py index 05ba1056..b5ce5ad5 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -499,6 +499,8 @@ def run(self): with self.assertRaises(ValueError): proc.execute() + assert proc.is_excepted + def test_missing_output(self): proc = utils.MissingOutputProcess() @@ -507,7 +509,7 @@ def test_missing_output(self): proc.execute() - self.assertFalse(proc.successful()) + self.assertFalse(proc.is_successful) def test_unsuccessful_result(self): ERROR_CODE = 256 From 84ab1de41bfe447d38a4a1c7554ad8534268f3d9 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 12 Oct 2022 21:03:58 +0200 Subject: [PATCH 03/45] `Process`: Fix incorrect overriding of `transition_failed` The `Process` class intended to override the `transition_failed` method of the `StateMachine` base class, but accidentally implemented the method `transition_excepted`, which was therefore never called. --- src/plumpy/base/state_machine.py | 12 +++++------- src/plumpy/processes.py | 4 ++-- test/test_processes.py | 7 +++---- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 0cad7c45..3400026a 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -338,16 +338,14 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transition_failing = False self._transitioning = False - @staticmethod - def transition_failed( - initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + def transition_failed( # pylint: disable=no-self-use + self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType ) -> None: - """ - Called when a state transitions fails. This method can be overwritten - to change the default behaviour which is to raise the exception. + """Called when a state transitions fails. - :param exception: The transition failed exception + This method can be overwritten to change the default behaviour which is to raise the exception. + :param exception: The transition failed exception. """ raise exception.with_traceback(trace) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 6f28b843..523fe2ea 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -978,8 +978,8 @@ def close(self) -> None: # region State related methods - def transition_excepted( - self, _initial_state: Any, final_state: process_states.ProcessState, exception: Exception, trace: TracebackType + def transition_failed( + self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: diff --git a/test/test_processes.py b/test/test_processes.py index b5ce5ad5..7df907ed 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -242,10 +242,9 @@ def test_forget_to_call_parent(self): proc.execute() def test_forget_to_call_parent_kill(self): - with self.assertRaises(AssertionError): - proc = ForgetToCallParent('kill') - proc.kill() - proc.execute() + proc = ForgetToCallParent('kill') + proc.kill() + assert proc.is_excepted def test_pid(self): # Test auto generation of pid From f4be3fef4677d8f4e0f36d23b842cada9bb325ff Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 12 Oct 2022 22:09:18 +0200 Subject: [PATCH 04/45] `StateMachine`: transition directly to excepted if transition failed Before, if a state transition failed a transition to the excepted state would be initiated. However, if the original failure came from a method that would be called in all state transitions, i.e. also when transitioning to the excepted, it would be guaranteed to be hit again. In the second transition failure, the exception would simply be raised again and bubble up. In the case of a transition failure and so `self._transition_failed` is set to `True`, the current state should not be explicitly exited but one should transition straight to the excepted state. This change now effectively allows the state machine to transition from a `FINISHED` state to the `EXCEPTED` state. A process could transition to the `FINISHED` state and on exiting the `FINISHED` state, an exception could be raised. In this case the result of the future would already be set, so the `on_except` method needs to check for this, and when set, first reset the future before setting the exception. --- src/plumpy/base/state_machine.py | 5 ++++- src/plumpy/processes.py | 7 +++++++ test/test_processes.py | 34 ++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 3400026a..c9a08848 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -315,7 +315,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A # Make sure we have a state instance new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL - self._exit_current_state(new_state) + + # If the previous transition failed, do not try to exit it but go straight to next state + if not self._transition_failing: + self._exit_current_state(new_state) try: self._enter_next_state(new_state) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 523fe2ea..41bd3609 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -828,6 +828,13 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: """Entering the EXCEPTED state.""" exception = exc_info[1] exception.__traceback__ = exc_info[2] + + # It is possible that we already got into a finished state and the future result was set, in which case, we + # should reset it before setting the exception or else ``asyncio`` will raise an exception. + future = self.future() + + if future.done(): + self._future = persistence.SavableFuture(loop=self._loop) self.future().set_exception(exception) @super_check diff --git a/test/test_processes.py b/test/test_processes.py index 7df907ed..e357fc4d 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -654,6 +654,40 @@ def test_execute_twice(self): with self.assertRaises(plumpy.ClosedError): proc.execute() + def test_exception_during_on_entered(self): + """Test that an exception raised during ``on_entered`` will cause the process to be excepted.""" + + class RaisingProcess(Process): + + def on_entered(self, from_state): + if from_state is not None and from_state.label == ProcessState.RUNNING: + raise RuntimeError('exception during on_entered') + super().on_entered(from_state) + + process = RaisingProcess() + + with self.assertRaises(RuntimeError): + process.execute() + + assert not process.is_successful + assert process.is_excepted + assert str(process.exception()) == 'exception during on_entered' + + def test_exception_during_run(self): + + class RaisingProcess(Process): + + def run(self): + raise RuntimeError('exception during run') + + process = RaisingProcess() + + with self.assertRaises(RuntimeError): + process.execute() + + assert process.is_excepted + assert str(process.exception()) == 'exception during run' + @plumpy.auto_persist('steps_ran') class SavePauseProc(plumpy.Process): From 449f7e62ec3edad201792f90157a5eaa67373b23 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 13 Oct 2022 11:03:32 +0200 Subject: [PATCH 05/45] Dependencies: Add lower limit for patch version of `nest-asyncio` (#241) The package breaks for `nest-asyncio==1.5.0` so the requirement is updated to add the lower bound `nest-asyncio>=1.5.1`. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e0382961..006ad77e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ requires-python = '>=3.7' dependencies = [ 'aio-pika~=6.6', 'kiwipy[rmq]~=0.7.4', - 'nest_asyncio~=1.5', + 'nest_asyncio~=1.5,>=1.5.1', 'pyyaml~=5.4', ] From ebbf17a57dd0d44abcdd106727072007650e9a8b Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 13 Oct 2022 11:12:10 +0200 Subject: [PATCH 06/45] Dependencies: Add support for Python 3.10 (#242) --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 4 ++-- pyproject.toml | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 58d7d3c7..716be813 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -58,7 +58,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9'] + python-version: ['3.7', '3.8', '3.9', '3.10'] services: postgres: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0377da9c..06a058cc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: - name: Set up Python 3.8 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: '3.8' - name: Install Python dependencies run: pip install -e .[pre-commit] @@ -25,7 +25,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9'] + python-version: ['3.7', '3.8', '3.9', '3.10'] services: rabbitmq: diff --git a/pyproject.toml b/pyproject.toml index 006ad77e..59042096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', ] keywords = ['workflow', 'multithreaded', 'rabbitmq'] requires-python = '>=3.7' From b5d1fbed8b8362515b8086f24746f4727b9d995e Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 26 Oct 2022 13:07:15 +0200 Subject: [PATCH 07/45] Tests: Fix remaining warnings (#244) --- pyproject.toml | 4 +--- test/test_process_spec.py | 2 +- test/test_utils.py | 28 +++++++++++++++------------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 59042096..734a5787 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,9 +108,7 @@ minversion = '6.0' testpaths = [ 'test', ] -filterwarnings = [ - 'ignore::DeprecationWarning:frozendict:', -] +filterwarnings = [] [tool.yapf] align_closing_bracket_with_visual_indent = true diff --git a/test/test_process_spec.py b/test/test_process_spec.py index 73dcf498..443f7a64 100644 --- a/test/test_process_spec.py +++ b/test/test_process_spec.py @@ -77,7 +77,7 @@ def test_validator(self): """Test the port validator with default.""" def dict_validator(dictionary, port): - if 'key' not in dictionary or dictionary['key'] is not 'value': + if 'key' not in dictionary or dictionary['key'] != 'value': return 'Invalid dictionary' self.spec.input('dict', default={'key': 'value'}, validator=dict_validator) diff --git a/test/test_utils.py b/test/test_utils.py index e6a0e249..546261f2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,30 +2,32 @@ import asyncio import functools import inspect -import unittest +import warnings + +import pytest from plumpy.utils import AttributesFrozendict, ensure_coroutine, load_function -class TestAttributesFrozendict(unittest.TestCase): +class TestAttributesFrozendict: def test_getitem(self): d = AttributesFrozendict({'a': 5}) - self.assertEqual(d['a'], 5) + assert d['a'] == 5 - with self.assertRaises(KeyError): + with pytest.raises(KeyError): d['b'] def test_getattr(self): d = AttributesFrozendict({'a': 5}) - self.assertEqual(d.a, 5) + assert d.a == 5 - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): d.b def test_setitem(self): d = AttributesFrozendict() - with self.assertRaises(TypeError): + with pytest.raises(TypeError): d['a'] = 5 @@ -37,7 +39,7 @@ async def async_fct(): pass -class TestEnsureCoroutine(unittest.TestCase): +class TestEnsureCoroutine: def test_sync_func(self): coro = ensure_coroutine(fct) @@ -48,8 +50,6 @@ def test_async_func(self): assert coro is async_fct def test_callable_class(self): - """ - """ class AsyncDummy: @@ -60,8 +60,6 @@ async def __call__(self): assert coro is AsyncDummy def test_callable_object(self): - """ - """ class AsyncDummy: @@ -76,7 +74,11 @@ def test_functools_partial(self): fct_wrap = functools.partial(fct) coro = ensure_coroutine(fct_wrap) assert coro is not fct_wrap - assert asyncio.iscoroutine(coro()) + # The following will emit a RuntimeWarning ``coroutine 'ensure_coroutine..wrap' was never awaited`` + # since were not actually ever awaiting ``core`` but that is not the point of the test. + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + assert asyncio.iscoroutine(coro()) def test_load_function(): From fdf4c6322b01652a977c6c49969a03aa37e060ee Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 28 Oct 2022 19:19:43 +0200 Subject: [PATCH 08/45] Dependencies: Update requirement `kiwipy~=0.8.2` (#243) This new release of `kiwipy` comes with support for the recent release of `aio-pika` which comes with stability fixes for the dropping of connections. --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 734a5787..42c56364 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,7 @@ classifiers = [ keywords = ['workflow', 'multithreaded', 'rabbitmq'] requires-python = '>=3.7' dependencies = [ - 'aio-pika~=6.6', - 'kiwipy[rmq]~=0.7.4', + 'kiwipy[rmq]~=0.8.2', 'nest_asyncio~=1.5,>=1.5.1', 'pyyaml~=5.4', ] @@ -43,6 +42,7 @@ Documentation = 'https://plumpy.readthedocs.io' docs = [ 'ipython~=7.0', 'jinja2==2.11.3', + 'kiwipy[docs]~=0.8.2', 'markupsafe==2.0.1', 'myst-nb~=0.11.0', 'sphinx~=3.2.0', From d9a200acba8d562cc7d3616be2724adb476b502a Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 31 Oct 2022 11:12:30 +0100 Subject: [PATCH 09/45] DevOps: Update the `mypy` pre-commit dependency (#246) The requirement is updated to `mypy==0.930`. This allows us to move the configuration from `tox.ini` to the `pyproject.toml` file. The config in `.pre-commit-config.yaml` is also corrected as the include path did not properly include the `src` starting folder. The remaining configuration for `tox.ini` is also moved to the `pyproject.toml` and deleted. --- .pre-commit-config.yaml | 11 ++++-- pyproject.toml | 63 +++++++++++++++++++++++++++++ src/plumpy/__init__.py | 4 +- src/plumpy/base/state_machine.py | 10 ++--- src/plumpy/events.py | 2 +- src/plumpy/mixins.py | 2 +- src/plumpy/persistence.py | 8 ++-- src/plumpy/processes.py | 5 ++- src/plumpy/utils.py | 2 +- tox.ini | 68 -------------------------------- 10 files changed, 88 insertions(+), 87 deletions(-) delete mode 100644 tox.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9cf325c8..18364c25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,14 +28,17 @@ repos: additional_dependencies: ['toml'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 + rev: v0.982 hooks: - id: mypy - args: [--config-file=tox.ini] - additional_dependencies: ['aio_pika~=6.6'] + args: [--config-file=pyproject.toml] + additional_dependencies: [ + 'toml', + 'types-pyyaml', + ] files: > (?x)^( - plumpy/.*py| + src/plumpy/.*py| )$ - repo: https://github.com/PyCQA/pylint diff --git a/pyproject.toml b/pyproject.toml index 42c56364..f31c988b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ pre-commit = [ 'mypy==0.982', 'pre-commit~=2.2', 'pylint==2.12.2', + 'types-pyyaml' ] tests = [ 'ipykernel==6.12.1', @@ -83,6 +84,28 @@ include_trailing_comma = true line_length = 120 multi_line_output = 3 +[tool.mypy] +show_error_codes = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +warn_unused_ignores = true +warn_redundant_casts = true + +[[tool.mypy.overrides]] +module = 'test.*' +check_untyped_defs = false + +[[tool.mypy.overrides]] +module = [ + 'aio_pika.*', + 'aiocontextvars.*', + 'kiwipy.*', + 'nest_asyncio.*', + 'tblib.*', +] +ignore_missing_imports = true + [tool.pylint.format] max-line-length = 120 @@ -118,3 +141,43 @@ column_limit = 120 dedent_closing_brackets = true indent_dictionary_value = false split_arguments_when_comma_terminated = true + +[tool.tox] +legacy_tox_ini = """ +[tox] +envlist = py37 + +[testenv] +usedevelop = true + +[testenv:py{37,38,39,310}] +description = Run the unit tests +extras = + tests +commands = pytest {posargs} + +[testenv:py{37,38,39,310}-pre-commit] +description = Run the style checks and formatting +extras = + pre-commit + tests +commands = pre-commit run {posargs} + +[testenv:docs-{update,clean}] +description = Build the documentation +extras = docs +whitelist_externals = rm +commands = + clean: rm -rf docs/_build + sphinx-build -nW --keep-going -b {posargs:html} docs/source/ docs/_build/{posargs:html} + +[testenv:docs-live] +description = Build the documentation and launch browser (with live updates) +extras = docs +deps = sphinx-autobuild +commands = + sphinx-autobuild \ + --re-ignore _build/.* \ + --port 0 --open-browser \ + -n -b {posargs:html} docs/source/ docs/_build/{posargs:html} +""" diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 6015a744..a044d5c7 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- +# mypy: disable-error-code=name-defined # pylint: disable=undefined-variable -# type: ignore[name-defined] __version__ = '0.21.0' import logging @@ -33,7 +33,7 @@ # for more details class NullHandler(logging.Handler): - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: pass diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index c9a08848..274ac968 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -76,12 +76,12 @@ def event( """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" if from_states != '*': if inspect.isclass(from_states): - from_states = (from_states,) # type: ignore + from_states = (from_states,) if not all(issubclass(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): - to_states = (to_states,) # type: ignore + to_states = (to_states,) if not all(issubclass(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') @@ -111,7 +111,7 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return transition if inspect.isfunction(from_states): - return wrapper(from_states) # type: ignore + return wrapper(from_states) return wrapper @@ -397,8 +397,8 @@ def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *ar return state_cls(self, *args, **kwargs) def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: - if inspect.isclass(state) and issubclass(state, State): # type: ignore - return cast(Type[State], state) + if inspect.isclass(state) and issubclass(state, State): + return state try: return self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 64f5dd9c..735fb3f6 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -23,7 +23,7 @@ def new_event_loop(*args: Any, **kwargs: Any) -> asyncio.AbstractEventLoop: raise NotImplementedError('this method is not implemented because `plumpy` uses a single reentrant loop') -class PlumpyEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore +class PlumpyEventLoopPolicy(asyncio.DefaultEventLoopPolicy): """Custom event policy that always returns the same event loop that is made reentrant by ``nest_asyncio``.""" _loop: Optional[asyncio.AbstractEventLoop] = None diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index 05cb4c11..6302184b 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -15,7 +15,7 @@ class ContextMixin(persistence.Savable): CONTEXT: str = '_context' def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) # type: ignore + super().__init__(*args, **kwargs) self._context: Optional[AttributesDict] = AttributesDict() @property diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 7b49a330..b646546d 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -73,7 +73,7 @@ def _bundle_constructor(loader: yaml.Loader, data: Any) -> Generator[Bundle, Non yaml.add_representer(Bundle, _bundle_representer) -yaml.add_constructor(_BUNDLE_TAG, _bundle_constructor) +yaml.add_constructor(_BUNDLE_TAG, _bundle_constructor) # type: ignore[arg-type] class Persister(metaclass=abc.ABCMeta): @@ -340,7 +340,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') +SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') # type: ignore[name-defined] def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: @@ -650,5 +650,5 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadS super().load_instance_state(saved_state, load_context) if self._callbacks: # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list - for callback in self._callbacks: # type: ignore - self.remove_done_callback(callback) + for callback in self._callbacks: + self.remove_done_callback(callback) # type: ignore[arg-type] diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 41bd3609..0066cb8b 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -72,7 +72,10 @@ class ProcessStateMachineMeta(abc.ABCMeta, state_machine.StateMachineMeta): # Make ProcessStateMachineMeta instances (classes) YAML - able -yaml.representer.Representer.add_representer(ProcessStateMachineMeta, yaml.representer.Representer.represent_name) +yaml.representer.Representer.add_representer( + ProcessStateMachineMeta, + yaml.representer.Representer.represent_name # type: ignore[arg-type] +) def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 91cbe21e..a11ebd01 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -171,7 +171,7 @@ def load_function(name: str, instance: Optional[Any] = None) -> Callable[..., An obj = load_object(name) if inspect.ismethod(obj): if instance is not None: - return obj.__get__(instance, instance.__class__) + return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] return obj diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 3b47c63b..00000000 --- a/tox.ini +++ /dev/null @@ -1,68 +0,0 @@ -# To use tox, see https://tox.readthedocs.io -# Simply pip or conda install tox -# If you use conda, you may also want to install tox-conda -# then run `tox` or `tox -- {pytest args}` -# run in parallel using `tox -p` - -# see also test/rmq/docker-compose.yml to start a rabbitmq server, required for the those tests - -[tox] -envlist = py37 - -[testenv] -usedevelop = true - -[testenv:py{37,38,39}] -description = Run the unit tests -extras = - tests -commands = pytest {posargs} - -[testenv:py{37,38,39}-pre-commit] -description = Run the style checks and formatting -extras = - pre-commit - tests -commands = pre-commit run {posargs} - -[testenv:docs-{update,clean}] -description = Build the documentation -extras = docs -whitelist_externals = rm -commands = - clean: rm -rf docs/_build - sphinx-build -nW --keep-going -b {posargs:html} docs/source/ docs/_build/{posargs:html} - -[testenv:docs-live] -description = Build the documentation and launch browser (with live updates) -extras = docs -deps = sphinx-autobuild -commands = - sphinx-autobuild \ - --re-ignore _build/.* \ - --port 0 --open-browser \ - -n -b {posargs:html} docs/source/ docs/_build/{posargs:html} - - -[mypy] -show_error_codes = True -disallow_untyped_defs = True -disallow_incomplete_defs = True -check_untyped_defs = True -warn_unused_ignores = True -warn_redundant_casts = True - -[mypy-aiocontextvars.*] -ignore_missing_imports = True - -[mypy-frozendict.*] -ignore_missing_imports = True - -[mypy-kiwipy.*] -ignore_missing_imports = True - -[mypy-nest_asyncio.*] -ignore_missing_imports = True - -[mypy-tblib.*] -ignore_missing_imports = True From 9d490a4276b99a9a5b08cc701e12f427482f14b7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 31 Oct 2022 11:20:59 +0100 Subject: [PATCH 10/45] Release `v0.22.0` --- CHANGELOG.md | 21 ++++++++++++++++++++- src/plumpy/__init__.py | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d9a0cd3..1ec9e450 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,29 @@ # Changelog +## v0.22.0 - 2022-10-31 + +### Features +- `Process`: Add the `is_excepted` property [[#240]](https://github.com/aiidateam/plumpy/pull/240) + +### Bug fixes +- `StateMachine`: transition directly to excepted if transition failed [[#240]](https://github.com/aiidateam/plumpy/pull/240) +- `Process`: Fix incorrect overriding of `transition_failed` [[#240]](https://github.com/aiidateam/plumpy/pull/240) + +### Dependencies +- Add support for Python 3.10 [[#242]](https://github.com/aiidateam/plumpy/pull/242) +- Update requirement `kiwipy~=0.8.2` [[#243]](https://github.com/aiidateam/plumpy/pull/243) +- Add lower limit for patch version of `nest-asyncio` [[#241]](https://github.com/aiidateam/plumpy/pull/241) + +### Devops +- Fix `auto_persist` decorator typing [[#239]](https://github.com/aiidateam/plumpy/pull/239) +- Fix remaining warnings in unit tests [[#244]](https://github.com/aiidateam/plumpy/pull/244) +- Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) + ## v0.21.0 - 2022-04-08 ### Bug fixes -- Fix UnboundLocalError in DefaultObjectLoader.load_object. [[#225]](https://github.com/aiidateam/plumpy/pull/225) +- Fix `UnboundLocalError` in `DefaultObjectLoader.load_object`. [[#225]](https://github.com/aiidateam/plumpy/pull/225) ### Dependencies - Drop support for Python 3.6. [[#228]](https://github.com/aiidateam/plumpy/pull/228) diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index a044d5c7..8a61c400 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined # pylint: disable=undefined-variable -__version__ = '0.21.0' +__version__ = '0.22.0' import logging From 85970dcab62b7e255cf0889935aa0670811ae534 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 21 Nov 2022 00:16:40 +0100 Subject: [PATCH 11/45] Dependencies: Update requirement `pyyaml~=6.0` (#248) This latest version adds official support for Python 3.11. Also need to update `kiwipy` to `~=0.8.3` since older versions pin the `pyyaml` version to `~=5.0`. --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f31c988b..1db506c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,9 @@ classifiers = [ keywords = ['workflow', 'multithreaded', 'rabbitmq'] requires-python = '>=3.7' dependencies = [ - 'kiwipy[rmq]~=0.8.2', + 'kiwipy[rmq]~=0.8.3', 'nest_asyncio~=1.5,>=1.5.1', - 'pyyaml~=5.4', + 'pyyaml~=6.0', ] [project.urls] @@ -42,7 +42,7 @@ Documentation = 'https://plumpy.readthedocs.io' docs = [ 'ipython~=7.0', 'jinja2==2.11.3', - 'kiwipy[docs]~=0.8.2', + 'kiwipy[docs]~=0.8.3', 'markupsafe==2.0.1', 'myst-nb~=0.11.0', 'sphinx~=3.2.0', From e31f936456090da3125826852b866328560690b4 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 21 Nov 2022 10:59:15 +0100 Subject: [PATCH 12/45] Dependencies: Add support for Python 3.11 (#249) --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 2 +- pyproject.toml | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 716be813..1de6e969 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -58,7 +58,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] services: postgres: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06a058cc..73dd3710 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] services: rabbitmq: diff --git a/pyproject.toml b/pyproject.toml index 1db506c9..be8655c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', ] keywords = ['workflow', 'multithreaded', 'rabbitmq'] requires-python = '>=3.7' @@ -60,8 +61,9 @@ tests = [ 'pytest==6.2.5', 'pytest-asyncio==0.16.0', 'pytest-cov==3.0.0', - 'pytest-notebook==0.7.0', + 'pytest-notebook==0.8.1', 'shortuuid==1.0.8', + 'importlib-resources~=5.2', ] [tool.flit.module] From 096741a99438fad5e71f91788dbb64c1854b8eea Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 21 Nov 2022 11:17:54 +0100 Subject: [PATCH 13/45] Release `v0.22.1` --- CHANGELOG.md | 7 +++++++ src/plumpy/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ec9e450..2a21fcd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v0.22.1 - 2022-11-21 + +### Dependencies +- Add support for Python 3.11 [[#249]](https://github.com/aiidateam/plumpy/pull/249) +- Update requirement `pyyaml~=6.0` [[#248]](https://github.com/aiidateam/plumpy/pull/248) + + ## v0.22.0 - 2022-10-31 ### Features diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 8a61c400..15de5b7f 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined # pylint: disable=undefined-variable -__version__ = '0.22.0' +__version__ = '0.22.1' import logging From 870230f5ffe24ca030ee904c70e0cc052da8904d Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 29 Nov 2022 16:17:25 +0100 Subject: [PATCH 14/45] Docs: Update `CHANGELOG.md` with support releases --- CHANGELOG.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a21fcd6..12d4c42c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,31 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.2 - 2022-11-29 + +### Bug fixes +`Process`: Ensure that the raw inputs are not mutated [[#251]](https://github.com/aiidateam/plumpy/pull/251) + +### Dependencies +- Add support for Python 3.10 and 3.11 [[#254]](https://github.com/aiidateam/plumpy/pull/254) +- Update requirement `pytest-notebook>=0.8.1` [[#254]](https://github.com/aiidateam/plumpy/pull/254) +- Update requirement `pyyaml~=6.0` [[#254]](https://github.com/aiidateam/plumpy/pull/254) +- Update requirement `kiwipy[rmq]~=0.7.7` [[#254]](https://github.com/aiidateam/plumpy/pull/254) +- Update the `myst-nb` and `sphinx` requirements [[#253]](https://github.com/aiidateam/plumpy/pull/253) + + +## v0.21.1 - 2022-11-21 + +This is a backport of changes introduced in `v0.22.0`. + +### Features +- `Process`: Add the `is_excepted` property [[#240]](https://github.com/aiidateam/plumpy/pull/240) + +### Bug fixes +- `StateMachine`: transition directly to excepted if transition failed [[#240]](https://github.com/aiidateam/plumpy/pull/240) +- `Process`: Fix incorrect overriding of `transition_failed` [[#240]](https://github.com/aiidateam/plumpy/pull/240) + + ## v0.21.0 - 2022-04-08 ### Bug fixes From cdc4f7ce30c59564802bb8f20316e0cd0339a7a2 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 29 Nov 2022 16:27:38 +0100 Subject: [PATCH 15/45] Docs: Add logo to `README.md` This makes it identical to the `README.md` of `kiwipy`. --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 32d845ff..d7abb616 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# plumpy +## plumpy + + [![Build status][github-ci]][github-link] [![Docs status][rtd-badge]][rtd-link] From ff959dfd33558e482a5c100f3ce19c5a370fb1d1 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 29 Nov 2022 15:04:26 +0100 Subject: [PATCH 16/45] Process: Ensure that the raw inputs are not mutated (#251) The `Process.raw_inputs` property returns an `AttributesFrozenDict` instance with the inputs that were originally passed to the constructor of the process instance. These inputs should not be mutated as the final inputs are generated from it by pre-processing them with respect to the process spec, filling in defaults for missing ports. However, the `spec().inputs.pre_process` call in `on_create` was passing a shallow copy of the `raw_inputs` and so nested dictionaries would get modified. The problem is fixed by passing in a deep copy instead. This is done using a custom inline function `recursively_copy_dictionaries` instead of the maybe more obvious choice `copy.deepcopy`. The reason is that the latter not only copies the namespaces but also the values, which is not what we want here because we want to maintain the original values. Cherry-pick: 4701ad3627a3e2f3e86d2cb82ec84a7676014b87 --- src/plumpy/processes.py | 14 ++++++++++++-- test/test_processes.py | 22 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 0066cb8b..21bb0555 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -734,8 +734,18 @@ def on_create(self) -> None: """Entering the CREATED state.""" self._creation_time = time.time() - # This will parse the inputs with respect to the input portnamespace of the spec and validate them - raw_inputs = dict(self._raw_inputs) if self._raw_inputs else {} + def recursively_copy_dictionaries(value: Any) -> Any: + """Recursively copy the mapping but only create copies of the dictionaries not the values.""" + if isinstance(value, dict): + return {key: recursively_copy_dictionaries(subvalue) for key, subvalue in value.items()} + return value + + # This will parse the inputs with respect to the input portnamespace of the spec and validate them. The + # ``pre_process`` method of the inputs port namespace modifies its argument in place, and since the + # ``_raw_inputs`` should not be modified, we pass a clone of it. Note that we only need a clone of the nested + # dictionaries, so we don't use ``copy.deepcopy`` (which might seem like the obvious choice) as that will also + # create a clone of the values, which we don't want. + raw_inputs = recursively_copy_dictionaries(dict(self._raw_inputs)) if self._raw_inputs else {} self._parsed_inputs = self.spec().inputs.pre_process(raw_inputs) result = self.spec().inputs.validate(self._parsed_inputs) diff --git a/test/test_processes.py b/test/test_processes.py index e357fc4d..d0868e49 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -115,6 +115,28 @@ def define(cls, spec): with self.assertRaises(AttributeError): p.raw_inputs.b + def test_raw_inputs(self): + """Test that the ``raw_inputs`` are not mutated by the ``Process`` constructor. + + Regression test for https://github.com/aiidateam/plumpy/issues/250 + """ + + class Proc(Process): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a') + spec.input('nested.a') + spec.input('nested.b', default='default-value') + + inputs = {'a': 5, 'nested': {'a': 'value'}} + process = Proc(inputs) + + # Compare against a clone of the original inputs dictionary as the original is modified. It should not contain + # the default value of the ``nested.b`` port. + self.assertDictEqual(dict(process.raw_inputs), {'a': 5, 'nested': {'a': 'value'}}) + def test_inputs_default(self): class Proc(utils.DummyProcess): From 11c527886272d7592f25c3a699be41e7c3fe7178 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 7 Dec 2022 17:34:32 +0100 Subject: [PATCH 17/45] `PortNamespace`: Fix bug in valid type checking of dynamic namespaces (#255) If a `valid_type` is set on a `PortNamespace` it is automatically made dynamic. If a namespace is dynamic, it means that it should accept any nested namespace, however deeply nested. However, the leafs of the inputs should still respect the `valid_type`. This was not the case though. As soon as a port namespace was made dynamic, any nested input namespace would be expected, regardless of the types of the leaf values. The `PortNamespace.validate_dynamic_ports` is made recursive. If a port value is a dictionary, it recursively calls itself to validate its nested values, ultimately making sure that leaf values have a valid type if one is specified for the namespace. Cherry-pick: 9fb5f1c5f7dbe8d58d21652ee903a94d9e7f657b --- src/plumpy/ports.py | 20 ++++++++++++-------- test/test_port.py | 14 ++++++++++++-- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index 39c87e77..1e4c3375 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -730,18 +730,22 @@ def validate_dynamic_ports( :return: if invalid returns a string with the reason for the validation failure, otherwise None :rtype: typing.Optional[str] """ - breadcrumbs = (*breadcrumbs, self.name) - if port_values and not self.dynamic: msg = f'Unexpected ports {port_values}, for a non dynamic namespace' + return PortValidationError(msg, breadcrumbs_to_port((*breadcrumbs, self.name))) + + if self.valid_type is None: + return None + + if isinstance(port_values, dict): + for key, value in port_values.items(): + result = self.validate_dynamic_ports(value, (*breadcrumbs, self.name, key)) + if result is not None: + return result + elif not isinstance(port_values, self.valid_type): + msg = f'Invalid type {type(port_values)} for dynamic port value: expected {self.valid_type}' return PortValidationError(msg, breadcrumbs_to_port(breadcrumbs)) - if self.valid_type is not None: - valid_type = self.valid_type - for port_name, port_value in port_values.items(): - if not isinstance(port_value, valid_type): - msg = f'Invalid type {type(port_value)} for dynamic port value: expected {valid_type}' - return PortValidationError(msg, breadcrumbs_to_port(breadcrumbs + (port_name,))) return None @staticmethod diff --git a/test/test_port.py b/test/test_port.py index 029c109c..55ddbe66 100644 --- a/test/test_port.py +++ b/test/test_port.py @@ -249,9 +249,14 @@ def test_port_namespace_set_valid_type(self): self.assertIsNone(self.port_namespace.valid_type) def test_port_namespace_validate(self): - """Check that validating of sub namespaces works correctly""" + """Check that validating of sub namespaces works correctly. + + By setting a valid type on a port namespace, it automatically becomes dynamic. Port namespaces that are dynamic + should accept arbitrarily nested input and should validate, as long as all leaf values satisfy the `valid_type`. + """ port_namespace_sub = self.port_namespace.create_port_namespace('sub.space') port_namespace_sub.valid_type = int + assert port_namespace_sub.dynamic # Check that passing a non mapping type raises validation_error = self.port_namespace.validate(5) @@ -261,7 +266,12 @@ def test_port_namespace_validate(self): validation_error = self.port_namespace.validate({'sub': {'space': {'output': 5}}}) self.assertIsNone(validation_error) - # Invalid input + # Valid input: `sub.space` is dynamic, so should allow arbitrarily nested namespaces as long as the leaf values + # match the valid type, which is `int` in this example. + validation_error = self.port_namespace.validate({'sub': {'space': {'output': {'invalid': 5}}}}) + self.assertIsNone(validation_error) + + # Invalid input - the value in ``space`` is not ``int`` but a ``str`` validation_error = self.port_namespace.validate({'sub': {'space': {'output': '5'}}}) self.assertIsNotNone(validation_error) From e652528effffb8376753a5b7f3936c8eeea0ee81 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 7 Dec 2022 20:42:31 +0100 Subject: [PATCH 18/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.3` --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12d4c42c..9db1bbb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,12 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.3 - 2022-12-07 + +### Bug fixes +- `PortNamespace`: Fix bug in valid type checking of dynamic namespaces [[#255]](https://github.com/aiidateam/plumpy/pull/255) + + ## v0.21.2 - 2022-11-29 ### Bug fixes From f2dbc1dcc0581c7b6b3438029eeca866df57a24b Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 7 Dec 2022 23:03:40 +0100 Subject: [PATCH 19/45] =?UTF-8?q?=F0=9F=94=A7=20Update=20mypy=20and=20pyli?= =?UTF-8?q?nt=20(#256)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 8 ++++---- pyproject.toml | 14 ++++++-------- src/plumpy/base/state_machine.py | 2 +- src/plumpy/communications.py | 2 +- src/plumpy/mixins.py | 4 ++-- src/plumpy/persistence.py | 2 +- src/plumpy/ports.py | 2 +- src/plumpy/processes.py | 4 +--- src/plumpy/utils.py | 4 ++-- 9 files changed, 19 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 18364c25..af40ef64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 + rev: v4.4.0 hooks: - id: double-quote-string-fixer - id: end-of-file-fixer @@ -9,7 +9,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/ikamensh/flynt/ - rev: '0.76' + rev: '0.77' hooks: - id: flynt @@ -28,7 +28,7 @@ repos: additional_dependencies: ['toml'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.982 + rev: v0.991 hooks: - id: mypy args: [--config-file=pyproject.toml] @@ -42,7 +42,7 @@ repos: )$ - repo: https://github.com/PyCQA/pylint - rev: v2.12.2 + rev: v2.15.8 hooks: - id: pylint language: system diff --git a/pyproject.toml b/pyproject.toml index be8655c2..b3ba0c88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,9 @@ docs = [ 'importlib-metadata~=4.12.0', ] pre-commit = [ - 'mypy==0.982', + 'mypy==0.991', 'pre-commit~=2.2', - 'pylint==2.12.2', + 'pylint==2.15.8', 'types-pyyaml' ] tests = [ @@ -88,11 +88,10 @@ multi_line_output = 3 [tool.mypy] show_error_codes = true -disallow_untyped_defs = true -disallow_incomplete_defs = true -check_untyped_defs = true -warn_unused_ignores = true -warn_redundant_casts = true +strict = true +# reduce stricness, eventually these should be removed +disallow_any_generics = false +warn_return_any = false [[tool.mypy.overrides]] module = 'test.*' @@ -113,7 +112,6 @@ max-line-length = 120 [tool.pylint.messages_control] disable = [ - 'bad-continuation', 'duplicate-code', 'global-statement', 'import-outside-toplevel', diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 274ac968..b62825e1 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -341,7 +341,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transition_failing = False self._transitioning = False - def transition_failed( # pylint: disable=no-self-use + def transition_failed( self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType ) -> None: """Called when a state transitions fails. diff --git a/src/plumpy/communications.py b/src/plumpy/communications.py index 901d2e46..51dff60d 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/communications.py @@ -119,7 +119,7 @@ def wrap_communicator( return LoopCommunicator(communicator, loop) -class LoopCommunicator(kiwipy.Communicator): +class LoopCommunicator(kiwipy.Communicator): # type: ignore """Wrapper around a `kiwipy.Communicator` that schedules any subscriber messages on a given event loop.""" def __init__(self, communicator: kiwipy.Communicator, loop: Optional[asyncio.AbstractEventLoop] = None): diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index 6302184b..a8dcca1e 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -from typing import Any +from typing import Any, Optional from . import persistence -from .utils import SAVED_STATE_TYPE, AttributesDict, Optional +from .utils import SAVED_STATE_TYPE, AttributesDict __all__ = ['ContextMixin'] diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index b646546d..7a15b1cc 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -340,7 +340,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') # type: ignore[name-defined] +SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') # type: ignore[name-defined] # pylint: disable=invalid-name def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index 1e4c3375..7e69bef8 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -602,7 +602,7 @@ def project(self, port_values: MutableMapping[str, Any]) -> MutableMapping[str, def validate( # pylint: disable=arguments-differ self, - port_values: Mapping[str, Any] = None, + port_values: Optional[Mapping[str, Any]] = None, breadcrumbs: Sequence[str] = () ) -> Optional[PortValidationError]: """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 21bb0555..4c1d53ce 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -660,7 +660,7 @@ def add_process_listener(self, listener: ProcessListener) -> None: the specific state condition. """ - assert (listener != self), 'Cannot listen to yourself!' + assert (listener != self), 'Cannot listen to yourself!' # type: ignore self.__event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: @@ -1324,7 +1324,6 @@ def encode_input_args(self, inputs: Any) -> Any: :param inputs: A mapping of the inputs as passed to the process :return: The encoded inputs """ - # pylint: disable=no-self-use return copy.deepcopy(inputs) @protected @@ -1337,7 +1336,6 @@ def decode_input_args(self, encoded: Any) -> Any: :param encoded: :return: The decoded input args """ - # pylint: disable=no-self-use return copy.deepcopy(encoded) def get_status_info(self, out_status_info: dict) -> None: diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index a11ebd01..0ba2b910 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -14,7 +14,7 @@ from .settings import check_override, check_protected if TYPE_CHECKING: - from .processes import ProcessListener # pylint: disable=cyclic-import + from .process_listener import ProcessListener # pylint: disable=cyclic-import __all__ = ['AttributesDict'] @@ -171,7 +171,7 @@ def load_function(name: str, instance: Optional[Any] = None) -> Callable[..., An obj = load_object(name) if inspect.ismethod(obj): if instance is not None: - return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] + return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] # pylint: disable=unnecessary-dunder-call return obj From 514041b001b0a5bbe8f358cd338a8196877453fb Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 7 Mar 2023 14:49:22 +0100 Subject: [PATCH 20/45] Dependencies: Update pre-commit requirement `isort==5.12.0` (#260) Older versions were breaking due to a release of `poetry-core` causing our pre-commit job in the CI to fail. For details, see: https://github.com/PyCQA/isort/issues/2077 Cherry-pick: e5a0ce126d04cc1b6e10aab500f23ea533736800 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af40ef64..b3950100 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: flynt - repo: https://github.com/pycqa/isort - rev: '5.10.1' + rev: '5.12.0' hooks: - id: isort From f8521f6c7358cf2d890cb99984b8a3a26898bde7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 9 Mar 2023 14:08:23 +0100 Subject: [PATCH 21/45] Workchains: Raise if `if_/while_` predicate does not return boolean (#259) The `if_` and `while_` conditionals are constructed with a predicate. The interface expects the predicate to be a callable that returns a boolean, which if true, the body of the conditional is entered. The problem is that the type of the value returned by the predicate was not explicitly checked, and any value that would evaluate as truthy would be accepted. This could potentially lead to unexpected behavior, such as an infinite loop for the `while_` construct. Here the `_Conditional.is_true` method is updated to explicitly check the type of the value returned by the predicate. If anything but a boolean is returned, a `TypeError` is raised. Cherry-pick: 800bcf154c0ea0d4576636b95d2ad2285adec266 --- src/plumpy/workchains.py | 7 ++++++- test/test_workchains.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 1bf0196b..e4eb6b57 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -388,7 +388,12 @@ def predicate(self) -> PREDICATE_TYPE: return self._predicate def is_true(self, workflow: 'WorkChain') -> bool: - return self._predicate(workflow) + result = self._predicate(workflow) + + if not isinstance(result, bool): + raise TypeError(f'The conditional predicate `{self._predicate.__name__}` did not return a boolean') + + return result def __call__(self, *instructions: Union[_Instruction, WC_COMMAND_TYPE]) -> _Instruction: assert self._body is None, 'Instructions have already been set' diff --git a/test/test_workchains.py b/test/test_workchains.py index 7ac020d6..2748c955 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -618,3 +618,14 @@ def step_two(self): workchain = Wf(inputs=dict(subspace={'one': 1, 'two': 2})) workchain.execute() + + +@pytest.mark.parametrize('construct', (if_, while_)) +def test_conditional_return_type(construct): + """Test that a conditional passed to the ``if_`` and ``while_`` functions that does not return a ``bool`` raises.""" + + def invalid_conditional(self): + return 'true' + + with pytest.raises(TypeError, match='The conditional predicate `invalid_conditional` did not return a boolean'): + construct(invalid_conditional)[0].is_true(None) From c8a0a892d70f7b6f3a7b8fd79c753dd29ef89907 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 9 Mar 2023 14:32:38 +0100 Subject: [PATCH 22/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.4` --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9db1bbb1..4ffd59ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,15 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.4 - 2023-03-09 + +### Bug fixes +- Workchains: Raise if `if_/while_` predicate does not return boolean [[#259]](https://github.com/aiidateam/plumpy/pull/259) + +### Dependencies +- Dependencies: Update pre-commit requirement `isort==5.12.0` [[#260]](https://github.com/aiidateam/plumpy/pull/260) + + ## v0.21.3 - 2022-12-07 ### Bug fixes From b30ec7ed80da8c1eb9ed2240600303aafeae17cb Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 14 Mar 2023 18:07:00 +0100 Subject: [PATCH 23/45] Workchains: Accept but deprecate conditional predicates returning None (#261) In 800bcf154c0ea0d4576636b95d2ad2285adec266, the behavior of predicates passed to `_Conditional` instances was changed to raise if anything but a boolean was returned. This was to catch cases where users would return non-boolean values by accident, which would anyway would most likely to broken behavior of the workchain. However, quite a number of existing implementations used to do something like the following in a predicate def conditional_predicate(self): if some_condition is True: return True In the case that `some_condition` was not equal to `True`, the predicate would return `None` which would be evaluated as "falsy" and so would function as if `False` had been returned. In order to not break all implemenations using this behavior, `None` is still accepted and automatically converted to `False`. A `UserWarning` is emitted to warn that the behavior is deprecated. This is used in favor of a `DeprecationWarning` since those are not shown by default, which means the majority of users wouldn't see the deprecation warning. Cherry-pick: f47627adf0dece414c26254515f53fb7414bb3bf --- src/plumpy/workchains.py | 9 +++++++++ test/test_workchains.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index e4eb6b57..673a03c6 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -390,6 +390,15 @@ def predicate(self) -> PREDICATE_TYPE: def is_true(self, workflow: 'WorkChain') -> bool: result = self._predicate(workflow) + if result is None: + import warnings + warnings.warn( + f'The conditional predicate `{self._predicate.__name__}` returned `None` but it should return a bool. ' + 'The behavior is deprecated and will soon start raising an exception, please return ``False`` instead.', + UserWarning + ) + return False + if not isinstance(result, bool): raise TypeError(f'The conditional predicate `{self._predicate.__name__}` did not return a boolean') diff --git a/test/test_workchains.py b/test/test_workchains.py index 2748c955..768989fc 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -621,11 +621,20 @@ def step_two(self): @pytest.mark.parametrize('construct', (if_, while_)) -def test_conditional_return_type(construct): - """Test that a conditional passed to the ``if_`` and ``while_`` functions that does not return a ``bool`` raises.""" +def test_conditional_return_type(construct, caplog): + """Test that a conditional passed to the ``if_`` and ``while_`` functions that does not return a ``bool`` raises. + + For now ``None`` is still accepted and is interpreted as ``False`` but it emits a deprecation warning. + """ def invalid_conditional(self): return 'true' with pytest.raises(TypeError, match='The conditional predicate `invalid_conditional` did not return a boolean'): construct(invalid_conditional)[0].is_true(None) + + def deprecated_conditional(self): + return None + + with pytest.warns(UserWarning): + construct(deprecated_conditional)[0].is_true(None) From b457c78a548e59cee39128a0505389a8250371d5 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 14 Mar 2023 18:13:30 +0100 Subject: [PATCH 24/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.5` --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ffd59ff..1a1fc4dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,12 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.5 - 2023-03-14 + +### Bug fixes +- Workchains: Accept but deprecate conditional predicates returning `None` [[#261]](https://github.com/aiidateam/plumpy/pull/261) + + ## v0.21.4 - 2023-03-09 ### Bug fixes From 2f324d89683b6c8762048c8319ad2b5bc57912df Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 3 Apr 2023 19:25:45 +0200 Subject: [PATCH 25/45] Workchains: Turn exception into warning for incorrect return type in conditional predicates (#265) In f47627adf0dece414c26254515f53fb7414bb3bf, the recently added check on the return type of a conditional predicate was updated slightly to only emit a warning when `None` would be returned instead of raising a `TypeError`. This was done because there quite a number of workchains that unwittingly were relying on this behavior and with the change added in 800bcf154c0ea0d4576636b95d2ad2285adec266 would break all of these. Since then, it was discovered that the exception for `None` may still not be enough as it turns out that there are also quite a number of workchains that use objects that behave like booleans, but are not actually instances of the built-in `bool` base type. Examples, are `numpy.bool` as well as `aiida.orm.Bool`. Since here the behavior of the predicate would be as expected, breaking the existing workflow is actually not desirable. Therefore, the requirement on the return type is relaxed slightly further to any type that implements the `__bool__` method. This includes `None`, but also types like `numpy.bool` and `aiida.orm.Bool` which are often used in the workchain logic by AiiDA code. There might still be edge cases where users may (accidentally) return a type from a function that does implement `__bool__` but whose evaluated value is not what the user intended, but that risk is preferable then forcing users to cast bool-like instances explicit to a `bool`. Cherry-pick: 6f5d326d940f0c94d73fc42166d33aafc095c447 --- src/plumpy/workchains.py | 12 ++++-------- test/test_workchains.py | 29 +++++++++++++++++++---------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 673a03c6..90e35482 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -390,17 +390,13 @@ def predicate(self) -> PREDICATE_TYPE: def is_true(self, workflow: 'WorkChain') -> bool: result = self._predicate(workflow) - if result is None: + if not hasattr(result, '__bool__'): import warnings warnings.warn( - f'The conditional predicate `{self._predicate.__name__}` returned `None` but it should return a bool. ' - 'The behavior is deprecated and will soon start raising an exception, please return ``False`` instead.', - UserWarning + f'The conditional predicate `{self._predicate.__name__}` returned `{result}` which is not boolean-like.' + ' The return value should be `True` or `False` or implement the `__bool__` method. This behavior is ' + 'deprecated and will soon start raising an exception.', UserWarning ) - return False - - if not isinstance(result, bool): - raise TypeError(f'The conditional predicate `{self._predicate.__name__}` did not return a boolean') return result diff --git a/test/test_workchains.py b/test/test_workchains.py index 768989fc..71cd0f6a 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -621,20 +621,29 @@ def step_two(self): @pytest.mark.parametrize('construct', (if_, while_)) -def test_conditional_return_type(construct, caplog): - """Test that a conditional passed to the ``if_`` and ``while_`` functions that does not return a ``bool`` raises. +def test_conditional_return_type(construct, recwarn): + """Test that a conditional passed to the ``if_`` and ``while_`` functions warns for incorrect type.""" - For now ``None`` is still accepted and is interpreted as ``False`` but it emits a deprecation warning. - """ + class BoolLike: + """Instances that implement ``__bool__`` are valid return types for conditional predicate.""" - def invalid_conditional(self): - return 'true' + def __bool__(self): + return True - with pytest.raises(TypeError, match='The conditional predicate `invalid_conditional` did not return a boolean'): - construct(invalid_conditional)[0].is_true(None) + def valid_conditional(self): + return BoolLike() + + construct(valid_conditional)[0].is_true(None) + assert len(recwarn) == 0 - def deprecated_conditional(self): + def conditional_returning_none(self): return None + construct(conditional_returning_none)[0].is_true(None) + assert len(recwarn) == 0 + + def invalid_conditional(self): + return 'true' + with pytest.warns(UserWarning): - construct(deprecated_conditional)[0].is_true(None) + construct(invalid_conditional)[0].is_true(None) From f9ccdd7b5580237643fc76e4bcfbc48c22a30032 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 3 Apr 2023 22:42:56 +0200 Subject: [PATCH 26/45] `PortNamespace`: Make `dynamic` apply recursively (#263) The `dynamic` attribute of a port namespace indicates whether it should accept ports that are not explicitly defined. This was mostly used during validation, when a dictionary of port values was matched against a given `PortNamespace`. This was, however, not being applied recursively _and_ only during validation. For example, given a dynamic portnamespace, validating a dictionary: { 'nested': { 'output': 'some_value' } } would pass validation without problems. However, the `Process.out` call that would actually attempt to attach the output to the process instance would call: self.spec().outputs.get_port(namespace_separator.join(namespace)) which would raise, since `get_port` would raise a `ValueError`: ValueError: port 'output' does not exist in port namespace 'nested' The problem is that the `nested` namespace is expected, because the top level namespace was marked as dynamic, however, it itself would not also be treated as dynamic and so attempting to retrieve `some_value` from the `nested` namespace, would trigger a `KeyError`. Here the logic in `PortNamespace.get_port` is updated to check in advance whether the port exists in the namespace, and if not the case _and_ the namespace is dynamic, the nested port namespace is created. The attributes of the new namespace are inherited from its parent namespace, making the `dynamic` attribute act recursively. Cherry-pick: 4c29f4459c8eb8a8263049ac338189c604702e4e --- src/plumpy/ports.py | 13 ++++++++++++- test/test_port.py | 21 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index 7e69bef8..399c598a 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -446,9 +446,20 @@ def get_port(self, name: str) -> Union[Port, 'PortNamespace']: namespace = name.split(self.NAMESPACE_SEPARATOR) port_name = namespace.pop(0) - if port_name not in self: + if port_name not in self and not self.dynamic: raise ValueError(f"port '{port_name}' does not exist in port namespace '{self.name}'") + if port_name not in self and self.dynamic: + self[port_name] = self.__class__( + name=port_name, + required=self.required, + validator=self.validator, + valid_type=self.valid_type, + default=self.default, + dynamic=self.dynamic, + populate_defaults=self.populate_defaults + ) + if namespace: portnamespace = cast(PortNamespace, self[port_name]) return portnamespace.get_port(self.NAMESPACE_SEPARATOR.join(namespace)) diff --git a/test/test_port.py b/test/test_port.py index 55ddbe66..eca1d09f 100644 --- a/test/test_port.py +++ b/test/test_port.py @@ -208,6 +208,27 @@ def test_port_namespace_get_port(self): port = self.port_namespace.get_port('sub.name.space.' + self.BASE_PORT_NAME) self.assertEqual(port, self.port) + def test_port_namespace_get_port_dynamic(self): + """Test that ``get_port`` does not raise if a port does not exist as long as the namespace is dynamic. + + In this case, the method should create the subnamespace on-the-fly with the same stats as the host namespace. + """ + port_namespace = PortNamespace(self.BASE_PORT_NAMESPACE_NAME, dynamic=True) + + name = 'undefined' + sub_namespace = port_namespace.get_port(name) + + assert isinstance(sub_namespace, PortNamespace) + assert sub_namespace.dynamic + assert sub_namespace.name == name + + name = 'nested.undefined' + sub_namespace = port_namespace.get_port(name) + + assert isinstance(sub_namespace, PortNamespace) + assert sub_namespace.dynamic + assert sub_namespace.name == 'undefined' + def test_port_namespace_create_port_namespace(self): """ Test the create_port_namespace function of the PortNamespace class From 9e3cdb5531604992723d750984408cea6918a8c5 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 3 Apr 2023 22:51:31 +0200 Subject: [PATCH 27/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.6` --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a1fc4dd..391c7d2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,13 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.6 - 2023-04-03 + +### Bug fixes +- `PortNamespace`: Make `dynamic` apply recursively [[#263]](https://github.com/aiidateam/plumpy/pull/263) +- Workchains: Turn exception into warning for incorrect return type in conditional predicates: any type that implements `__bool__` will be accepted [[#265]](https://github.com/aiidateam/plumpy/pull/265) + + ## v0.21.5 - 2023-03-14 ### Bug fixes From 02b8731125a28fb41163e0407c5fc6d27e490442 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 18 Apr 2023 10:52:11 +0200 Subject: [PATCH 28/45] Tests: Add regression tests for `PortNamespace.get_port` In commit 4c29f4459c8eb8a8263049ac338189c604702e4e the implementation of `PortNamespace.get_port` was updated to not raise a `ValueError` if the requested port does not exist, as long as the namespace is dynamic, but instead create the port on-the-fly and return it. This was done to fix allow processes to attach arbitrarily deeply nested outputs in dynamic namespaces, which seems a valid use-case and what a user would suspect from a dynamic output namespace. However, the changes had unintended consequences in other parts of the code. Most notably in the validation of input namespaces that include exposed namespaces which excluded ports containing validators. Normally the validator would not be called because the port was excluded, but if the host namespace of the exposed namespace is dynamic, the new change would create the excluded port on the fly and cause it to be validated nevertheless, which would break existing workflows that didn't actually specify the port value in the inputs since it used to be excluded. Here a regression test is added that checks for this behavior, which currently fails. In addition, the test `test_processes.py::test_namespaced_process_outputs` is modified such that it properly tests the functionality that the commit 4c29f4459c8eb8a8263049ac338189c604702e4e aimed to enable. It already was testing the `Process.out` method for a dynamic output namespace, but, it was a "shallowly" nested namespace. That is to say, the namespace `integer.namespace` was created explicitly and made dynamic and the test verified that arbitrary integer outputs could be added to it. This therefore didn't test that outputs could be added on arbirtarily deep levels of nesting. The test is changed to only explicitly create the `integer` top level output namespace, and then add outputs on two levels deeper, verifying that the dynamicity of the namespace works recursively, and the intermediate namespace `integer.nested` is created on the fly. Cherry-pick: 879d41a667c4c48e1d8c774494abad7d86fb2b71 --- test/test_expose.py | 39 +++++++++++++++++++++++++++++++++++++++ test/test_processes.py | 19 ++++++++++--------- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/test/test_expose.py b/test/test_expose.py index 05af962a..1a495727 100644 --- a/test/test_expose.py +++ b/test/test_expose.py @@ -491,3 +491,42 @@ def define(cls, spec): self.check_ports(ExposeProcess, 'base', ['namespace']) self.check_ports(ExposeProcess, 'base.namespace', ['sub_one']) self.check_namespace_properties(BaseNamespaceProcess, 'namespace', ExposeProcess, 'base.namespace') + + def test_expose_exclude_port_with_validator(self): + """Test that validators of excluded ports are not called, even if the parent namespace is dynamic. + + This is a regression test for https://github.com/aiidateam/plumpy/issues/267. Changes to the method + ``PortNamespace.get_port`` would recursively create and return non-existing ports as long as the parent + namespace is dynamic. This would result in a problem with the validationn of namespaces that contained exposed + namespaces with validators that are dependent on excluded ports. Even though the port was excluded, the changes + in ``get_port`` would now recreate the port on the fly when the validation attempted to retrieve it, thereby + undoing the exclusion of the port when exposed. + """ + + class BaseProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a', required=False) + spec.inputs.dynamic = True + spec.inputs.validator = cls.validator + + @classmethod + def validator(cls, value, ctx): + try: + ctx.get_port('a') + except ValueError: + return None + + if not isinstance(value['a'], str): + return f'value for input `a` should be a str, but got: {type(value["a"])}' + + class ExposeProcess(NewLoopProcess): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.expose_inputs(BaseProcess, namespace='base', exclude=('a',)) + + assert ExposeProcess.spec().inputs.validate({}) is None diff --git a/test/test_processes.py b/test/test_processes.py index d0868e49..737b463d 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -961,7 +961,8 @@ def define(cls, spec): def test_namespaced_process_outputs(self): """Test the output namespacing and validation.""" - namespace = 'integer.namespace' + namespace = 'integer' + namespace_nested = f'{namespace}.nested' class OutputMode(enum.Enum): @@ -983,14 +984,14 @@ def run(self): if self.inputs.output_mode == OutputMode.NONE: pass elif self.inputs.output_mode == OutputMode.DYNAMIC_PORT_NAMESPACE: - self.out(namespace + '.one', 1) - self.out(namespace + '.two', 2) + self.out(namespace_nested + '.one', 1) + self.out(namespace_nested + '.two', 2) elif self.inputs.output_mode == OutputMode.SINGLE_REQUIRED_PORT: self.out('required_bool', False) elif self.inputs.output_mode == OutputMode.BOTH_SINGLE_AND_NAMESPACE: self.out('required_bool', False) - self.out(namespace + '.one', 1) - self.out(namespace + '.two', 2) + self.out(namespace_nested + '.one', 1) + self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail process = DummyDynamicProcess() @@ -1006,8 +1007,8 @@ def run(self): self.assertEqual(process.state, ProcessState.FINISHED) self.assertFalse(process.is_successful) - self.assertEqual(process.outputs['integer']['namespace']['one'], 1) - self.assertEqual(process.outputs['integer']['namespace']['two'], 2) + self.assertEqual(process.outputs[namespace]['nested']['one'], 1) + self.assertEqual(process.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) @@ -1025,8 +1026,8 @@ def run(self): self.assertEqual(process.state, ProcessState.FINISHED) self.assertTrue(process.is_successful) self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs['integer']['namespace']['one'], 1) - self.assertEqual(process.outputs['integer']['namespace']['two'], 2) + self.assertEqual(process.outputs[namespace]['nested']['one'], 1) + self.assertEqual(process.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): From 1607eb3387e7c4ac772b6d1370d67e8cbbeadb94 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 18 Apr 2023 11:39:26 +0200 Subject: [PATCH 29/45] `PortNamespace.get_port`: Only create if `create_dynamically` is `True` In 4c29f4459c8eb8a8263049ac338189c604702e4e, the `get_port` method was updated to automatically create a port if it didn't exist and the namespace is dynamic instead of raising. But this had unwanted knock-on consequences (see previous commit for details). Here, the change in behavior is only triggered when the new argument `create_dynamically` is explicitly set to `True`. This means that by default the old behavior is maintained of a `ValueError` being raised if the requested port doesn't exist. But now `Process.out` can override the default to `True` to automatically support nested namespaces in a dynamic output namespace. Cherry-pick: ad100016745e0557c9f2736b6b5695a2d9035137 --- src/plumpy/ports.py | 16 ++++++++++------ src/plumpy/processes.py | 2 +- test/test_port.py | 6 +++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index 399c598a..fc5f138f 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -428,12 +428,14 @@ def get_description(self) -> Dict[str, Dict[str, Any]]: return description - def get_port(self, name: str) -> Union[Port, 'PortNamespace']: + def get_port(self, name: str, create_dynamically: bool = False) -> Union[Port, 'PortNamespace']: """ Retrieve a (namespaced) port from this PortNamespace. If any of the sub namespaces of the terminal port itself cannot be found, a ValueError will be raised - :param name: name (potentially namespaced) of the port to retrieve + :param name: name (potentially namespaced) of the port to retrieve. + :param create_dynamically: If set to ``True``, dynamically create the requested port if it doesn't exist and the + namespace is dynamic, instead of raising a ``ValueError``. :returns: Port :raises: ValueError if port or namespace does not exist """ @@ -446,10 +448,10 @@ def get_port(self, name: str) -> Union[Port, 'PortNamespace']: namespace = name.split(self.NAMESPACE_SEPARATOR) port_name = namespace.pop(0) - if port_name not in self and not self.dynamic: - raise ValueError(f"port '{port_name}' does not exist in port namespace '{self.name}'") + if port_name not in self: + if not self.dynamic or not create_dynamically: + raise ValueError(f"port '{port_name}' does not exist in port namespace '{self.name}'") - if port_name not in self and self.dynamic: self[port_name] = self.__class__( name=port_name, required=self.required, @@ -462,7 +464,9 @@ def get_port(self, name: str) -> Union[Port, 'PortNamespace']: if namespace: portnamespace = cast(PortNamespace, self[port_name]) - return portnamespace.get_port(self.NAMESPACE_SEPARATOR.join(namespace)) + return portnamespace.get_port( + self.NAMESPACE_SEPARATOR.join(namespace), create_dynamically=create_dynamically + ) return self[port_name] diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 4c1d53ce..d005450d 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1288,7 +1288,7 @@ def out(self, output_port: str, value: Any) -> None: if namespace: port_namespace = cast( ports.PortNamespace, - self.spec().outputs.get_port(namespace_separator.join(namespace)) + self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True) ) else: port_namespace = self.spec().outputs diff --git a/test/test_port.py b/test/test_port.py index eca1d09f..ab9b51a6 100644 --- a/test/test_port.py +++ b/test/test_port.py @@ -209,21 +209,21 @@ def test_port_namespace_get_port(self): self.assertEqual(port, self.port) def test_port_namespace_get_port_dynamic(self): - """Test that ``get_port`` does not raise if a port does not exist as long as the namespace is dynamic. + """Test ``get_port`` with the ``create_dynamically=True`` keyword. In this case, the method should create the subnamespace on-the-fly with the same stats as the host namespace. """ port_namespace = PortNamespace(self.BASE_PORT_NAMESPACE_NAME, dynamic=True) name = 'undefined' - sub_namespace = port_namespace.get_port(name) + sub_namespace = port_namespace.get_port(name, create_dynamically=True) assert isinstance(sub_namespace, PortNamespace) assert sub_namespace.dynamic assert sub_namespace.name == name name = 'nested.undefined' - sub_namespace = port_namespace.get_port(name) + sub_namespace = port_namespace.get_port(name, create_dynamically=True) assert isinstance(sub_namespace, PortNamespace) assert sub_namespace.dynamic From c24784efbd64edb9ac1aaa911f367374831a3c9c Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 20 Apr 2023 15:45:34 +0200 Subject: [PATCH 30/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.7` --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 391c7d2c..3a83c6f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,12 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.7 - 2023-04-20 + +### Bug fixes +- `PortNamespace.get_port`: Only create if `create_dynamically` is `True` [[#268]](https://github.com/aiidateam/plumpy/pull/268) + + ## v0.21.6 - 2023-04-03 ### Bug fixes From 171d0623c487e0b59f2c6a8c24ee665b0eaed8ab Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 7 Jun 2023 12:03:46 +0200 Subject: [PATCH 31/45] Dependencies: Update requirement `mypy==1.3.0` (#270) Also move the configuration from `tox.ini` to `pyproject.toml` and change the `.pre-commit-config.yaml` to use the local install of `mypy` instead of in the pre-commit virtual environment. Cherry-pick: 87982d02a2d929c93cc6bcc233d138659cd70ce4 --- .pre-commit-config.yaml | 29 +++++++++++++++-------------- pyproject.toml | 8 +++++++- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3950100..cae9888f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,20 +27,6 @@ repos: args: ['-i'] additional_dependencies: ['toml'] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.991 - hooks: - - id: mypy - args: [--config-file=pyproject.toml] - additional_dependencies: [ - 'toml', - 'types-pyyaml', - ] - files: > - (?x)^( - src/plumpy/.*py| - )$ - - repo: https://github.com/PyCQA/pylint rev: v2.15.8 hooks: @@ -51,3 +37,18 @@ repos: docs/source/conf.py| test/.*| )$ + +- repo: local + hooks: + - id: mypy + name: mypy + entry: mypy + args: [--config-file=pyproject.toml] + language: python + types: [python] + require_serial: true + pass_filenames: true + files: >- + (?x)^( + src/.*py| + )$ diff --git a/pyproject.toml b/pyproject.toml index b3ba0c88..f8c79b15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ docs = [ 'importlib-metadata~=4.12.0', ] pre-commit = [ - 'mypy==0.991', + 'mypy==1.3.0', 'pre-commit~=2.2', 'pylint==2.15.8', 'types-pyyaml' @@ -91,6 +91,11 @@ show_error_codes = true strict = true # reduce stricness, eventually these should be removed disallow_any_generics = false +disallow_untyped_defs = true +disallow_incomplete_defs = true +check_untyped_defs = true +warn_unused_ignores = true +warn_redundant_casts = true warn_return_any = false [[tool.mypy.overrides]] @@ -101,6 +106,7 @@ check_untyped_defs = false module = [ 'aio_pika.*', 'aiocontextvars.*', + 'frozendict.*', 'kiwipy.*', 'nest_asyncio.*', 'tblib.*', From 079711376e5a287dc19e8ff2b4f21781e8362bdd Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 7 Jun 2023 12:19:16 +0200 Subject: [PATCH 32/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.8` --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a83c6f0..af4e87b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,12 @@ - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.8 - 2023-06-07 + +### Devops +- Dependencies: Update requirement `mypy==1.3.0` [[#270]](https://github.com/aiidateam/plumpy/pull/270) + + ## v0.21.7 - 2023-04-20 ### Bug fixes From 44d27d18b817f86cbff7888059340086c5f108da Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 23 Jun 2023 15:55:53 +0200 Subject: [PATCH 33/45] Release `v0.22.2` --- CHANGELOG.md | 11 +++++++++++ src/plumpy/__init__.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af4e87b3..2f13d214 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v0.22.2 - 2023-06-23 + +This release applies the fixes that were released on the support branch of `v0.21.x`. + +### Bug fixes +- Workchains: Accept but deprecate conditional predicates returning `None` [[#261]](https://github.com/aiidateam/plumpy/pull/261) +- `PortNamespace`: Fix bug in valid type checking of dynamic namespaces [[#255]](https://github.com/aiidateam/plumpy/pull/255) +- `PortNamespace`: Make `dynamic` apply recursively [[#263]](https://github.com/aiidateam/plumpy/pull/263) +- `PortNamespace.get_port`: Only create if `create_dynamically` is `True` [[#268]](https://github.com/aiidateam/plumpy/pull/268) + + ## v0.22.1 - 2022-11-21 ### Dependencies diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 15de5b7f..a22deecc 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined # pylint: disable=undefined-variable -__version__ = '0.22.1' +__version__ = '0.22.2' import logging From 5ddba0f0136994a4703d507d7de98062c26928e7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 26 Oct 2023 14:16:15 +0200 Subject: [PATCH 34/45] Docs: Update ReadTheDocs configuration file (#276) Add the new required `build.os` key and fix it to Ubuntu 22.04 Cherry-pick: 31f85c71730b488aafd680f240485a51884722b7 --- .readthedocs.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 246f3c6a..df0ea532 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,7 +1,11 @@ version: 2 +build: + os: ubuntu-22.04 + tools: + python: '3.8' + python: - version: 3.8 install: - method: pip path: . From fc660017e4d21a1122dc1236e4ade2f54d146500 Mon Sep 17 00:00:00 2001 From: Riccardo Bertossa <33728857+rikigigi@users.noreply.github.com> Date: Fri, 10 Nov 2023 15:21:32 +0100 Subject: [PATCH 35/45] Make `ProcessListener` instances persistable (#277) The `ProcessListener` is made persistable by deriving it, as well as the `EventHelper` class from `persistence.Savable`. The class `EventHelper` is moved to a new file because of a circular import that would result between the `utils` and `persistence` modules. There was a circular reference issue in the test listener that was storing a reference to the process inside it, making its serialization impossible. To fix the tests an ugly hack was used: storing the reference to the process outside the class in a global dict using id as keys. Some more ugly hacks were needed to correctly check the equality of two processes. Instances having different listeners should be ignored. Cherry-pick: 98a375f07db0cacaacdc1545d4d12f25dd00bf1d --- src/plumpy/base/utils.py | 2 + src/plumpy/event_helper.py | 54 ++++++++++++++++++++++ src/plumpy/process_listener.py | 26 ++++++++++- src/plumpy/processes.py | 17 ++++--- src/plumpy/utils.py | 41 ----------------- test/test_processes.py | 5 +- test/test_workchains.py | 40 ++++++++++++++++ test/utils.py | 84 +++++++++++++++++++++++++++++----- 8 files changed, 206 insertions(+), 63 deletions(-) create mode 100644 src/plumpy/event_helper.py diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index c4820f1b..232c5d26 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None: wrapped(self, *args, **kwargs) self._called -= 1 + # Forward wrapped function name to the decorator to show the correct name in the ``call_with_super_check`` + wrapper.__name__ = wrapped.__name__ return wrapper diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py new file mode 100644 index 00000000..3a342321 --- /dev/null +++ b/src/plumpy/event_helper.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import logging +from typing import TYPE_CHECKING, Any, Callable + +from . import persistence + +if TYPE_CHECKING: + from typing import Set, Type + + from .process_listener import ProcessListener # pylint: disable=cyclic-import + +_LOGGER = logging.getLogger(__name__) + + +@persistence.auto_persist('_listeners', '_listener_type') +class EventHelper(persistence.Savable): + + def __init__(self, listener_type: 'Type[ProcessListener]'): + assert listener_type is not None, 'Must provide valid listener type' + + self._listener_type = listener_type + self._listeners: 'Set[ProcessListener]' = set() + + def add_listener(self, listener: 'ProcessListener') -> None: + assert isinstance(listener, self._listener_type), 'Listener is not of right type' + self._listeners.add(listener) + + def remove_listener(self, listener: 'ProcessListener') -> None: + self._listeners.discard(listener) + + def remove_all_listeners(self) -> None: + self._listeners.clear() + + @property + def listeners(self) -> 'Set[ProcessListener]': + return self._listeners + + def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + """Call an event method on all listeners. + + :param event_function: the method of the ProcessListener + :param args: arguments to pass to the method + :param kwargs: keyword arguments to pass to the method + + """ + if event_function is None: + raise ValueError('Must provide valid event method') + + # Make a copy of the list for iteration just in case it changes in a callback + for listener in list(self.listeners): + try: + getattr(listener, event_function.__name__)(*args, **kwargs) + except Exception as exception: # pylint: disable=broad-except + _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index c0a49d9a..110394a2 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import abc -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Dict, Optional + +from . import persistence +from .utils import SAVED_STATE_TYPE, protected __all__ = ['ProcessListener'] @@ -8,7 +11,26 @@ from .processes import Process # pylint: disable=cyclic-import -class ProcessListener(metaclass=abc.ABCMeta): +@persistence.auto_persist('_params') +class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): + + # region Persistence methods + + def __init__(self) -> None: + super().__init__() + self._params: Dict[str, Any] = {} + + def init(self, **kwargs: Any) -> None: + self._params = kwargs + + @protected + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] + ) -> None: + super().load_instance_state(saved_state, load_context) + self.init(**saved_state['_params']) + + # endregion def on_process_created(self, process: 'Process') -> None: """ diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index d005450d..57e12af5 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -42,6 +42,7 @@ from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check +from .event_helper import EventHelper from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -91,7 +92,9 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return func_wrapper -@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status') +@persistence.auto_persist( + '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' +) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -289,7 +292,7 @@ def __init__( # Runtime variables self._future = persistence.SavableFuture(loop=self._loop) - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = logger self._communicator = communicator @@ -612,7 +615,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi # Runtime variables, set initial states self._future = persistence.SavableFuture() - self.__event_helper = utils.EventHelper(ProcessListener) + self._event_helper = EventHelper(ProcessListener) self._logger = None self._communicator = None @@ -661,11 +664,11 @@ def add_process_listener(self, listener: ProcessListener) -> None: """ assert (listener != self), 'Cannot listen to yourself!' # type: ignore - self.__event_helper.add_listener(listener) + self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: """Remove a process listener from the process.""" - self.__event_helper.remove_listener(listener) + self._event_helper.remove_listener(listener) @protected def set_logger(self, logger: logging.Logger) -> None: @@ -778,7 +781,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) + self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -891,7 +894,7 @@ def on_close(self) -> None: self._closed = True def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - self.__event_helper.fire_event(evt, self, *args, **kwargs) + self._event_helper.fire_event(evt, self, *args, **kwargs) # endregion diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 0ba2b910..4eab8efe 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -27,47 +27,6 @@ PID_TYPE = Hashable # pylint: disable=invalid-name -class EventHelper: - - def __init__(self, listener_type: 'Type[ProcessListener]'): - assert listener_type is not None, 'Must provide valid listener type' - - self._listener_type = listener_type - self._listeners: 'Set[ProcessListener]' = set() - - def add_listener(self, listener: 'ProcessListener') -> None: - assert isinstance(listener, self._listener_type), 'Listener is not of right type' - self._listeners.add(listener) - - def remove_listener(self, listener: 'ProcessListener') -> None: - self._listeners.discard(listener) - - def remove_all_listeners(self) -> None: - self._listeners.clear() - - @property - def listeners(self) -> 'Set[ProcessListener]': - return self._listeners - - def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - """Call an event method on all listeners. - - :param event_function: the method of the ProcessListener - :param args: arguments to pass to the method - :param kwargs: keyword arguments to pass to the method - - """ - if event_function is None: - raise ValueError('Must provide valid event method') - - # Make a copy of the list for iteration just in case it changes in a callback - for listener in list(self.listeners): - try: - getattr(listener, event_function.__name__)(*args, **kwargs) - except Exception as exception: # pylint: disable=broad-except - _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) - - class Frozendict(Mapping): """ An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping` diff --git a/test/test_processes.py b/test/test_processes.py index 737b463d..158cce66 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -800,7 +800,8 @@ def test_instance_state_with_outputs(self): # Check that it is a copy self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {})) # Check the contents are the same - self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {})) + # Remove the ``ProcessSaver`` instance that is only used for testing + utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'}) self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {})) @@ -875,7 +876,7 @@ def _check_round_trip(self, proc1): bundle2 = plumpy.Bundle(proc2) self.assertEqual(proc1.pid, proc2.pid) - self.assertDictEqual(bundle1, bundle2) + utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'}) class TestProcessNamespace(unittest.TestCase): diff --git a/test/test_workchains.py b/test/test_workchains.py index 71cd0f6a..c698aff9 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -283,6 +283,46 @@ def test_checkpointing(self): if step not in ['isA', 's2', 'isB', 's3']: self.assertTrue(finished, f'Step {step} was not called by workflow') + def test_listener_persistence(self): + persister = plumpy.InMemoryPersister() + process_finished_count = 0 + + class TestListener(plumpy.ProcessListener): + + def on_process_finished(self, process, output): + nonlocal process_finished_count + process_finished_count += 1 + + class SimpleWorkChain(plumpy.WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.outline( + cls.step1, + cls.step2, + ) + + def step1(self): + persister.save_checkpoint(self, 'step1') + + def step2(self): + persister.save_checkpoint(self, 'step2') + + # add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint + globals()['SimpleWorkChain'] = SimpleWorkChain + globals()['TestListener'] = TestListener + + workchain = SimpleWorkChain() + workchain.add_process_listener(TestListener()) + output = workchain.execute() + + self.assertEqual(process_finished_count, 1) + + workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle() + workchain_checkpoint.execute() + self.assertEqual(process_finished_count, 2) + def test_return_in_outline(self): class WcWithReturn(WorkChain): diff --git a/test/utils.py b/test/utils.py index feb3d1c8..1f7408f6 100644 --- a/test/utils.py +++ b/test/utils.py @@ -185,7 +185,7 @@ def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) - def middle_step(self,): + def middle_step(self): return process_states.Continue(self.last_step) def last_step(self): @@ -260,25 +260,72 @@ def _save(self, p): self.outputs.append(p.outputs.copy()) -class ProcessSaver(plumpy.ProcessListener, Saver): +_ProcessSaverProcReferences = {} +_ProcessSaver_Saver = {} + + +class ProcessSaver(plumpy.ProcessListener): """ - Save the instance state of a process each time it is about to enter a new state + Save the instance state of a process each time it is about to enter a new state. + NB: this is not a general purpose saver, it is only intended to be used for testing + The listener instances inside a process are persisted, so if we store a process + reference in the ProcessSaver instance, we will have a circular reference that cannot be + persisted. So we store the Saver instance in a global dictionary with the key the id of the + ProcessSaver instance. + In the init_not_persistent method we initialize the instances that cannot be persisted, + like the saver instance. The __del__ method is used to clean up the global dictionaries + (note there is no guarantee that __del__ will be called) + """ + def __del__(self): + global _ProcessSaver_Saver + global _ProcessSaverProcReferences + if _ProcessSaverProcReferences is not None and id(self) in _ProcessSaverProcReferences: + del _ProcessSaverProcReferences[id(self)] + if _ProcessSaver_Saver is not None and id(self) in _ProcessSaver_Saver: + del _ProcessSaver_Saver[id(self)] + + def get_process(self): + global _ProcessSaverProcReferences + return _ProcessSaverProcReferences[id(self)] + + def _save(self, p): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)]._save(p) + + def set_process(self, process): + global _ProcessSaverProcReferences + _ProcessSaverProcReferences[id(self)] = process + def __init__(self, proc): - plumpy.ProcessListener.__init__(self) - Saver.__init__(self) - self.process = proc + super().__init__() proc.add_process_listener(self) + self.init_not_persistent(proc) + + def init_not_persistent(self, proc): + global _ProcessSaver_Saver + _ProcessSaver_Saver[id(self)] = Saver() + self.set_process(proc) def capture(self): - self._save(self.process) - if not self.process.has_terminated(): + self._save(self.get_process()) + if not self.get_process().has_terminated(): try: - self.process.execute() + self.get_process().execute() except Exception: pass + @property + def snapshots(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].snapshots + + @property + def outputs(self): + global _ProcessSaver_Saver + return _ProcessSaver_Saver[id(self)].outputs + @utils.override def on_process_running(self, process): self._save(process) @@ -335,7 +382,13 @@ def check_process_against_snapshots(loop, proc_class, snapshots): """ for i, bundle in zip(list(range(0, len(snapshots))), snapshots): loaded = bundle.unbundle(plumpy.LoadSaveContext(loop=loop)) - saver = ProcessSaver(loaded) + # the process listeners are persisted + saver = list(loaded._event_helper._listeners)[0] + assert isinstance(saver, ProcessSaver) + # the process reference inside this particular implementation of process listener + # cannot be persisted because of a circular reference. So we load it there + # also the saver is not persisted for the same reason. We load it manually + saver.init_not_persistent(loaded) saver.capture() # Now check going backwards until running that the saved states match @@ -345,7 +398,11 @@ def check_process_against_snapshots(loop, proc_class, snapshots): break compare_dictionaries( - snapshots[-j], saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], exclude={'exception'} + snapshots[-j], + saver.snapshots[-j], + snapshots[-j], + saver.snapshots[-j], + exclude={'exception', '_listeners'} ) j += 1 @@ -376,6 +433,11 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): elif isinstance(v1, list) and isinstance(v2, list): for vv1, vv2 in zip(v1, v2): compare_value(bundle1, bundle2, vv1, vv2, exclude) + elif isinstance(v1, set) and isinstance(v2, set) and len(v1) == len(v2) and len(v1) <= 1: + # TODO: implement sets with more than one element + compare_value(bundle1, bundle2, list(v1), list(v2), exclude) + elif isinstance(v1, set) and isinstance(v2, set): + raise NotImplementedError('Comparison between sets not implemented') else: if v1 != v2: raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') From fda244a139e26472c37f698d7d64f3d6e13ad57b Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 10 Nov 2023 15:40:02 +0100 Subject: [PATCH 36/45] Catch `ChannelInvalidStateError` in process state change (#278) In `Process.on_entered`, the `Communicator.broadcast_send` method is called to broadcast the state change to subscribers over RabbitMQ. This can throw a `ChannelInvalidStateError` in addition to the `ConnectionClose` exception that was already being caught, in case there is a problem with the connection. Cherry-pick: db2af9acf7c139798a21e574d6308ae21b3b7513 --- src/plumpy/processes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 57e12af5..c9c1a118 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -34,7 +34,7 @@ except ModuleNotFoundError: from contextvars import ContextVar -from aio_pika.exceptions import ConnectionClosed +from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed import kiwipy import yaml @@ -718,7 +718,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) - except ConnectionClosed: + except (ConnectionClosed, ChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: From 8043a630c50b60d599d3d8e7d85aefbb3ff46ade Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 13 Nov 2023 10:07:20 +0100 Subject: [PATCH 37/45] Dependencies: Add support for Python 3.12 (#275) Cherry-pick: 2af390738df3f151c8225c01e265527b65d7a005 --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 2 +- pyproject.toml | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 1de6e969..427ccfd5 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -58,7 +58,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] services: postgres: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 73dd3710..611473f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] services: rabbitmq: diff --git a/pyproject.toml b/pyproject.toml index f8c79b15..dd070e4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', ] keywords = ['workflow', 'multithreaded', 'rabbitmq'] requires-python = '>=3.7' @@ -61,7 +62,7 @@ tests = [ 'pytest==6.2.5', 'pytest-asyncio==0.16.0', 'pytest-cov==3.0.0', - 'pytest-notebook==0.8.1', + 'pytest-notebook>=0.8.0', 'shortuuid==1.0.8', 'importlib-resources~=5.2', ] From f7b320e94cb3f7d9c019da1e414079304fbf046b Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 13 Nov 2023 09:21:22 +0100 Subject: [PATCH 38/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.9` --- CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f13d214..ba626fc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,18 @@ This release applies the fixes that were released on the support branch of `v0.2 - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.9 - 2023-11-10 + +### Features +- Make `ProcessListener` instances persistable [[98a375f]](https://github.com/aiidateam/plumpy/commit/98a375f07db0cacaacdc1545d4d12f25dd00bf1d) + +### Fixes +- Catch `ChannelInvalidStateError` in process state change [[db2af9a]](https://github.com/aiidateam/plumpy/commit/db2af9acf7c139798a21e574d6308ae21b3b7513) + +### Devops +- Update ReadTheDocs configuration file [[31f85c7]](https://github.com/aiidateam/plumpy/commit/31f85c71730b488aafd680f240485a51884722b7) + + ## v0.21.8 - 2023-06-07 ### Devops From ff5770f55da9974b693bed5e731c211ee47c39cd Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 13 Nov 2023 09:21:40 +0100 Subject: [PATCH 39/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.10` --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba626fc3..0f675534 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,13 @@ This release applies the fixes that were released on the support branch of `v0.2 - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.10 - 2023-11-13 + +### Dependencies + +- Dependencies: Add support for Python 3.12 [[2af3907]](https://github.com/aiidateam/plumpy/commit/2af390738df3f151c8225c01e265527b65d7a005) + + ## v0.21.9 - 2023-11-10 ### Features From 14b7c1a83026699fcc9f01c56a85872a8dd3c25c Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 2 Feb 2024 15:44:23 +0100 Subject: [PATCH 40/45] Dependencies: Drop support for Python 3.7 (#282) --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 2 +- pyproject.toml | 3 +-- src/plumpy/events.py | 2 -- src/plumpy/process_states.py | 6 ------ src/plumpy/processes.py | 6 ------ 6 files changed, 3 insertions(+), 18 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 427ccfd5..656c8d7e 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -58,7 +58,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] services: postgres: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 611473f3..f324ec7f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] services: rabbitmq: diff --git a/pyproject.toml b/pyproject.toml index dd070e4a..33944c33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ classifiers = [ 'License :: OSI Approved :: MIT License', 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', @@ -28,7 +27,7 @@ classifiers = [ 'Programming Language :: Python :: 3.12', ] keywords = ['workflow', 'multithreaded', 'rabbitmq'] -requires-python = '>=3.7' +requires-python = '>=3.8' dependencies = [ 'kiwipy[rmq]~=0.8.3', 'nest_asyncio~=1.5,>=1.5.1', diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 735fb3f6..60a5306e 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -55,8 +55,6 @@ def reset_event_loop_policy() -> None: cls = loop.__class__ del cls._check_running # type: ignore - # typo in Python 3.7 source - del cls._check_runnung # type: ignore del cls._nest_patched # type: ignore # pylint: enable=protected-access diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index da91b506..d962eb77 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import asyncio from enum import Enum import sys import traceback @@ -231,11 +230,6 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over except Interruption: # Let this bubble up to the caller raise - except asyncio.CancelledError: # pylint: disable=try-except-raise - # note this re-raise is only required in python<=3.7, - # for python>=3.8 asyncio.CancelledError does not inherit from Exception, - # so will not be caught below - raise except Exception: # pylint: disable=broad-except excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(State, excepted) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index c9c1a118..b6e14ad9 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1235,12 +1235,6 @@ async def step(self) -> None: except KeyboardInterrupt: # pylint: disable=try-except-raise raise - except asyncio.CancelledError: # pylint: disable=try-except-raise - # note this re-raise is only required in python<=3.7, - # where asyncio.CancelledError == concurrent.futures.CancelledError - # it is encountered when the run_task is cancelled - # for python>=3.8 asyncio.CancelledError does not inherit from Exception, so will not be caught below - raise except Exception: # pylint: disable=broad-except # Overwrite the next state to go to excepted directly next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) From c180f5260d8229ef46a8480000f8b0658c48e70b Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 2 Feb 2024 15:37:47 +0100 Subject: [PATCH 41/45] Release `v0.22.3` --- CHANGELOG.md | 14 ++++++++++++++ src/plumpy/__init__.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f675534..aff0ad12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## v0.22.3 - 2024-02-02 + +### Bug fixes +- Catch `ChannelInvalidStateError` in process state change [[fda244a]](https://github.com/aiidateam/plumpy/commit/fda244a139e26472c37f698d7d64f3d6e13ad57b) +- Make `ProcessListener` instances persistable [[fc66001]](https://github.com/aiidateam/plumpy/commit/fc660017e4d21a1122dc1236e4ade2f54d146500) + +### Dependencies +- Add support for Python 3.12 [[8043a63]](https://github.com/aiidateam/plumpy/commit/8043a630c50b60d599d3d8e7d85aefbb3ff46ade) +- Drop support for Python 3.7 [[14b7c1a]](https://github.com/aiidateam/plumpy/commit/14b7c1a83026699fcc9f01c56a85872a8dd3c25c) + +### Devops +- Update ReadTheDocs configuration file [[5ddba0f]](https://github.com/aiidateam/plumpy/commit/5ddba0f0136994a4703d507d7de98062c26928e7) + + ## v0.22.2 - 2023-06-23 This release applies the fixes that were released on the support branch of `v0.21.x`. diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index a22deecc..ea88f872 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined # pylint: disable=undefined-variable -__version__ = '0.22.2' +__version__ = '0.22.3' import logging From 20e5898e0c9037624988fe321e784f4fe38a2e8d Mon Sep 17 00:00:00 2001 From: Sebastiano Bisacchi <33641204+sebaB003@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:20:53 +0200 Subject: [PATCH 42/45] Make `Waiting.resume()` idempotent (#285) Calling `Waiting.resume()` when it had already been resumed would raise an exception. Here, the method is made idempotent by checking first whether the future has already been resolved. This fix ensures the behavior matches the behavior of the other state transitions: calling `play` on an already running process and calling `pause` on an already paused process isn't rising any error. --- src/plumpy/process_states.py | 4 ++++ test/rmq/docker-compose.yml | 11 +++++------ test/test_processes.py | 25 +++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index d962eb77..3407412d 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -330,6 +330,10 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' + + if self._waiting_future.done(): + return + self._waiting_future.set_result(value) diff --git a/test/rmq/docker-compose.yml b/test/rmq/docker-compose.yml index 6e743f7d..456690e0 100644 --- a/test/rmq/docker-compose.yml +++ b/test/rmq/docker-compose.yml @@ -12,16 +12,15 @@ version: '3.4' services: - rabbit: - image: rabbitmq:3.8.3-management - container_name: plumpy-rmq + image: rabbitmq:3-management-alpine + container_name: plumpy_rmq + ports: + - 5672:5672 + - 15672:15672 environment: RABBITMQ_DEFAULT_USER: guest RABBITMQ_DEFAULT_PASS: guest - ports: - - '5672:5672' - - '15672:15672' healthcheck: test: rabbitmq-diagnostics -q ping interval: 30s diff --git a/test/test_processes.py b/test/test_processes.py index 158cce66..0cb4161b 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -836,6 +836,31 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) + def test_double_restart(self): + """Test that consecutive restarts do not cause any issues, this is tested for concurrency reasons.""" + loop = asyncio.get_event_loop() + proc = _RestartProcess() + + async def async_test(): + await utils.run_until_waiting(proc) + + # Save the state of the process + saved_state = plumpy.Bundle(proc) + + # Load a process from the saved state + loaded_proc = saved_state.unbundle() + self.assertEqual(loaded_proc.state, ProcessState.WAITING) + + # Now resume it twice in succession + loaded_proc.resume() + loaded_proc.resume() + + await loaded_proc.step_until_terminated() + self.assertEqual(loaded_proc.outputs, {'finished': True}) + + loop.create_task(proc.step_until_terminated()) + loop.run_until_complete(async_test()) + def test_wait_save_continue(self): """ Test that process saved while in WAITING state restarts correctly when loaded """ loop = asyncio.get_event_loop() From b3837fc9dbf7dc5aca0785e93b94cf5b89d04a91 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 1 Jul 2024 21:17:53 +0200 Subject: [PATCH 43/45] Docs: Update `CHANGELOG.md` with notes of `v0.21.11` --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aff0ad12..fb4f1efe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,15 @@ This release applies the fixes that were released on the support branch of `v0.2 - Update the `mypy` pre-commit dependency [[#246]](https://github.com/aiidateam/plumpy/pull/246) +## v0.21.11 - 2024-07-01 + +### Fixes +- Make `Waiting.resume()` idempotent [[a79497b]](https://github.com/aiidateam/plumpy/commit/a79497ba37cef7bc609cee90535ad86708fc48f9) + +### Dependencies +- Add requirement `nbdime<4` [[94df0df]](https://github.com/aiidateam/plumpy/commit/94df0dfd0a3ea93174aa4de83ac5e06246350c27) + + ## v0.21.10 - 2023-11-13 ### Dependencies From 3b9318c67e543a03c85ed280e453db4c385fdefd Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 08:46:55 +0100 Subject: [PATCH 44/45] Dependencies: Add `type-extensions` to direct dependencies (#293) This is a temporary workaround for `aio-pika` breaking for v4.5.0 See https://github.com/mosquito/aio-pika/issues/649 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 33944c33..78c34807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ 'kiwipy[rmq]~=0.8.3', 'nest_asyncio~=1.5,>=1.5.1', 'pyyaml~=6.0', + # XXX: workaround for https://github.com/mosquito/aio-pika/issues/649 + 'typing-extensions~=4.12', ] [project.urls] From 55e05e956c9715fb69785d83d0194b65811b4720 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 09:17:12 +0100 Subject: [PATCH 45/45] Devops: Switch to ruff and other devops improvements (#289) * Bump version for pytest and pytest tools * Use ruff with aligning to aiida-core ruff config * Update mypy to 1.13.0 * Remove all pylint ignore annotations * Bump ruff to 0.8.0 for sort the imports * Update python version to 3.12 for pre-commit CI * Double quotes for eval inside single quotes f-strings * Exclude tests/.* for ruff linting * Do not fail-fast so tests for different python version independent --- .github/workflows/cd.yml | 2 +- .github/workflows/ci.yml | 11 +- .github/workflows/validate_release_tag.py | 8 +- .pre-commit-config.yaml | 61 +++--- docs/source/conf.py | 32 ++- docs/source/tutorial.ipynb | 203 ++++++++---------- examples/process_helloworld.py | 1 - examples/process_wait_and_resume.py | 7 +- examples/workchain_simple.py | 1 - pyproject.toml | 46 +++- src/plumpy/__init__.py | 19 +- src/plumpy/base/__init__.py | 4 +- src/plumpy/base/state_machine.py | 60 +++--- src/plumpy/base/utils.py | 2 +- src/plumpy/communications.py | 23 +- src/plumpy/event_helper.py | 5 +- src/plumpy/events.py | 20 +- src/plumpy/exceptions.py | 2 +- src/plumpy/futures.py | 13 +- src/plumpy/lang.py | 9 +- src/plumpy/loaders.py | 6 +- src/plumpy/mixins.py | 1 + src/plumpy/persistence.py | 43 ++-- src/plumpy/ports.py | 71 +++--- src/plumpy/process_comms.py | 64 +++--- src/plumpy/process_listener.py | 3 +- src/plumpy/process_spec.py | 13 +- src/plumpy/process_states.py | 55 ++--- src/plumpy/processes.py | 83 +++---- src/plumpy/settings.py | 1 - src/plumpy/utils.py | 31 +-- src/plumpy/workchains.py | 79 +++---- {test => tests}/__init__.py | 0 {test => tests}/base/__init__.py | 0 {test => tests}/base/test_statemachine.py | 14 +- {test => tests}/base/test_utils.py | 6 - {test => tests}/conftest.py | 1 + .../notebooks/get_event_loop.ipynb | 4 +- {test => tests}/persistence/__init__.py | 0 {test => tests}/persistence/test_inmemory.py | 31 +-- {test => tests}/persistence/test_pickle.py | 28 +-- {test => tests}/rmq/__init__.py | 0 {test => tests}/rmq/docker-compose.yml | 0 {test => tests}/rmq/test_communicator.py | 39 ++-- {test => tests}/rmq/test_process_comms.py | 12 +- {test => tests}/test_communications.py | 4 +- {test => tests}/test_events.py | 1 + {test => tests}/test_expose.py | 146 +++++-------- {test => tests}/test_lang.py | 19 +- {test => tests}/test_loaders.py | 7 +- {test => tests}/test_persistence.py | 6 +- {test => tests}/test_port.py | 16 +- {test => tests}/test_process_comms.py | 12 +- {test => tests}/test_process_spec.py | 2 - {test => tests}/test_processes.py | 81 +++---- {test => tests}/test_utils.py | 6 - {test => tests}/test_waiting_process.py | 1 - {test => tests}/test_workchains.py | 87 ++++---- {test => tests}/utils.py | 51 ++--- 59 files changed, 712 insertions(+), 841 deletions(-) rename {test => tests}/__init__.py (100%) rename {test => tests}/base/__init__.py (100%) rename {test => tests}/base/test_statemachine.py (90%) rename {test => tests}/base/test_utils.py (99%) rename {test => tests}/conftest.py (99%) rename {test => tests}/notebooks/get_event_loop.ipynb (90%) rename {test => tests}/persistence/__init__.py (100%) rename {test => tests}/persistence/test_inmemory.py (88%) rename {test => tests}/persistence/test_pickle.py (89%) rename {test => tests}/rmq/__init__.py (100%) rename {test => tests}/rmq/docker-compose.yml (100%) rename {test => tests}/rmq/test_communicator.py (90%) rename {test => tests}/rmq/test_process_comms.py (97%) rename {test => tests}/test_communications.py (100%) rename {test => tests}/test_events.py (99%) rename {test => tests}/test_expose.py (84%) rename {test => tests}/test_lang.py (97%) rename {test => tests}/test_loaders.py (97%) rename {test => tests}/test_persistence.py (97%) rename {test => tests}/test_port.py (99%) rename {test => tests}/test_process_comms.py (87%) rename {test => tests}/test_process_spec.py (99%) rename {test => tests}/test_processes.py (97%) rename {test => tests}/test_utils.py (99%) rename {test => tests}/test_waiting_process.py (99%) rename {test => tests}/test_workchains.py (93%) rename {test => tests}/utils.py (93%) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 656c8d7e..ad524fb2 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -88,7 +88,7 @@ jobs: run: pip install -e .[tests] - name: Run pytest - run: pytest -sv --cov=plumpy test + run: pytest -s --cov=plumpy tests - name: Create xml coverage run: coverage xml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f324ec7f..8d813780 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: continuous-integration +name: ci on: [push, pull_request] @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.12 uses: actions/setup-python@v2 with: - python-version: '3.8' + python-version: '3.12' - name: Install Python dependencies run: pip install -e .[pre-commit] @@ -26,6 +26,7 @@ jobs: strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + fail-fast: false services: rabbitmq: @@ -42,10 +43,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install python dependencies - run: pip install -e .[tests] + run: pip install .[tests] - name: Run pytest - run: pytest -sv --cov=plumpy test + run: pytest -s --cov=plumpy tests/ - name: Create xml coverage run: coverage xml diff --git a/.github/workflows/validate_release_tag.py b/.github/workflows/validate_release_tag.py index bdd35537..4caf68b8 100644 --- a/.github/workflows/validate_release_tag.py +++ b/.github/workflows/validate_release_tag.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Validate that the version in the tag label matches the version of the package.""" + import argparse import ast from pathlib import Path @@ -17,8 +18,11 @@ def get_version_from_module(content: str) -> str: try: return next( - ast.literal_eval(statement.value) for statement in module.body if isinstance(statement, ast.Assign) - for target in statement.targets if isinstance(target, ast.Name) and target.id == '__version__' + ast.literal_eval(statement.value) + for statement in module.body + if isinstance(statement, ast.Assign) + for target in statement.targets + if isinstance(target, ast.Name) and target.id == '__version__' ) except StopIteration as exception: raise IOError('Unable to find the `__version__` attribute in the module.') from exception diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cae9888f..970f9d71 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,46 +1,35 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: double-quote-string-fixer - - id: end-of-file-fixer - - id: fix-encoding-pragma - - id: mixed-line-ending - - id: trailing-whitespace + - id: double-quote-string-fixer + - id: end-of-file-fixer + - id: fix-encoding-pragma + - id: mixed-line-ending + - id: trailing-whitespace -- repo: https://github.com/ikamensh/flynt/ - rev: '0.77' + - repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 hooks: - - id: flynt + - id: flynt + args: [--line-length=120, --fail-on-change] -- repo: https://github.com/pycqa/isort - rev: '5.12.0' + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.0 hooks: - - id: isort + - id: ruff-format + exclude: &exclude_ruff > + (?x)^( + tests/.*| + )$ -- repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 - hooks: - - id: yapf - name: yapf - types: [python] - args: ['-i'] - additional_dependencies: ['toml'] - -- repo: https://github.com/PyCQA/pylint - rev: v2.15.8 - hooks: - - id: pylint - language: system - exclude: > - (?x)^( - docs/source/conf.py| - test/.*| - )$ + - id: ruff + exclude: *exclude_ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] -- repo: local + - repo: local hooks: - - id: mypy + - id: mypy name: mypy entry: mypy args: [--config-file=pyproject.toml] @@ -49,6 +38,6 @@ repos: require_serial: true pass_filenames: true files: >- - (?x)^( - src/.*py| - )$ + (?x)^( + src/.*py| + )$ diff --git a/docs/source/conf.py b/docs/source/conf.py index a1c6f26e..b1a2a019 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -8,11 +8,9 @@ import filecmp import os -from pathlib import Path import shutil -import subprocess -import sys import tempfile +from pathlib import Path import plumpy @@ -32,8 +30,12 @@ master_doc = 'index' language = None extensions = [ - 'myst_nb', 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', - 'IPython.sphinxext.ipython_console_highlighting' + 'myst_nb', + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.viewcode', + 'sphinx.ext.intersphinx', + 'IPython.sphinxext.ipython_console_highlighting', ] # List of patterns, relative to source directory, that match files and @@ -46,14 +48,14 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/3.8', None), - 'kiwipy': ('https://kiwipy.readthedocs.io/en/latest/', None) + 'kiwipy': ('https://kiwipy.readthedocs.io/en/latest/', None), } myst_enable_extensions = ['colon_fence', 'deflist', 'html_image', 'smartquotes', 'substitution'] myst_url_schemes = ('http', 'https', 'mailto') myst_substitutions = { 'rabbitmq': '[RabbitMQ](https://www.rabbitmq.com/)', - 'kiwipy': '[kiwipy](https://kiwipy.readthedocs.io)' + 'kiwipy': '[kiwipy](https://kiwipy.readthedocs.io)', } jupyter_execute_notebooks = 'cache' execution_show_tb = 'READTHEDOCS' in os.environ @@ -84,7 +86,7 @@ 'use_issues_button': True, 'path_to_docs': 'docs', 'use_edit_page_button': True, - 'extra_navbar': '' + 'extra_navbar': '', } # API Documentation @@ -112,9 +114,17 @@ def run_apidoc(app): # this ensures that document rebuilds are not triggered every time (due to change in file mtime) with tempfile.TemporaryDirectory() as tmpdirname: options = [ - '-o', tmpdirname, - str(package_dir), '--private', '--force', '--module-first', '--separate', '--no-toc', '--maxdepth', '4', - '-q' + '-o', + tmpdirname, + str(package_dir), + '--private', + '--force', + '--module-first', + '--separate', + '--no-toc', + '--maxdepth', + '4', + '-q', ] os.environ['SPHINX_APIDOC_OPTIONS'] = 'members,special-members,private-members,undoc-members,show-inheritance' diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index 90194728..c1fdb3b2 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -34,10 +34,11 @@ "outputs": [], "source": [ "import asyncio\n", - "from pprint import pprint\n", "import time\n", + "from pprint import pprint\n", "\n", "import kiwipy\n", + "\n", "import plumpy\n", "\n", "# this is required because jupyter is already running an event loop\n", @@ -116,16 +117,16 @@ ], "source": [ "class SimpleProcess(plumpy.Process):\n", - "\n", " def run(self):\n", " print(self.state.name)\n", - " \n", + "\n", + "\n", "process = SimpleProcess()\n", "print(process.state.name)\n", "process.execute()\n", "print(process.state.name)\n", - "print(\"Success\", process.is_successful)\n", - "print(\"Result\", process.result())" + "print('Success', process.is_successful)\n", + "print('Result', process.result())" ] }, { @@ -204,17 +205,16 @@ ], "source": [ "class SpecProcess(plumpy.Process):\n", - " \n", " @classmethod\n", " def define(cls, spec: plumpy.ProcessSpec):\n", " super().define(spec)\n", " spec.input('input1', valid_type=str, help='A help string')\n", " spec.output('output1')\n", - " \n", + "\n", " spec.input_namespace('input2')\n", " spec.input('input2.input2a')\n", " spec.input('input2.input2b', default='default')\n", - " \n", + "\n", " spec.output_namespace('output2')\n", " spec.output('output2.output2a')\n", " spec.output('output2.output2b')\n", @@ -223,12 +223,10 @@ " self.out('output1', self.inputs.input1)\n", " self.out('output2.output2a', self.inputs.input2.input2a)\n", " self.out('output2.output2b', self.inputs.input2.input2b)\n", - " \n", + "\n", + "\n", "pprint(SpecProcess.spec().get_description())\n", - "process = SpecProcess(inputs={\n", - " 'input1': 'my input',\n", - " 'input2': {'input2a': 'other input'}\n", - "})\n", + "process = SpecProcess(inputs={'input1': 'my input', 'input2': {'input2a': 'other input'}})\n", "process.execute()\n", "process.outputs" ] @@ -276,20 +274,20 @@ ], "source": [ "class ContinueProcess(plumpy.Process):\n", - "\n", " def run(self):\n", - " print(\"running\")\n", + " print('running')\n", " return plumpy.Continue(self.continue_fn)\n", - " \n", + "\n", " def continue_fn(self):\n", - " print(\"continuing\")\n", + " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill(\"I was killed\")\n", - " \n", + " return plumpy.Kill('I was killed')\n", + "\n", + "\n", "process = ContinueProcess()\n", "try:\n", " process.execute()\n", - "except plumpy.KilledError as error:\n", + "except plumpy.KilledError:\n", " pass\n", "\n", "print(process.state)\n", @@ -330,7 +328,6 @@ ], "source": [ "class WaitListener(plumpy.ProcessListener):\n", - "\n", " def on_process_running(self, process):\n", " print(process.state.name)\n", "\n", @@ -338,14 +335,15 @@ " print(process.state.name)\n", " process.resume()\n", "\n", - "class WaitProcess(plumpy.Process):\n", "\n", + "class WaitProcess(plumpy.Process):\n", " def run(self):\n", " return plumpy.Wait(self.resume_fn)\n", - " \n", + "\n", " def resume_fn(self):\n", " return plumpy.Stop(None, True)\n", "\n", + "\n", "process = WaitProcess()\n", "print(process.state.name)\n", "\n", @@ -394,33 +392,32 @@ ], "source": [ "async def async_fn():\n", - " print(\"async_fn start\")\n", - " await asyncio.sleep(.01)\n", - " print(\"async_fn end\")\n", + " print('async_fn start')\n", + " await asyncio.sleep(0.01)\n", + " print('async_fn end')\n", + "\n", "\n", "class NamedProcess(plumpy.Process):\n", - " \n", " @classmethod\n", " def define(cls, spec: plumpy.ProcessSpec):\n", " super().define(spec)\n", " spec.input('name')\n", "\n", " def run(self):\n", - " print(self.inputs.name, \"run\")\n", + " print(self.inputs.name, 'run')\n", " return plumpy.Continue(self.continue_fn)\n", "\n", " def continue_fn(self):\n", - " print(self.inputs.name, \"continued\")\n", + " print(self.inputs.name, 'continued')\n", + "\n", + "\n", + "process1 = NamedProcess({'name': 'process1'})\n", + "process2 = NamedProcess({'name': 'process2'})\n", "\n", - "process1 = NamedProcess({\"name\": \"process1\"})\n", - "process2 = NamedProcess({\"name\": \"process2\"})\n", "\n", "async def execute():\n", - " await asyncio.gather(\n", - " async_fn(),\n", - " process1.step_until_terminated(),\n", - " process2.step_until_terminated()\n", - " )\n", + " await asyncio.gather(async_fn(), process1.step_until_terminated(), process2.step_until_terminated())\n", + "\n", "\n", "plumpy.get_event_loop().run_until_complete(execute())" ] @@ -468,31 +465,33 @@ ], "source": [ "class SimpleProcess(plumpy.Process):\n", - " \n", " def run(self):\n", " print(self.get_name())\n", - " \n", - "class PauseProcess(plumpy.Process):\n", "\n", + "\n", + "class PauseProcess(plumpy.Process):\n", " def run(self):\n", - " print(f\"{self.get_name()}: pausing\")\n", + " print(f'{self.get_name()}: pausing')\n", " self.pause()\n", - " print(f\"{self.get_name()}: continue step\")\n", + " print(f'{self.get_name()}: continue step')\n", " return plumpy.Continue(self.next_step)\n", - " \n", + "\n", " def next_step(self):\n", - " print(f\"{self.get_name()}: next step\")\n", + " print(f'{self.get_name()}: next step')\n", + "\n", "\n", "pause_proc = PauseProcess()\n", "simple_proc = SimpleProcess()\n", "\n", + "\n", "async def play(proc):\n", " while True:\n", " if proc.paused:\n", - " print(f\"{proc.get_name()}: playing (state={proc.state.name})\")\n", + " print(f'{proc.get_name()}: playing (state={proc.state.name})')\n", " proc.play()\n", " break\n", "\n", + "\n", "async def execute():\n", " return await asyncio.gather(\n", " pause_proc.step_until_terminated(),\n", @@ -500,6 +499,7 @@ " play(pause_proc),\n", " )\n", "\n", + "\n", "outputs = plumpy.get_event_loop().run_until_complete(execute())" ] }, @@ -555,7 +555,8 @@ "\n", " def step2(self):\n", " print('step2')\n", - " \n", + "\n", + "\n", "workchain = SimpleWorkChain()\n", "output = workchain.execute()" ] @@ -601,11 +602,7 @@ " super().define(spec)\n", " spec.input('run', valid_type=bool)\n", "\n", - " spec.outline(\n", - " plumpy.if_(cls.if_step)(\n", - " cls.conditional_step\n", - " )\n", - " )\n", + " spec.outline(plumpy.if_(cls.if_step)(cls.conditional_step))\n", "\n", " def if_step(self):\n", " print(' if')\n", @@ -613,12 +610,13 @@ "\n", " def conditional_step(self):\n", " print(' conditional')\n", - " \n", - "workchain = IfWorkChain({\"run\": False})\n", + "\n", + "\n", + "workchain = IfWorkChain({'run': False})\n", "print('execute False')\n", "output = workchain.execute()\n", "\n", - "workchain = IfWorkChain({\"run\": True})\n", + "workchain = IfWorkChain({'run': True})\n", "print('execute True')\n", "output = workchain.execute()" ] @@ -666,23 +664,19 @@ " super().define(spec)\n", " spec.input('steps', valid_type=int, default=3)\n", "\n", - " spec.outline(\n", - " cls.init_step,\n", - " plumpy.while_(cls.while_step)(\n", - " cls.conditional_step\n", - " )\n", - " )\n", - " \n", + " spec.outline(cls.init_step, plumpy.while_(cls.while_step)(cls.conditional_step))\n", + "\n", " def init_step(self):\n", " self.ctx.iterator = 0\n", "\n", " def while_step(self):\n", " self.ctx.iterator += 1\n", - " return (self.ctx.iterator <= self.inputs.steps)\n", + " return self.ctx.iterator <= self.inputs.steps\n", "\n", " def conditional_step(self):\n", " print('step', self.ctx.iterator)\n", - " \n", + "\n", + "\n", "workchain = WhileWorkChain()\n", "output = workchain.execute()" ] @@ -714,13 +708,12 @@ "outputs": [], "source": [ "async def awaitable_func(msg):\n", - " await asyncio.sleep(.01)\n", + " await asyncio.sleep(0.01)\n", " print(msg)\n", " return True\n", - " \n", "\n", - "class InternalProcess(plumpy.Process):\n", "\n", + "class InternalProcess(plumpy.Process):\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", @@ -733,7 +726,6 @@ "\n", "\n", "class InterstepWorkChain(plumpy.WorkChain):\n", - "\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", @@ -745,31 +737,24 @@ " cls.step2,\n", " cls.step3,\n", " )\n", - " \n", + "\n", " def step1(self):\n", " print(self.inputs.name, 'step1')\n", "\n", " def step2(self):\n", " print(self.inputs.name, 'step2')\n", - " time.sleep(.01)\n", - " \n", + " time.sleep(0.01)\n", + "\n", " if self.inputs.awaitable:\n", " self.to_context(\n", - " awaitable=asyncio.ensure_future(\n", - " awaitable_func(f'{self.inputs.name} step2 awaitable'),\n", - " loop=self.loop\n", - " )\n", + " awaitable=asyncio.ensure_future(awaitable_func(f'{self.inputs.name} step2 awaitable'), loop=self.loop)\n", " )\n", " if self.inputs.process:\n", - " self.to_context(\n", - " process=self.launch(\n", - " InternalProcess, \n", - " inputs={'name': f'{self.inputs.name} step2 process'})\n", - " )\n", + " self.to_context(process=self.launch(InternalProcess, inputs={'name': f'{self.inputs.name} step2 process'}))\n", "\n", " def step3(self):\n", " print(self.inputs.name, 'step3')\n", - " print(f\" ctx={self.ctx}\")" + " print(f' ctx={self.ctx}')" ] }, { @@ -803,11 +788,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1'})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2'})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -847,11 +831,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1', 'process': True})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2', 'process': True})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -882,11 +865,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1', 'awaitable': True})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2', 'awaitable': True})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -926,11 +908,10 @@ "wkchain1 = InterstepWorkChain({'name': 'wkchain1', 'process': True, 'awaitable': True})\n", "wkchain2 = InterstepWorkChain({'name': 'wkchain2', 'process': True, 'awaitable': True})\n", "\n", + "\n", "async def execute():\n", - " return await asyncio.gather(\n", - " wkchain1.step_until_terminated(),\n", - " wkchain2.step_until_terminated()\n", - " )\n", + " return await asyncio.gather(wkchain1.step_until_terminated(), wkchain2.step_until_terminated())\n", + "\n", "\n", "output = plumpy.get_event_loop().run_until_complete(execute())" ] @@ -972,8 +953,8 @@ "source": [ "persister = plumpy.InMemoryPersister()\n", "\n", - "class PersistWorkChain(plumpy.WorkChain):\n", "\n", + "class PersistWorkChain(plumpy.WorkChain):\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", @@ -982,10 +963,10 @@ " cls.step2,\n", " cls.step3,\n", " )\n", - " \n", + "\n", " def __repr__(self):\n", - " return f\"PersistWorkChain(ctx={self.ctx})\"\n", - " \n", + " return f'PersistWorkChain(ctx={self.ctx})'\n", + "\n", " def init_step(self):\n", " self.ctx.step = 1\n", " persister.save_checkpoint(self, 'init')\n", @@ -997,7 +978,8 @@ " def step3(self):\n", " self.ctx.step += 1\n", " persister.save_checkpoint(self, 'step3')\n", - " \n", + "\n", + "\n", "workchain = PersistWorkChain()\n", "workchain.execute()\n", "workchain" @@ -1129,9 +1111,11 @@ "source": [ "communicator = kiwipy.LocalCommunicator()\n", "\n", + "\n", "class SimpleProcess(plumpy.Process):\n", " pass\n", "\n", + "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" @@ -1161,43 +1145,42 @@ ], "source": [ "class ControlledWorkChain(plumpy.WorkChain):\n", - "\n", " @classmethod\n", " def define(cls, spec):\n", " super().define(spec)\n", " spec.input('steps', valid_type=int, default=10)\n", " spec.output('result', valid_type=int)\n", "\n", - " spec.outline(\n", - " cls.init_step,\n", - " plumpy.while_(cls.while_step)(cls.loop_step),\n", - " cls.final_step\n", - " )\n", - " \n", + " spec.outline(cls.init_step, plumpy.while_(cls.while_step)(cls.loop_step), cls.final_step)\n", + "\n", " def init_step(self):\n", " self.ctx.iterator = 0\n", "\n", " def while_step(self):\n", - " return (self.ctx.iterator <= self.inputs.steps)\n", - " \n", + " return self.ctx.iterator <= self.inputs.steps\n", + "\n", " def loop_step(self):\n", " self.ctx.iterator += 1\n", "\n", " def final_step(self):\n", " self.out('result', self.ctx.iterator)\n", "\n", + "\n", "loop_communicator = plumpy.wrap_communicator(kiwipy.LocalCommunicator())\n", "loop_communicator.add_task_subscriber(plumpy.ProcessLauncher())\n", "controller = plumpy.RemoteProcessController(loop_communicator)\n", "\n", "wkchain = ControlledWorkChain(communicator=loop_communicator)\n", - " \n", + "\n", + "\n", "async def run_wait():\n", " return await controller.launch_process(ControlledWorkChain)\n", "\n", + "\n", "async def run_nowait():\n", " return await controller.launch_process(ControlledWorkChain, nowait=True)\n", "\n", + "\n", "print(plumpy.get_event_loop().run_until_complete(run_wait()))\n", "print(plumpy.get_event_loop().run_until_complete(run_nowait()))" ] diff --git a/examples/process_helloworld.py b/examples/process_helloworld.py index cf043eba..db2eff0f 100644 --- a/examples/process_helloworld.py +++ b/examples/process_helloworld.py @@ -3,7 +3,6 @@ class HelloWorld(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) diff --git a/examples/process_wait_and_resume.py b/examples/process_wait_and_resume.py index 03e8b57a..d4aa20b4 100644 --- a/examples/process_wait_and_resume.py +++ b/examples/process_wait_and_resume.py @@ -5,7 +5,6 @@ class WaitForResumeProc(plumpy.Process): - def run(self): print(f'Now I am running: {self.state}') return plumpy.Wait(self.after_resume_and_exec) @@ -15,12 +14,10 @@ def after_resume_and_exec(self): kwargs = { - 'connection_params': { - 'url': 'amqp://guest:guest@127.0.0.1:5672/' - }, + 'connection_params': {'url': 'amqp://guest:guest@127.0.0.1:5672/'}, 'message_exchange': 'WaitForResume.uuid-0', 'task_exchange': 'WaitForResume.uuid-0', - 'task_queue': 'WaitForResume.uuid-0' + 'task_queue': 'WaitForResume.uuid-0', } if __name__ == '__main__': diff --git a/examples/workchain_simple.py b/examples/workchain_simple.py index 078de3ca..aa189d3b 100644 --- a/examples/workchain_simple.py +++ b/examples/workchain_simple.py @@ -3,7 +3,6 @@ class AddAndMulWF(plumpy.WorkChain): - @classmethod def define(cls, spec): super().define(spec) diff --git a/pyproject.toml b/pyproject.toml index 78c34807..9d996d5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,16 +53,15 @@ docs = [ 'importlib-metadata~=4.12.0', ] pre-commit = [ - 'mypy==1.3.0', + 'mypy==1.13.0', 'pre-commit~=2.2', - 'pylint==2.15.8', 'types-pyyaml' ] tests = [ 'ipykernel==6.12.1', - 'pytest==6.2.5', - 'pytest-asyncio==0.16.0', - 'pytest-cov==3.0.0', + 'pytest~=7.0', + 'pytest-asyncio~=0.12,<0.17', + 'pytest-cov~=4.1', 'pytest-notebook>=0.8.0', 'shortuuid==1.0.8', 'importlib-resources~=5.2', @@ -75,18 +74,43 @@ name = 'plumpy' exclude = [ 'docs/', 'examples/', - 'test/', + 'tests/', ] [tool.flynt] line-length = 120 fail-on-change = true -[tool.isort] -force_sort_within_sections = true -include_trailing_comma = true -line_length = 120 -multi_line_output = 3 +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = 'single' + +[tool.ruff.lint] +ignore = [ + 'F403', # Star imports unable to detect undefined names + 'F405', # Import may be undefined or defined from star imports + 'PLR0911', # Too many return statements + 'PLR0912', # Too many branches + 'PLR0913', # Too many arguments in function definition + 'PLR0915', # Too many statements + 'PLR2004', # Magic value used in comparison + 'RUF005', # Consider iterable unpacking instead of concatenation + 'RUF012' # Mutable class attributes should be annotated with `typing.ClassVar` +] +select = [ + 'E', # pydocstyle + 'W', # pydocstyle + 'F', # pyflakes + 'I', # isort + 'N', # pep8-naming + 'PLC', # pylint-convention + 'PLE', # pylint-error + 'PLR', # pylint-refactor + 'PLW', # pylint-warning + 'RUF' # ruff +] [tool.mypy] show_error_codes = true diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index ea88f872..6f94b5bf 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- # mypy: disable-error-code=name-defined -# pylint: disable=undefined-variable __version__ = '0.22.3' import logging @@ -21,9 +20,20 @@ from .workchains import * __all__ = ( - events.__all__ + exceptions.__all__ + processes.__all__ + utils.__all__ + futures.__all__ + mixins.__all__ + - persistence.__all__ + communications.__all__ + process_comms.__all__ + process_listener.__all__ + - workchains.__all__ + loaders.__all__ + ports.__all__ + process_states.__all__ + events.__all__ + + exceptions.__all__ + + processes.__all__ + + utils.__all__ + + futures.__all__ + + mixins.__all__ + + persistence.__all__ + + communications.__all__ + + process_comms.__all__ + + process_listener.__all__ + + workchains.__all__ + + loaders.__all__ + + ports.__all__ + + process_states.__all__ ) @@ -32,7 +42,6 @@ # https://docs.python.org/3.1/library/logging.html#library-config # for more details class NullHandler(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: pass diff --git a/src/plumpy/base/__init__.py b/src/plumpy/base/__init__.py index 79450590..a4e3132e 100644 --- a/src/plumpy/base/__init__.py +++ b/src/plumpy/base/__init__.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -# pylint: disable=undefined-variable -# type: ignore from .state_machine import * from .utils import * -__all__ = (state_machine.__all__ + utils.__all__) +__all__ = state_machine.__all__ + utils.__all__ # type: ignore[name-defined] diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index b62825e1..d99d0705 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The state machine for processes""" + import enum import functools import inspect @@ -13,24 +14,24 @@ from .utils import call_with_super_check, super_check -__all__ = ['StateMachine', 'StateMachineMeta', 'event', 'TransitionFailed'] +__all__ = ['StateMachine', 'StateMachineMeta', 'TransitionFailed', 'event'] _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] # pylint: disable=invalid-name -EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] # pylint: disable=invalid-name +LABEL_TYPE = Union[None, enum.Enum, str] +EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] class StateMachineError(Exception): """Base class for state machine errors""" -class StateEntryFailed(Exception): +class StateEntryFailed(Exception): # noqa: N818 """ Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: # pylint: disable=keyword-arg-before-vararg + def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -42,20 +43,16 @@ class InvalidStateError(Exception): class EventError(StateMachineError): - def __init__(self, evt: str, msg: str): super().__init__(msg) self.event = evt -class TransitionFailed(Exception): +class TransitionFailed(Exception): # noqa: N818 """A state transition failed""" def __init__( - self, - initial_state: 'State', - final_state: Optional['State'] = None, - traceback_str: Optional[str] = None + self, initial_state: 'State', final_state: Optional['State'] = None, traceback_str: Optional[str] = None ) -> None: self.initial_state = initial_state self.final_state = final_state @@ -71,7 +68,7 @@ def _format_msg(self) -> str: def event( from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', - to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*' + to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" if from_states != '*': @@ -102,8 +99,8 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: raise EventError(evt_label, 'Machine did not transition') raise EventError( - evt_label, 'Event produced invalid state transition from ' - f'{initial.LABEL} to {self._state.LABEL}' + evt_label, + 'Event produced invalid state transition from ' f'{initial.LABEL} to {self._state.LABEL}', ) return result @@ -126,7 +123,7 @@ class State: def is_terminal(cls) -> bool: return not cls.ALLOWED - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): # pylint: disable=unused-argument + def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): """ :param state_machine: The process this state belongs to """ @@ -138,12 +135,12 @@ def __str__(self) -> str: @property def label(self) -> LABEL_TYPE: - """ Convenience property to get the state label """ + """Convenience property to get the state label""" return self.LABEL @super_check def enter(self) -> None: - """ Entering the state """ + """Entering the state""" def execute(self) -> Optional['State']: """ @@ -153,7 +150,7 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: - """ Exiting the state """ + """Exiting the state""" if self.is_terminal(): raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -175,13 +172,13 @@ class StateEventHook(enum.Enum): procedure. The callback will be passed a state instance whose meaning will differ depending on the hook as commented below. """ + ENTERING_STATE: int = 0 # State passed will be the state that is being entered ENTERED_STATE: int = 1 # State passed will be the last state that we entered from EXITING_STATE: int = 2 # State passed will be the next state that will be entered (or None for terminal) class StateMachineMeta(type): - def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': """ Create the state machine and enter the initial state. @@ -220,13 +217,13 @@ def get_states(cls) -> Sequence[Type[State]]: def initial_state_label(cls) -> LABEL_TYPE: cls.__ensure_built() assert cls.STATES is not None - return cls.STATES[0].LABEL # pylint: disable=unsubscriptable-object + return cls.STATES[0].LABEL @classmethod def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None - return cls._STATES_MAP[label] # pylint: disable=unsubscriptable-object + return cls._STATES_MAP[label] @classmethod def __ensure_built(cls) -> None: @@ -238,15 +235,15 @@ def __ensure_built(cls) -> None: pass cls.STATES = cls.get_states() - assert isinstance(cls.STATES, Iterable) # pylint: disable=isinstance-second-argument-not-valid-type + assert isinstance(cls.STATES, Iterable) # Build the states map cls._STATES_MAP = {} - for state_cls in cls.STATES: # pylint: disable=not-an-iterable + for state_cls in cls.STATES: assert issubclass(state_cls, State) label = state_cls.LABEL - assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" # pylint: disable=unsupported-membership-test - cls._STATES_MAP[label] = state_cls # pylint: disable=unsupported-assignment-operation + assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" + cls._STATES_MAP[label] = state_cls # should class initialise sealed = False? cls.sealed = True # type: ignore @@ -301,11 +298,10 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: @super_check def on_terminated(self) -> None: - """ Called when a terminal state is entered """ + """Called when a terminal state is entered""" def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: - assert not self._transitioning, \ - 'Cannot call transition_to when already transitioning state' + assert not self._transitioning, 'Cannot call transition_to when already transitioning state' initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -331,7 +327,7 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A if self._state is not None and self._state.is_terminal(): call_with_super_check(self.on_terminated) - except Exception: # pylint: disable=broad-except + except Exception: self._transitioning = False if self._transition_failing: raise @@ -360,12 +356,12 @@ def set_debug(self, enabled: bool) -> None: def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: try: - return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object + return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') def _exit_current_state(self, next_state: State) -> None: - """ Exit the given state """ + """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state @@ -401,6 +397,6 @@ def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State return state try: - return self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object + return self.get_states_map()[cast(Hashable, state)] except KeyError: raise ValueError(f'{state} is not a valid state') diff --git a/src/plumpy/base/utils.py b/src/plumpy/base/utils.py index 232c5d26..8c35b903 100644 --- a/src/plumpy/base/utils.py +++ b/src/plumpy/base/utils.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from typing import Any, Callable -__all__ = ['super_check', 'call_with_super_check'] +__all__ = ['call_with_super_check', 'super_check'] def super_check(wrapped: Callable[..., Any]) -> Callable[..., Any]: diff --git a/src/plumpy/communications.py b/src/plumpy/communications.py index 51dff60d..1d7e775b 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/communications.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module for general kiwipy communication methods""" + import asyncio import functools from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional @@ -10,7 +11,12 @@ from .utils import ensure_coroutine __all__ = [ - 'Communicator', 'RemoteException', 'DeliveryFailed', 'TaskRejected', 'plum_to_kiwi_future', 'wrap_communicator' + 'Communicator', + 'DeliveryFailed', + 'RemoteException', + 'TaskRejected', + 'plum_to_kiwi_future', + 'wrap_communicator', ] RemoteException = kiwipy.RemoteException @@ -20,7 +26,7 @@ if TYPE_CHECKING: # identifiers for subscribers - ID_TYPE = Hashable # pylint: disable=invalid-name + ID_TYPE = Hashable Subscriber = Callable[..., Any] # RPC subscriber params: communicator, msg RpcSubscriber = Callable[[kiwipy.Communicator, Any], Any] @@ -55,8 +61,9 @@ def on_done(_plum_future: futures.Future) -> None: return kiwi_future -def convert_to_comm(callback: 'Subscriber', - loop: Optional[asyncio.AbstractEventLoop] = None) -> Callable[..., kiwipy.Future]: +def convert_to_comm( + callback: 'Subscriber', loop: Optional[asyncio.AbstractEventLoop] = None +) -> Callable[..., kiwipy.Future]: """ Take a callback function and converted it to one that will schedule a callback on the given even loop and return a kiwi future representing the future outcome @@ -67,7 +74,6 @@ def convert_to_comm(callback: 'Subscriber', :return: a new callback function that returns a future """ if isinstance(callback, kiwipy.BroadcastFilter): - # if the broadcast is filtered for this callback, # we don't want to go through the (costly) process # of setting up async tasks and callbacks @@ -75,16 +81,15 @@ def convert_to_comm(callback: 'Subscriber', def _passthrough(*args: Any, **kwargs: Any) -> bool: sender = kwargs.get('sender', args[1]) subject = kwargs.get('subject', args[2]) - return callback.is_filtered(sender, subject) # type: ignore[attr-defined] + return callback.is_filtered(sender, subject) else: - def _passthrough(*args: Any, **kwargs: Any) -> bool: # pylint: disable=unused-argument + def _passthrough(*args: Any, **kwargs: Any) -> bool: return False coro = ensure_coroutine(callback) def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> kiwipy.Future: - if _passthrough(*args, **kwargs): kiwi_future = kiwipy.Future() kiwi_future.set_result(None) @@ -170,7 +175,7 @@ def broadcast_send( body: Optional[Any], sender: Optional[str] = None, subject: Optional[str] = None, - correlation_id: Optional['ID_TYPE'] = None + correlation_id: Optional['ID_TYPE'] = None, ) -> futures.Future: return self._communicator.broadcast_send(body, sender, subject, correlation_id) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 3a342321..47ad4956 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -7,14 +7,13 @@ if TYPE_CHECKING: from typing import Set, Type - from .process_listener import ProcessListener # pylint: disable=cyclic-import + from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @persistence.auto_persist('_listeners', '_listener_type') class EventHelper(persistence.Savable): - def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -50,5 +49,5 @@ def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: A for listener in list(self.listeners): try: getattr(listener, event_function.__name__)(*args, **kwargs) - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: _LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception) diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 60a5306e..3de81987 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -1,18 +1,24 @@ # -*- coding: utf-8 -*- """Event and loop related classes and functions""" + import asyncio import sys from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence __all__ = [ - 'new_event_loop', 'set_event_loop', 'get_event_loop', 'run_until_complete', 'set_event_loop_policy', - 'reset_event_loop_policy', 'PlumpyEventLoopPolicy' + 'PlumpyEventLoopPolicy', + 'get_event_loop', + 'new_event_loop', + 'reset_event_loop_policy', + 'run_until_complete', + 'set_event_loop', + 'set_event_loop_policy', ] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process -get_event_loop = asyncio.get_event_loop # pylint: disable=invalid-name +get_event_loop = asyncio.get_event_loop def set_event_loop(*args: Any, **kwargs: Any) -> None: @@ -51,12 +57,10 @@ def reset_event_loop_policy() -> None: """Reset the event loop policy to the default.""" loop = get_event_loop() - # pylint: disable=protected-access cls = loop.__class__ del cls._check_running # type: ignore del cls._nest_patched # type: ignore - # pylint: enable=protected-access asyncio.set_event_loop_policy(None) @@ -69,7 +73,7 @@ def run_until_complete(future: asyncio.Future, loop: Optional[asyncio.AbstractEv class ProcessCallback: """Object returned by callback registration methods.""" - __slots__ = ('_callback', '_args', '_kwargs', '_process', '_cancelled', '__weakref__') + __slots__ = ('__weakref__', '_args', '_callback', '_cancelled', '_kwargs', '_process') def __init__( self, process: 'Process', callback: Callable[..., Any], args: Sequence[Any], kwargs: Dict[str, Any] @@ -93,7 +97,7 @@ async def run(self) -> None: if not self._cancelled: try: await self._callback(*self._args, **self._kwargs) - except Exception: # pylint: disable=broad-except + except Exception: exc_info = sys.exc_info() self._process.callback_excepted(self._callback, exc_info[1], exc_info[2]) finally: diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 40d3e12d..70b5aa2d 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from typing import Optional -__all__ = ['KilledError', 'UnsuccessfulResult', 'InvalidStateError', 'PersistenceError', 'ClosedError'] +__all__ = ['ClosedError', 'InvalidStateError', 'KilledError', 'PersistenceError', 'UnsuccessfulResult'] class KilledError(Exception): diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 365b8008..161244cd 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -2,12 +2,13 @@ """ Module containing future related methods and classes """ + import asyncio from typing import Any, Callable, Coroutine, Optional import kiwipy -__all__ = ['Future', 'gather', 'chain', 'copy_future', 'CancelledError', 'create_task'] +__all__ = ['CancelledError', 'Future', 'chain', 'copy_future', 'create_task', 'gather'] CancelledError = kiwipy.CancelledError @@ -16,11 +17,11 @@ class InvalidStateError(Exception): """Exception for when a future or action is in an invalid state""" -copy_future = kiwipy.copy_future # pylint: disable=invalid-name -chain = kiwipy.chain # pylint: disable=invalid-name -gather = asyncio.gather # pylint: disable=invalid-name +copy_future = kiwipy.copy_future +chain = kiwipy.chain +gather = asyncio.gather -Future = asyncio.Future # pylint: disable=invalid-name +Future = asyncio.Future class CancellableAction(Future): @@ -35,7 +36,7 @@ def __init__(self, action: Callable[..., Any], cookie: Any = None): @property def cookie(self) -> Any: - """ A cookie that can be used to correlate the actions with something """ + """A cookie that can be used to correlate the actions with something""" return self._cookie def run(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/plumpy/lang.py b/src/plumpy/lang.py index 6d9290af..450927d6 100644 --- a/src/plumpy/lang.py +++ b/src/plumpy/lang.py @@ -2,13 +2,13 @@ """ Python language utilities and tools. """ + import functools import inspect from typing import Any, Callable def protected(check: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def wrap(func: Callable[..., Any]) -> Callable[..., Any]: if isinstance(func, property): raise RuntimeError('Protected must go after @property decorator') @@ -31,7 +31,7 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: return func(self, *args, **kwargs) else: - wrapped_fn = func + wrapped_fn = func # type: ignore[assignment] return wrapped_fn @@ -60,15 +60,14 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: return func(self, *args, **kwargs) else: - wrapped_fn = func + wrapped_fn = func # type: ignore[assignment] return wrapped_fn return wrap -class __NULL: # pylint: disable=invalid-name - +class __NULL: # noqa: N801 def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) diff --git a/src/plumpy/loaders.py b/src/plumpy/loaders.py index 59f33f64..a01f9b60 100644 --- a/src/plumpy/loaders.py +++ b/src/plumpy/loaders.py @@ -3,7 +3,7 @@ import importlib from typing import Any, Optional -__all__ = ['ObjectLoader', 'DefaultObjectLoader', 'set_object_loader', 'get_object_loader'] +__all__ = ['DefaultObjectLoader', 'ObjectLoader', 'get_object_loader', 'set_object_loader'] class ObjectLoader(metaclass=abc.ABCMeta): @@ -74,7 +74,7 @@ def get_object_loader() -> ObjectLoader: :return: A class loader :rtype: :class:`ObjectLoader` """ - global OBJECT_LOADER + global OBJECT_LOADER # noqa: PLW0603 if OBJECT_LOADER is None: OBJECT_LOADER = DefaultObjectLoader() return OBJECT_LOADER @@ -88,5 +88,5 @@ def set_object_loader(loader: Optional[ObjectLoader]) -> None: :type loader: :class:`ObjectLoader` :return: """ - global OBJECT_LOADER + global OBJECT_LOADER # noqa: PLW0603 OBJECT_LOADER = loader diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index a8dcca1e..10142eb7 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -12,6 +12,7 @@ class ContextMixin(persistence.Savable): Add a context to a Process. The contents of the context will be saved in the instance state unlike standard instance variables. """ + CONTEXT: str = '_context' def __init__(self, *args: Any, **kwargs: Any): diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 7a15b1cc..ba755bc5 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -18,18 +18,24 @@ from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ - 'Bundle', 'Persister', 'PicklePersister', 'auto_persist', 'Savable', 'SavableFuture', 'LoadSaveContext', - 'PersistedCheckpoint', 'InMemoryPersister' + 'Bundle', + 'InMemoryPersister', + 'LoadSaveContext', + 'PersistedCheckpoint', + 'Persister', + 'PicklePersister', + 'Savable', + 'SavableFuture', + 'auto_persist', ] PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the @@ -77,7 +83,6 @@ def _bundle_constructor(loader: yaml.Loader, data: Any) -> Generator[Bundle, Non class Persister(metaclass=abc.ABCMeta): - @abc.abstractmethod def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None: """ @@ -301,7 +306,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: class InMemoryPersister(Persister): - """ Mainly to be used in testing/debugging """ + """Mainly to be used in testing/debugging""" def __init__(self, loader: Optional[loaders.ObjectLoader] = None) -> None: super().__init__() @@ -340,13 +345,11 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='Type[Savable]') # type: ignore[name-defined] # pylint: disable=invalid-name +SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - # pylint: disable=protected-access if savable._auto_persist is None: savable._auto_persist = set() else: @@ -390,7 +393,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: self._values = dict(**kwargs) self.loader = loader @@ -408,7 +410,7 @@ def __contains__(self, item: Any) -> bool: return self._values.__contains__(item) def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """ Add additional information to the context by making a copy with the new values """ + """Add additional information to the context by making a copy with the new values""" extended = self._values.copy() extended.update(kwargs) loader = extended.pop('loader', self.loader) @@ -485,7 +487,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio self.load_members(self._auto_persist, saved_state, load_context) @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: # pylint: disable=unused-argument + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() if self._auto_persist is not None: self.save_members(self._auto_persist, out_state) @@ -527,10 +529,7 @@ def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> N out_state[member] = value def load_members( - self, - members: Iterable[str], - saved_state: SAVED_STATE_TYPE, - load_context: Optional[LoadSaveContext] = None + self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None ) -> None: for member in members: setattr(self, member, self._get_value(saved_state, member, load_context)) @@ -580,8 +579,9 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: # endregion - def _get_value(self, saved_state: SAVED_STATE_TYPE, name: str, - load_context: Optional[LoadSaveContext]) -> Union[MethodType, 'Savable']: + def _get_value( + self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] + ) -> Union[MethodType, 'Savable']: value = saved_state[name] typ = Savable._get_meta_type(saved_state, name) @@ -626,10 +626,10 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa state = saved_state['_state'] - if state == asyncio.futures._PENDING: # type: ignore # pylint: disable=protected-access + if state == asyncio.futures._PENDING: # type: ignore obj = cls(loop=loop) - if state == asyncio.futures._FINISHED: # type: ignore # pylint: disable=protected-access + if state == asyncio.futures._FINISHED: # type: ignore obj = cls(loop=loop) result = saved_state['_result'] @@ -639,14 +639,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa except KeyError: obj.set_result(result) - if state == asyncio.futures._CANCELLED: # type: ignore # pylint: disable=protected-access + if state == asyncio.futures._CANCELLED: # type: ignore obj = cls(loop=loop) obj.cancel() return obj def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - # pylint: disable=attribute-defined-outside-init super().load_instance_state(saved_state, load_context) if self._callbacks: # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index fc5f138f..cfbd92d5 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -1,16 +1,17 @@ # -*- coding: utf-8 -*- """Module for process ports""" + import collections import copy import inspect import json import logging -from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast import warnings +from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast from plumpy.utils import AttributesFrozendict, is_mutable_property, type_check -__all__ = ['UNSPECIFIED', 'PortValidationError', 'PortNamespace', 'Port', 'InputPort', 'OutputPort'] +__all__ = ['UNSPECIFIED', 'InputPort', 'OutputPort', 'Port', 'PortNamespace', 'PortValidationError'] _LOGGER = logging.getLogger(__name__) UNSPECIFIED = () @@ -19,7 +20,7 @@ This has been deprecated and the new signature is `validator(value, port)` where the `port` argument will be the port instance to which the validator has been assigned.""" -VALIDATOR_TYPE = Callable[[Any, 'Port'], Optional[str]] # pylint: disable=invalid-name +VALIDATOR_TYPE = Callable[[Any, 'Port'], Optional[str]] class PortValidationError(Exception): @@ -66,9 +67,9 @@ def __init__( self, name: str, valid_type: Optional[Type[Any]] = None, - help: Optional[str] = None, # pylint: disable=redefined-builtin + help: Optional[str] = None, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None + validator: Optional[VALIDATOR_TYPE] = None, ) -> None: self._name = name self._valid_type = valid_type @@ -134,7 +135,7 @@ def help(self) -> Optional[str]: return self._help @help.setter - def help(self, help: Optional[str]) -> None: # pylint: disable=redefined-builtin + def help(self, help: Optional[str]) -> None: """Set the help string for this port :param help: the help string @@ -198,9 +199,9 @@ def validate(self, value: Any, breadcrumbs: Sequence[str] = ()) -> Optional[Port spec = inspect.getfullargspec(self.validator) if len(spec[0]) == 1: warnings.warn(VALIDATOR_SIGNATURE_DEPRECATION_WARNING.format(self.validator.__name__)) - result = self.validator(value) # type: ignore # pylint: disable=not-callable + result = self.validator(value) # type: ignore else: - result = self.validator(value, self) # pylint: disable=not-callable + result = self.validator(value, self) if result is not None: assert isinstance(result, str), 'Validator returned non string type' validation_error = result @@ -233,17 +234,17 @@ def __init__( self, name: str, valid_type: Optional[Type[Any]] = None, - help: Optional[str] = None, # pylint: disable=redefined-builtin + help: Optional[str] = None, default: Any = UNSPECIFIED, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None - ) -> None: # pylint: disable=too-many-arguments + validator: Optional[VALIDATOR_TYPE] = None, + ) -> None: super().__init__( name, valid_type=valid_type, help=help, required=InputPort.required_override(required, default), - validator=validator + validator=validator, ) if required is not InputPort.required_override(required, default): @@ -252,7 +253,6 @@ def __init__( ) if default is not UNSPECIFIED: - # Only validate the default value if it is not a callable. If it is a callable its return value will always # be validated when the port is validated upon process construction, if the default is was actually used. if not callable(default): @@ -304,14 +304,14 @@ class PortNamespace(collections.abc.MutableMapping, Port): def __init__( self, name: str = '', # Note this was set to None, but that would fail if you tried to compute breadcrumbs - help: Optional[str] = None, # pylint: disable=redefined-builtin + help: Optional[str] = None, required: bool = True, validator: Optional[VALIDATOR_TYPE] = None, valid_type: Optional[Type[Any]] = None, default: Any = UNSPECIFIED, dynamic: bool = False, - populate_defaults: bool = True - ) -> None: # pylint: disable=too-many-arguments + populate_defaults: bool = True, + ) -> None: """Construct a port namespace. :param name: the name of the namespace @@ -396,7 +396,7 @@ def valid_type(self, valid_type: Optional[Type[Any]]) -> None: if valid_type is not None: self.dynamic = True - super(PortNamespace, self.__class__).valid_type.fset(self, valid_type) # type: ignore # pylint: disable=no-member + super(PortNamespace, self.__class__).valid_type.fset(self, valid_type) # type: ignore @property def populate_defaults(self) -> bool: @@ -459,7 +459,7 @@ def get_port(self, name: str, create_dynamically: bool = False) -> Union[Port, ' valid_type=self.valid_type, default=self.default, dynamic=self.dynamic, - populate_defaults=self.populate_defaults + populate_defaults=self.populate_defaults, ) if namespace: @@ -495,7 +495,6 @@ def create_port_namespace(self, name: str, **kwargs: Any) -> 'PortNamespace': # If this is True, the (sub) port namespace does not yet exist, so we create it if port_name not in self: - # If there still is a `namespace`, we create a sub namespace, *without* the constructor arguments if namespace: self[port_name] = self.__class__(port_name) @@ -515,7 +514,7 @@ def absorb( port_namespace: 'PortNamespace', exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[Dict[str, Any]] = None + namespace_options: Optional[Dict[str, Any]] = None, ) -> List[str]: """Absorb another PortNamespace instance into oneself, including all its mutable properties and ports. @@ -531,7 +530,7 @@ def absorb( :param namespace_options: a dictionary with mutable PortNamespace property values to override :return: list of the names of the ports that were absorbed """ - # pylint: disable=too-many-branches + if not isinstance(port_namespace, PortNamespace): raise ValueError('port_namespace has to be an instance of PortNamespace') @@ -559,14 +558,12 @@ def absorb( absorbed_ports = [] for port_name, port in port_namespace.items(): - # If the current port name occurs in the exclude list, simply skip it entirely, there is no need to consider # any of the nested ports it might have, even if it is a port namespace if exclude and port_name in exclude: continue if isinstance(port, PortNamespace): - # If the name does not appear at the start of any of the include rules we continue: if include and not any(rule.startswith(port_name) for rule in include): continue @@ -580,7 +577,7 @@ def absorb( # absorb call that will properly consider the include and exclude rules self[port_name] = copy.copy(port) portnamespace = cast(PortNamespace, self[port_name]) - portnamespace._ports = {} # pylint: disable=protected-access + portnamespace._ports = {} portnamespace.absorb(port, sub_exclude, sub_include) else: # If include rules are specified but the port name does not appear, simply skip it @@ -615,10 +612,8 @@ def project(self, port_values: MutableMapping[str, Any]) -> MutableMapping[str, return result - def validate( # pylint: disable=arguments-differ - self, - port_values: Optional[Mapping[str, Any]] = None, - breadcrumbs: Sequence[str] = () + def validate( + self, port_values: Optional[Mapping[str, Any]] = None, breadcrumbs: Sequence[str] = () ) -> Optional[PortValidationError]: """ Validate the namespace port itself and subsequently all the port_values it contains @@ -627,7 +622,7 @@ def validate( # pylint: disable=arguments-differ :param breadcrumbs: a tuple of the path to having reached this point in validation :return: None or tuple containing 0: error string 1: tuple of breadcrumb strings to where the validation failed """ - # pylint: disable=arguments-renamed + breadcrumbs_local = (*breadcrumbs, self.name) message: Optional[str] @@ -665,12 +660,13 @@ def validate( # pylint: disable=arguments-differ spec = inspect.getfullargspec(self.validator) if len(spec[0]) == 1: warnings.warn(VALIDATOR_SIGNATURE_DEPRECATION_WARNING.format(self.validator.__name__)) - message = self.validator(port_values_clone) # type: ignore # pylint: disable=not-callable + message = self.validator(port_values_clone) # type: ignore else: - message = self.validator(port_values_clone, self) # pylint: disable=not-callable + message = self.validator(port_values_clone, self) if message is not None: - assert isinstance(message, str), \ - f"Validator returned something other than None or str: '{type(message)}'" + assert isinstance( + message, str + ), f"Validator returned something other than None or str: '{type(message)}'" return PortValidationError(message, breadcrumbs_to_port(breadcrumbs_local)) return None @@ -682,14 +678,12 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen :return: an AttributesFrozenDict with pre-processed port value mapping, complemented with port default values """ for name, port in self.items(): - # If the port was not specified in the inputs values and the port is a namespace with the property # `populate_defaults=False`, we skip the pre-processing and do not populate defaults. if name not in port_values and isinstance(port, PortNamespace) and not port.populate_defaults: continue if name not in port_values: - if port.has_default(): default = port.default if callable(default): @@ -712,8 +706,9 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen return AttributesFrozendict(port_values) - def validate_ports(self, port_values: MutableMapping[str, Any], - breadcrumbs: Sequence[str]) -> Optional[PortValidationError]: + def validate_ports( + self, port_values: MutableMapping[str, Any], breadcrumbs: Sequence[str] + ) -> Optional[PortValidationError]: """ Validate port values with respect to the explicitly defined ports of the port namespace. Ports values that are matched to an actual Port will be popped from the dictionary @@ -791,7 +786,7 @@ def strip_namespace(namespace: str, separator: str, rules: Optional[Sequence[str for rule in rules: if rule.startswith(prefix): - stripped.append(rule[len(prefix):]) + stripped.append(rule[len(prefix) :]) return stripped diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index c66e8431..293c680b 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" + import asyncio import copy import logging @@ -11,19 +12,19 @@ from .utils import PID_TYPE __all__ = [ + 'KILL_MSG', 'PAUSE_MSG', 'PLAY_MSG', - 'KILL_MSG', 'STATUS_MSG', 'ProcessLauncher', + 'RemoteProcessController', + 'RemoteProcessThreadController', 'create_continue_body', 'create_launch_body', - 'RemoteProcessThreadController', - 'RemoteProcessController', ] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process ProcessResult = Any ProcessStatus = Any @@ -34,7 +35,7 @@ class Intent: """Intent constants for a process message""" - # pylint: disable=too-few-public-methods + PLAY: str = 'play' PAUSE: str = 'pause' KILL: str = 'kill' @@ -71,7 +72,7 @@ def create_launch_body( init_kwargs: Optional[Dict[str, Any]] = None, persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, - nowait: bool = True + nowait: bool = True, ) -> Dict[str, Any]: """ Create a message body for the launch action @@ -95,8 +96,8 @@ def create_launch_body( PERSIST_KEY: persist, NOWAIT_KEY: nowait, ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs - } + KWARGS_KEY: init_kwargs, + }, } return msg_body @@ -119,7 +120,7 @@ def create_create_body( init_args: Optional[Sequence[Any]] = None, init_kwargs: Optional[Dict[str, Any]] = None, persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None + loader: Optional[loaders.ObjectLoader] = None, ) -> Dict[str, Any]: """ Create a message body to create a new process @@ -140,8 +141,8 @@ def create_create_body( PROCESS_CLASS_KEY: loader.identify_object(process_class), PERSIST_KEY: persist, ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs - } + KWARGS_KEY: init_kwargs, + }, } return msg_body @@ -216,11 +217,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro return result async def continue_process( - self, - pid: 'PID_TYPE', - tag: Optional[str] = None, - nowait: bool = False, - no_reply: bool = False + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Optional['ProcessResult']: """ Continue the process @@ -249,7 +246,7 @@ async def launch_process( persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> 'ProcessResult': """ Launch a process given the class and constructor arguments @@ -263,7 +260,7 @@ async def launch_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of launching the process """ - # pylint: disable=too-many-arguments + message = create_launch_body(process_class, init_args, init_kwargs, persist, loader, nowait) launch_future = self._communicator.task_send(message, no_reply=no_reply) future = await asyncio.wrap_future(launch_future) @@ -281,7 +278,7 @@ async def execute_process( init_kwargs: Optional[Dict[str, Any]] = None, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> 'ProcessResult': """ Execute a process. This call will first send a create task and then a continue task over @@ -296,7 +293,7 @@ async def execute_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of executing the process """ - # pylint: disable=too-many-arguments + message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) create_future = self._communicator.task_send(message) @@ -399,11 +396,7 @@ def kill_all(self, msg: Optional[Any]) -> None: self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( - self, - pid: 'PID_TYPE', - tag: Optional[str] = None, - nowait: bool = False, - no_reply: bool = False + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: message = create_continue_body(pid=pid, tag=tag, nowait=nowait) return self.task_send(message, no_reply=no_reply) @@ -416,9 +409,8 @@ def launch_process( persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> Union[None, PID_TYPE, ProcessResult]: - # pylint: disable=too-many-arguments """ Launch the process @@ -441,7 +433,7 @@ def execute_process( init_kwargs: Optional[Dict[str, Any]] = None, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> Union[None, PID_TYPE, ProcessResult]: """ Execute a process. This call will first send a create task and then a continue task over @@ -456,7 +448,7 @@ def execute_process( :param no_reply: if True, this call will be fire-and-forget, i.e. no return value :return: the result of executing the process """ - # pylint: disable=too-many-arguments + message = create_create_body(process_class, init_args, init_kwargs, persist=True, loader=loader) execute_future = kiwipy.Future() @@ -512,7 +504,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, persister: Optional[persistence.Persister] = None, load_context: Optional[persistence.LoadSaveContext] = None, - loader: Optional[loaders.ObjectLoader] = None + loader: Optional[loaders.ObjectLoader] = None, ) -> None: self._loop = loop self._persister = persister @@ -573,7 +565,8 @@ async def _launch( self._persister.save_checkpoint(proc) if nowait: - asyncio.ensure_future(proc.step_until_terminated()) + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 return proc.pid await proc.step_until_terminated() @@ -581,11 +574,7 @@ async def _launch( return proc.future().result() async def _continue( - self, - _communicator: kiwipy.Communicator, - pid: 'PID_TYPE', - nowait: bool, - tag: Optional[str] = None + self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None ) -> Union[PID_TYPE, ProcessResult]: """ Continue the process @@ -604,7 +593,8 @@ async def _continue( proc = cast('Process', saved_state.unbundle(self._load_context)) if nowait: - asyncio.ensure_future(proc.step_until_terminated()) + # XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails + asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 return proc.pid await proc.step_until_terminated() diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 110394a2..8e1acf94 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -8,12 +8,11 @@ __all__ = ['ProcessListener'] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process @persistence.auto_persist('_params') class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): - # region Persistence methods def __init__(self) -> None: diff --git a/src/plumpy/process_spec.py b/src/plumpy/process_spec.py index c82d59ee..00f2f3cc 100644 --- a/src/plumpy/process_spec.py +++ b/src/plumpy/process_spec.py @@ -7,9 +7,9 @@ from .ports import InputPort, OutputPort, Port, PortNamespace if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process -EXPOSED_TYPE = Dict[Optional[str], Dict[Type['Process'], Sequence[str]]] # pylint: disable=invalid-name +EXPOSED_TYPE = Dict[Optional[str], Dict[Type['Process'], Sequence[str]]] class ProcessSpec: @@ -22,6 +22,7 @@ class ProcessSpec: Every Process class has one of these. """ + NAME_INPUTS_PORT_NAMESPACE: str = 'inputs' NAME_OUTPUTS_PORT_NAMESPACE: str = 'outputs' PORT_NAMESPACE_TYPE = PortNamespace @@ -184,7 +185,7 @@ def expose_inputs( namespace: Optional[str] = None, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: """ This method allows one to automatically add the inputs from another Process to this ProcessSpec. @@ -215,7 +216,7 @@ def expose_outputs( namespace: Optional[str] = None, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: """ This method allows one to automatically add the ouputs from another Process to this ProcessSpec. @@ -249,8 +250,8 @@ def _expose_ports( namespace: Optional[str], exclude: Optional[Sequence[str]], include: Optional[Sequence[str]], - namespace_options: Optional[dict] = None - ) -> None: # pylint: disable=too-many-arguments + namespace_options: Optional[dict] = None, + ) -> None: """ Expose ports from a source PortNamespace of the ProcessSpec of a Process class into the destination PortNamespace of this ProcessSpec. If the namespace is specified, the ports will be exposed in that sub diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 3407412d..7ae6e9bd 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from enum import Enum import sys import traceback +from enum import Enum from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast @@ -22,28 +22,28 @@ from .utils import SAVED_STATE_TYPE __all__ = [ - 'ProcessState', + 'Continue', 'Created', - 'Running', - 'Waiting', - 'Finished', 'Excepted', - 'Killed', + 'Finished', + 'Interruption', # Commands 'Kill', - 'Stop', - 'Wait', - 'Continue', - 'Interruption', 'KillInterruption', + 'Killed', 'PauseInterruption', + 'ProcessState', + 'Running', + 'Stop', + 'Wait', + 'Waiting', ] if TYPE_CHECKING: - from .processes import Process # pylint: disable=cyclic-import + from .processes import Process -class Interruption(Exception): +class Interruption(Exception): # noqa: N818 pass @@ -64,7 +64,6 @@ class Command(persistence.Savable): @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): super().__init__() self.msg = msg @@ -76,7 +75,6 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): - def __init__( self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None ): @@ -88,7 +86,6 @@ def __init__( @auto_persist('result') class Stop(Command): - def __init__(self, result: Any, successful: bool) -> None: super().__init__() self.result = result @@ -127,6 +124,7 @@ class ProcessState(Enum): """ The possible states that a :class:`~plumpy.processes.Process` can be in. """ + CREATED: str = 'created' RUNNING: str = 'running' WAITING: str = 'waiting' @@ -137,7 +135,6 @@ class ProcessState(Enum): @auto_persist('in_state') class State(state_machine.State, persistence.Savable): - @property def process(self) -> state_machine.StateMachine: """ @@ -149,7 +146,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process - def interrupt(self, reason: Any) -> None: # pylint: disable=unused-argument + def interrupt(self, reason: Any) -> None: pass @@ -183,7 +180,11 @@ def execute(self) -> state_machine.State: class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.FINISHED, + ProcessState.KILLED, + ProcessState.EXCEPTED, } RUN_FN = 'run_fn' # The key used to store the function to run @@ -217,7 +218,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore # pylint: disable=invalid-overridden-method + async def execute(self) -> State: # type: ignore if self._command is not None: command = self._command else: @@ -230,7 +231,7 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over except Interruption: # Let this bubble up to the caller raise - except Exception: # pylint: disable=broad-except + except Exception: excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(State, excepted) else: @@ -267,7 +268,11 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, ProcessState.EXCEPTED, ProcessState.FINISHED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.KILLED, + ProcessState.EXCEPTED, + ProcessState.FINISHED, } DONE_CALLBACK = 'DONE_CALLBACK' @@ -285,7 +290,7 @@ def __init__( process: 'Process', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - data: Optional[Any] = None + data: Optional[Any] = None, ) -> None: super().__init__(process) self.done_callback = done_callback @@ -311,7 +316,7 @@ def interrupt(self, reason: Any) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore # pylint: disable=invalid-overridden-method + async def execute(self) -> State: # type: ignore try: result = await self._waiting_future except Interruption: @@ -370,9 +375,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = \ - tblib.Traceback.from_string(saved_state[self.TRACEBACK], - strict=False) + self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) except KeyError: self.traceback = None else: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index b6e14ad9..ba7967d3 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The main Process module""" + import abc import asyncio import contextlib @@ -10,6 +11,8 @@ import re import sys import time +import uuid +import warnings from types import TracebackType from typing import ( Any, @@ -26,17 +29,15 @@ Union, cast, ) -import uuid -import warnings try: from aiocontextvars import ContextVar except ModuleNotFoundError: from contextvars import ContextVar -from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed import kiwipy import yaml +from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils from .base import state_machine @@ -47,9 +48,7 @@ from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected -# pylint: disable=too-many-lines - -__all__ = ['Process', 'ProcessSpec', 'BundleKeys', 'TransitionFailed'] +__all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) PROCESS_STACK = ContextVar('process stack', default=[]) @@ -62,7 +61,7 @@ class BundleKeys: See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. """ - # pylint: disable=too-few-public-methods + INPUTS_RAW = 'INPUTS_RAW' INPUTS_PARSED = 'INPUTS_PARSED' OUTPUTS = 'OUTPUTS' @@ -75,7 +74,7 @@ class ProcessStateMachineMeta(abc.ABCMeta, state_machine.StateMachineMeta): # Make ProcessStateMachineMeta instances (classes) YAML - able yaml.representer.Representer.add_representer( ProcessStateMachineMeta, - yaml.representer.Representer.represent_name # type: ignore[arg-type] + yaml.representer.Representer.represent_name, # type: ignore[arg-type] ) @@ -84,7 +83,6 @@ def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - # pylint: disable=protected-access if self._closed: raise exceptions.ClosedError('Process is closed') return func(self, *args, **kwargs) @@ -133,8 +131,6 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe executed. """ - # pylint: disable=too-many-instance-attributes,too-many-public-methods - # Static class stuff ###################### _spec_class = ProcessSpec # Default placeholders, will be populated in init() @@ -167,7 +163,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: state_classes = cls.get_state_classes() return ( state_classes[process_states.ProcessState.CREATED], - *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED] + *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED], ) @classmethod @@ -179,7 +175,7 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: process_states.ProcessState.WAITING: process_states.Waiting, process_states.ProcessState.FINISHED: process_states.Finished, process_states.ProcessState.EXCEPTED: process_states.Excepted, - process_states.ProcessState.KILLED: process_states.Killed + process_states.ProcessState.KILLED: process_states.Killed, } @classmethod @@ -256,7 +252,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -278,8 +274,9 @@ def __init__( self._setup_event_hooks() self._status: Optional[str] = None # May hold a current status message - self._pre_paused_status: Optional[ - str] = None # Save status when a pause message replaces it, such that it can be restored + self._pre_paused_status: Optional[str] = ( + None # Save status when a pause message replaces it, such that it can be restored + ) self._paused = None # Input/output @@ -331,12 +328,13 @@ def try_killing(future: futures.Future) -> None: def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { - state_machine.StateEventHook.ENTERING_STATE: - lambda _s, _h, state: self.on_entering(cast(process_states.State, state)), - state_machine.StateEventHook.ENTERED_STATE: - lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)), - state_machine.StateEventHook.EXITING_STATE: - lambda _s, _h, _state: self.on_exiting() + state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( + cast(process_states.State, state) + ), + state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( + cast(Optional[process_states.State], from_state) + ), + state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } for hook, callback in event_hooks.items(): self.add_state_event_callback(hook, callback) @@ -356,7 +354,7 @@ def pid(self) -> Optional[PID_TYPE]: @property def uuid(self) -> Optional[uuid.UUID]: - """Return the UUID of the process """ + """Return the UUID of the process""" return self._uuid @property @@ -421,7 +419,7 @@ def launch( process_class: Type['Process'], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, - logger: Optional[logging.Logger] = None + logger: Optional[logging.Logger] = None, ) -> 'Process': """Start running the nested process. @@ -507,7 +505,7 @@ def done(self) -> bool: .. deprecated:: 0.18.6 Use the `has_terminated` method instead """ - warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) # pylint: disable=no-member + warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) return self._state.is_terminal() # endregion @@ -663,7 +661,7 @@ def add_process_listener(self, listener: ProcessListener) -> None: the specific state condition. """ - assert (listener != self), 'Cannot listen to yourself!' # type: ignore + assert listener != self, 'Cannot listen to yourself!' # type: ignore self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: @@ -886,7 +884,7 @@ def on_close(self) -> None: for cleanup in self._cleanups or []: try: cleanup() - except Exception: # pylint: disable=broad-except + except Exception: self.logger.exception('Process<%s>: Exception calling cleanup method %s', self.pid, cleanup) self._cleanups = None finally: @@ -926,15 +924,16 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An # Didn't match any known intents raise RuntimeError('Unknown intent') - def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, - correlation_id: Any) -> Optional[kiwipy.Future]: + def broadcast_receive( + self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + ) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator :param _comm: the communicator that sent the message :param msg: the message """ - # pylint: disable=unused-argument + self.logger.debug( "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body ) @@ -1044,7 +1043,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: - """ Carry out the pause procedure, optionally transitioning to the next state first""" + """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) @@ -1091,7 +1090,7 @@ def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) self._interrupt_action = new_action def _set_interrupt_action_from_exception(self, interrupt_exception: process_states.Interruption) -> None: - """ Set an interrupt action from the corresponding interrupt exception """ + """Set an interrupt action from the corresponding interrupt exception""" action = self._create_interrupt_action(interrupt_exception) self._set_interrupt_action(action) @@ -1233,9 +1232,9 @@ async def step(self) -> None: else: self._set_interrupt_action_from_exception(exception) - except KeyboardInterrupt: # pylint: disable=try-except-raise + except KeyboardInterrupt: raise - except Exception: # pylint: disable=broad-except + except Exception: # Overwrite the next state to go to excepted directly next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) self._set_interrupt_action(None) @@ -1285,7 +1284,7 @@ def out(self, output_port: str, value: Any) -> None: if namespace: port_namespace = cast( ports.PortNamespace, - self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True) + self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True), ) else: port_namespace = self.spec().outputs @@ -1341,9 +1340,11 @@ def get_status_info(self, out_status_info: dict) -> None: :param out_status_info: the old status """ - out_status_info.update({ - 'ctime': self.creation_time, - 'paused': self.paused, - 'process_string': str(self), - 'state': str(self.state), - }) + out_status_info.update( + { + 'ctime': self.creation_time, + 'paused': self.paused, + 'process_string': str(self), + 'state': str(self.state), + } + ) diff --git a/src/plumpy/settings.py b/src/plumpy/settings.py index 8a136dea..e863311c 100644 --- a/src/plumpy/settings.py +++ b/src/plumpy/settings.py @@ -1,4 +1,3 @@ # -*- coding: utf-8 -*- -# pylint: disable=invalid-name check_protected: bool = False check_override: bool = False diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 4eab8efe..36d76bbd 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -1,30 +1,36 @@ # -*- coding: utf-8 -*- import asyncio -from collections import deque -from collections.abc import Mapping import functools import importlib import inspect import logging import types -from typing import Set # pylint: disable=unused-import -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, MutableMapping, Optional, Tuple, Type +from collections import deque +from collections.abc import Mapping +from typing import ( + Any, + Callable, + Hashable, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + Type, +) from . import lang from .settings import check_override, check_protected -if TYPE_CHECKING: - from .process_listener import ProcessListener # pylint: disable=cyclic-import - __all__ = ['AttributesDict'] -protected = lang.protected(check=check_protected) # pylint: disable=invalid-name -override = lang.override(check=check_override) # pylint: disable=invalid-name +protected = lang.protected(check=check_protected) +override = lang.override(check=check_override) _LOGGER = logging.getLogger(__name__) -SAVED_STATE_TYPE = MutableMapping[str, Any] # pylint: disable=invalid-name -PID_TYPE = Hashable # pylint: disable=invalid-name +SAVED_STATE_TYPE = MutableMapping[str, Any] +PID_TYPE = Hashable class Frozendict(Mapping): @@ -67,7 +73,6 @@ def __hash__(self) -> int: class AttributesFrozendict(Frozendict): - def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._initialised: bool = True @@ -130,7 +135,7 @@ def load_function(name: str, instance: Optional[Any] = None) -> Callable[..., An obj = load_object(name) if inspect.ismethod(obj): if instance is not None: - return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] # pylint: disable=unnecessary-dunder-call + return obj.__get__(instance, instance.__class__) # type: ignore[attr-defined] return obj diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 90e35482..748a44d7 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,28 +1,43 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import abc import asyncio import collections import inspect import logging import re -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + List, + Mapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) import kiwipy from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE -__all__ = ['WorkChain', 'if_', 'while_', 'return_', 'ToContext', 'WorkChainSpec'] +__all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] ToContext = dict -PREDICATE_TYPE = Callable[['WorkChain'], bool] # pylint: disable=invalid-name -WC_COMMAND_TYPE = Callable[['WorkChain'], Any] # pylint: disable=invalid-name -EXIT_CODE_TYPE = int # pylint: disable=invalid-name +PREDICATE_TYPE = Callable[['WorkChain'], bool] +WC_COMMAND_TYPE = Callable[['WorkChain'], Any] +EXIT_CODE_TYPE = int class WorkChainSpec(processes.ProcessSpec): - def __init__(self) -> None: super().__init__() self._outline: Optional[Union['_Instruction', '_FunctionCall']] = None @@ -55,21 +70,20 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): - """ Overwrite the waiting state""" + """Overwrite the waiting state""" def __init__( self, process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None + awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: super().__init__(process, done_callback, msg, awaiting) self._awaiting: Dict[asyncio.Future, str] = {} for awaitable, key in (awaiting or {}).items(): - if isinstance(awaitable, processes.Process): - awaitable = awaitable.future() - self._awaiting[awaitable] = key + resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable + self._awaiting[resolved_awaitable] = key def enter(self) -> None: super().enter() @@ -85,7 +99,7 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: self.process.ctx[key] = awaitable.result() # type: ignore - except Exception as exception: # pylint: disable=broad-except + except Exception as exception: self._waiting_future.set_exception(exception) else: if not self._awaiting: @@ -97,6 +111,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): A WorkChain is a series of instructions carried out with the ability to save state in between. """ + _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' _CONTEXT = 'CONTEXT' @@ -113,7 +128,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) self._stepper: Optional[Stepper] = None @@ -152,9 +167,9 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None to the corresponding key in the context of the workchain """ for key, awaitable in kwargs.items(): - if isinstance(awaitable, processes.Process): - awaitable = awaitable.future() - self._awaitables[awaitable] = key + resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable + + self._awaitables[resolved_awaitable] = key def run(self) -> Any: return self._do_step() @@ -169,7 +184,6 @@ def _do_step(self) -> Any: finished, return_value = True, exception.exit_code if not finished and (return_value is None or isinstance(return_value, ToContext)): - if isinstance(return_value, ToContext): self.to_context(**return_value) @@ -182,7 +196,6 @@ def _do_step(self) -> Any: class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: self._workchain = workchain @@ -210,11 +223,11 @@ class _Instruction(metaclass=abc.ABCMeta): @abc.abstractmethod def create_stepper(self, workchain: 'WorkChain') -> Stepper: - """ Create a new stepper for this instruction """ + """Create a new stepper for this instruction""" @abc.abstractmethod def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> Stepper: - """ Recreate a stepper from a previously saved state """ + """Recreate a stepper from a previously saved state""" def __str__(self) -> str: return str(self.get_description()) @@ -229,7 +242,6 @@ def get_description(self) -> Any: class _FunctionStepper(Stepper): - def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): super().__init__(workchain) self._fn = fn @@ -250,7 +262,6 @@ def __str__(self) -> str: class _FunctionCall(_Instruction): - def __init__(self, func: WC_COMMAND_TYPE) -> None: try: args = inspect.getfullargspec(func)[0] @@ -282,7 +293,6 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') class _BlockStepper(Stepper): - def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: super().__init__(workchain) self._block = block @@ -333,14 +343,15 @@ class _Block(_Instruction, collections.abc.Sequence): def __init__(self, instructions: Sequence[Union[_Instruction, WC_COMMAND_TYPE]]) -> None: # Build up the list of commands - comms = [] + comms: MutableSequence[_Instruction | _FunctionCall] = [] for instruction in instructions: if not isinstance(instruction, _Instruction): # Assume it's a function call - instruction = _FunctionCall(instruction) + comms.append(_FunctionCall(instruction)) + else: + comms.append(instruction) - comms.append(instruction) - self._instruction: List[Union[_Instruction, _FunctionCall]] = comms + self._instruction: MutableSequence[_Instruction | _FunctionCall] = comms def __getitem__(self, index: int) -> Union[_Instruction, _FunctionCall]: # type: ignore return self._instruction[index] @@ -392,10 +403,12 @@ def is_true(self, workflow: 'WorkChain') -> bool: if not hasattr(result, '__bool__'): import warnings + warnings.warn( f'The conditional predicate `{self._predicate.__name__}` returned `{result}` which is not boolean-like.' ' The return value should be `True` or `False` or implement the `__bool__` method. This behavior is ' - 'deprecated and will soon start raising an exception.', UserWarning + 'deprecated and will soon start raising an exception.', + UserWarning, ) return result @@ -411,7 +424,6 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') class _IfStepper(Stepper): - def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: super().__init__(workchain) self._if_instruction = if_instruction @@ -467,7 +479,6 @@ def __str__(self) -> str: class _If(_Instruction, collections.abc.Sequence): - def __init__(self, condition: PREDICATE_TYPE) -> None: super().__init__() self._ifs: List[_Conditional] = [_Conditional(self, condition, label=if_.__name__)] @@ -520,7 +531,6 @@ def get_description(self) -> Mapping[str, Any]: class _WhileStepper(Stepper): - def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: super().__init__(workchain) self._while_instruction = while_instruction @@ -563,7 +573,6 @@ def __str__(self) -> str: class _While(_Conditional, _Instruction, collections.abc.Sequence): - def __init__(self, predicate: PREDICATE_TYPE) -> None: super().__init__(self, predicate, label=while_.__name__) @@ -586,14 +595,12 @@ def get_description(self) -> Dict[str, Any]: class _PropagateReturn(BaseException): - def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: super().__init__() self.exit_code = exit_code class _ReturnStepper(Stepper): - def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: super().__init__(workchain) self._return_instruction = return_instruction @@ -603,7 +610,7 @@ def step(self) -> Tuple[bool, Any]: Raise a _PropagateReturn exception where the value is the exit code set in the _Return instruction upon instantiation """ - raise _PropagateReturn(self._return_instruction._exit_code) # pylint: disable=protected-access + raise _PropagateReturn(self._return_instruction._exit_code) class _Return(_Instruction): @@ -670,7 +677,7 @@ def while_(condition: PREDICATE_TYPE) -> _While: return _While(condition) -return_ = _Return() # pylint: disable=invalid-name +return_ = _Return() """ A global singleton that contains a Return instruction that allows to exit out of the workchain outline directly with None as exit code diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/base/__init__.py b/tests/base/__init__.py similarity index 100% rename from test/base/__init__.py rename to tests/base/__init__.py diff --git a/test/base/test_statemachine.py b/tests/base/test_statemachine.py similarity index 90% rename from test/base/test_statemachine.py rename to tests/base/test_statemachine.py index 72fed261..5b4b73d8 100644 --- a/test/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -25,7 +25,7 @@ def __init__(self, player, track): super().__init__(player) self.track = track self._last_time = None - self._played = 0. + self._played = 0.0 def __str__(self): if self.in_state: @@ -55,8 +55,7 @@ class Paused(state_machine.State): TRANSITIONS = {STOP: STOPPED} def __init__(self, player, playing_state): - assert isinstance(playing_state, Playing), \ - 'Must provide the playing state to pause' + assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) self.playing_state = playing_state @@ -65,7 +64,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track) + self.state_machine.transition_to(Playing, track=track) else: self.state_machine.transition_to(self.playing_state) @@ -81,7 +80,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track) + self.state_machine.transition_to(Playing, track=track) class CdPlayer(state_machine.StateMachine): @@ -108,7 +107,7 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, self._state) + self.transition_to(Paused, playing_state=self._state) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) @@ -117,14 +116,13 @@ def stop(self): class TestStateMachine(unittest.TestCase): - def test_basic(self): cd_player = CdPlayer() self.assertEqual(cd_player.state, STOPPED) cd_player.play('Eminem - The Real Slim Shady') self.assertEqual(cd_player.state, PLAYING) - time.sleep(1.) + time.sleep(1.0) cd_player.pause() self.assertEqual(cd_player.state, PAUSED) diff --git a/test/base/test_utils.py b/tests/base/test_utils.py similarity index 99% rename from test/base/test_utils.py rename to tests/base/test_utils.py index 9aa0237b..d62b1422 100644 --- a/test/base/test_utils.py +++ b/tests/base/test_utils.py @@ -5,7 +5,6 @@ class Root: - @utils.super_check def method(self): pass @@ -15,19 +14,16 @@ def do(self): class DoCall(Root): - def method(self): super().method() class DontCall(Root): - def method(self): pass class TestSuperCheckMixin(unittest.TestCase): - def test_do_call(self): DoCall().do() @@ -36,9 +32,7 @@ def test_dont_call(self): DontCall().do() def dont_call_middle(self): - class ThirdChild(DontCall): - def method(self): super().method() diff --git a/test/conftest.py b/tests/conftest.py similarity index 99% rename from test/conftest.py rename to tests/conftest.py index 43555586..c70088fa 100644 --- a/test/conftest.py +++ b/tests/conftest.py @@ -5,4 +5,5 @@ @pytest.fixture(scope='session') def set_event_loop_policy(): from plumpy import set_event_loop_policy + set_event_loop_policy() diff --git a/test/notebooks/get_event_loop.ipynb b/tests/notebooks/get_event_loop.ipynb similarity index 90% rename from test/notebooks/get_event_loop.ipynb rename to tests/notebooks/get_event_loop.ipynb index 6aa4fbd5..860ca3d2 100644 --- a/test/notebooks/get_event_loop.ipynb +++ b/tests/notebooks/get_event_loop.ipynb @@ -7,7 +7,9 @@ "outputs": [], "source": [ "import asyncio\n", - "from plumpy import set_event_loop_policy, PlumpyEventLoopPolicy\n", + "\n", + "from plumpy import PlumpyEventLoopPolicy, set_event_loop_policy\n", + "\n", "set_event_loop_policy()\n", "assert isinstance(asyncio.get_event_loop_policy(), PlumpyEventLoopPolicy)\n", "assert hasattr(asyncio.get_event_loop(), '_nest_patched')" diff --git a/test/persistence/__init__.py b/tests/persistence/__init__.py similarity index 100% rename from test/persistence/__init__.py rename to tests/persistence/__init__.py diff --git a/test/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py similarity index 88% rename from test/persistence/test_inmemory.py rename to tests/persistence/test_inmemory.py index bc03f88b..b0db46e7 100644 --- a/test/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,31 +1,26 @@ # -*- coding: utf-8 -*- -import asyncio -from test.utils import ProcessWithCheckpoint import unittest +from ..utils import ProcessWithCheckpoint + import plumpy +import plumpy -class TestInMemoryPersister(unittest.TestCase): +class TestInMemoryPersister(unittest.TestCase): def test_save_load_roundtrip(self): """ Test the plumpy.PicklePersister by taking a dummpy process, saving a checkpoint and recreating it from the same checkpoint """ - loop = asyncio.get_event_loop() process = ProcessWithCheckpoint() persister = plumpy.InMemoryPersister() persister.save_checkpoint(process) - bundle = persister.load_checkpoint(process.pid) - load_context = plumpy.LoadSaveContext(loop=loop) - recreated = bundle.unbundle(load_context) - def test_get_checkpoints_without_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -43,8 +38,7 @@ def test_get_checkpoints_without_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_checkpoints_with_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() tag_a = 'tag_a' @@ -64,15 +58,12 @@ def test_get_checkpoints_with_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') checkpoints = [checkpoint_a1, checkpoint_a2] @@ -87,15 +78,12 @@ def test_get_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') persister = plumpy.InMemoryPersister() persister.save_checkpoint(process_a, tag='1') @@ -116,8 +104,7 @@ def test_delete_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_checkpoint(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() diff --git a/test/persistence/test_pickle.py b/tests/persistence/test_pickle.py similarity index 89% rename from test/persistence/test_pickle.py rename to tests/persistence/test_pickle.py index 19e4f52a..dd68b4fd 100644 --- a/test/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -1,37 +1,29 @@ # -*- coding: utf-8 -*- -import asyncio import tempfile import unittest if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from test.utils import ProcessWithCheckpoint +from ..utils import ProcessWithCheckpoint import plumpy class TestPicklePersister(unittest.TestCase): - def test_save_load_roundtrip(self): """ Test the plumpy.PicklePersister by taking a dummpy process, saving a checkpoint and recreating it from the same checkpoint """ - loop = asyncio.get_event_loop() process = ProcessWithCheckpoint() with tempfile.TemporaryDirectory() as directory: persister = plumpy.PicklePersister(directory) persister.save_checkpoint(process) - bundle = persister.load_checkpoint(process.pid) - load_context = plumpy.LoadSaveContext(loop=loop) - recreated = bundle.unbundle(load_context) - def test_get_checkpoints_without_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -50,8 +42,7 @@ def test_get_checkpoints_without_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_checkpoints_with_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() tag_a = 'tag_a' @@ -72,15 +63,12 @@ def test_get_checkpoints_with_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') checkpoints = [checkpoint_a1, checkpoint_a2] @@ -96,15 +84,12 @@ def test_get_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() checkpoint_a1 = plumpy.PersistedCheckpoint(process_a.pid, '1') checkpoint_a2 = plumpy.PersistedCheckpoint(process_a.pid, '2') - checkpoint_b1 = plumpy.PersistedCheckpoint(process_b.pid, '1') - checkpoint_b2 = plumpy.PersistedCheckpoint(process_b.pid, '2') with tempfile.TemporaryDirectory() as directory: persister = plumpy.PicklePersister(directory) @@ -126,8 +111,7 @@ def test_delete_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_checkpoint(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() diff --git a/test/rmq/__init__.py b/tests/rmq/__init__.py similarity index 100% rename from test/rmq/__init__.py rename to tests/rmq/__init__.py diff --git a/test/rmq/docker-compose.yml b/tests/rmq/docker-compose.yml similarity index 100% rename from test/rmq/docker-compose.yml rename to tests/rmq/docker-compose.yml diff --git a/test/rmq/test_communicator.py b/tests/rmq/test_communicator.py similarity index 90% rename from test/rmq/test_communicator.py rename to tests/rmq/test_communicator.py index 5cedd38d..3f2570d8 100644 --- a/test/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.rmq.communicator` module.""" + import asyncio import functools import shutil import tempfile import uuid -from kiwipy import BroadcastFilter, rmq import pytest import shortuuid import yaml +from kiwipy import BroadcastFilter, rmq import plumpy from plumpy import communications, process_comms @@ -38,7 +39,7 @@ def loop_communicator(): message_exchange=message_exchange, task_exchange=task_exchange, task_queue=task_queue, - decoder=functools.partial(yaml.load, Loader=yaml.Loader) + decoder=functools.partial(yaml.load, Loader=yaml.Loader), ) loop = asyncio.get_event_loop() @@ -61,7 +62,7 @@ class TestLoopCommunicator: @pytest.mark.asyncio async def test_broadcast(self, loop_communicator): - BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} + BROADCAST = {'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} # noqa: N806 broadcast_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -69,12 +70,9 @@ async def test_broadcast(self, loop_communicator): def get_broadcast(_comm, body, sender, subject, correlation_id): assert loop is asyncio.get_event_loop() - broadcast_future.set_result({ - 'body': body, - 'sender': sender, - 'subject': subject, - 'correlation_id': correlation_id - }) + broadcast_future.set_result( + {'body': body, 'sender': sender, 'subject': subject, 'correlation_id': correlation_id} + ) loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send(**BROADCAST) @@ -84,11 +82,8 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_broadcast_filter(self, loop_communicator): - broadcast_future = plumpy.Future() - loop = asyncio.get_event_loop() - def ignore_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_exception(AssertionError('broadcast received')) @@ -98,12 +93,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): loop_communicator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send( - **{ - 'body': 'present', - 'sender': 'Martin', - 'subject': 'sup', - 'correlation_id': 420 - } + **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} ) result = await broadcast_future @@ -111,7 +101,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_rpc(self, loop_communicator): - MSG = 'rpc this' + MSG = 'rpc this' # noqa: N806 rpc_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -128,7 +118,7 @@ def get_rpc(_comm, msg): @pytest.mark.asyncio async def test_task(self, loop_communicator): - TASK = 'task this' + TASK = 'task this' # noqa: N806 task_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -145,7 +135,6 @@ def get_task(_comm, msg): class TestTaskActions: - @pytest.mark.asyncio async def test_launch(self, loop_communicator, async_controller, persister): # Let the process run to the end @@ -157,7 +146,7 @@ async def test_launch(self, loop_communicator, async_controller, persister): @pytest.mark.asyncio async def test_launch_nowait(self, loop_communicator, async_controller, persister): - """ Testing launching but don't wait, just get the pid """ + """Testing launching but don't wait, just get the pid""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.launch_process(utils.DummyProcess, nowait=True) @@ -165,7 +154,7 @@ async def test_launch_nowait(self, loop_communicator, async_controller, persiste @pytest.mark.asyncio async def test_execute_action(self, loop_communicator, async_controller, persister): - """ Test the process execute action """ + """Test the process execute action""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) result = await async_controller.execute_process(utils.DummyProcessWithOutput) @@ -173,7 +162,7 @@ async def test_execute_action(self, loop_communicator, async_controller, persist @pytest.mark.asyncio async def test_execute_action_nowait(self, loop_communicator, async_controller, persister): - """ Test the process execute action """ + """Test the process execute action""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True) @@ -197,7 +186,7 @@ async def test_launch_many(self, loop_communicator, async_controller, persister) @pytest.mark.asyncio async def test_continue(self, loop_communicator, async_controller, persister): - """ Test continuing a saved process """ + """Test continuing a saved process""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) process = utils.DummyProcessWithOutput() diff --git a/test/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py similarity index 97% rename from test/rmq/test_process_comms.py rename to tests/rmq/test_process_comms.py index 6afccf46..7223b888 100644 --- a/test/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- import asyncio +import copy import kiwipy -from kiwipy import rmq import pytest import shortuuid +from kiwipy import rmq import plumpy -from plumpy import process_comms import plumpy.communications +from plumpy import process_comms from .. import utils @@ -43,7 +44,6 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: - @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) @@ -122,7 +122,6 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: - @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) @@ -197,7 +196,10 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - sync_controller.kill_all('bang bang, I shot you down') + msg = copy.copy(process_comms.KILL_MSG) + msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + + sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) diff --git a/test/test_communications.py b/tests/test_communications.py similarity index 100% rename from test/test_communications.py rename to tests/test_communications.py index f82036bd..f7e04255 100644 --- a/test/test_communications.py +++ b/tests/test_communications.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.communications` module.""" -from kiwipy import CommunicatorHelper + import pytest +from kiwipy import CommunicatorHelper from plumpy.communications import LoopCommunicator @@ -14,7 +15,6 @@ def __call__(self): class Communicator(CommunicatorHelper): - def task_send(self, task, no_reply=False): pass diff --git a/test/test_events.py b/tests/test_events.py similarity index 99% rename from test/test_events.py rename to tests/test_events.py index e6260f1d..964bd6f7 100644 --- a/test/test_events.py +++ b/tests/test_events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.events` module.""" + import asyncio import pathlib diff --git a/test/test_expose.py b/tests/test_expose.py similarity index 84% rename from test/test_expose.py rename to tests/test_expose.py index 1a495727..0f6f8087 100644 --- a/test/test_expose.py +++ b/tests/test_expose.py @@ -1,56 +1,50 @@ # -*- coding: utf-8 -*- -from test.utils import NewLoopProcess import unittest +from .utils import NewLoopProcess + from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process -class TestExposeProcess(unittest.TestCase): +def validator_function(input, port): + pass - def setUp(self): - super().setUp() - def validator_function(input, port): - pass +class BaseNamespaceProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('top') + spec.input('namespace.sub_one') + spec.input('namespace.sub_two') + spec.inputs['namespace'].valid_type = (int, float) + spec.inputs['namespace'].validator = validator_function - class BaseNamespaceProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.input('top') - spec.input('namespace.sub_one') - spec.input('namespace.sub_two') - spec.inputs['namespace'].valid_type = (int, float) - spec.inputs['namespace'].validator = validator_function +class BaseProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a', valid_type=str, default='a') + spec.input('b', valid_type=str, default='b') + spec.inputs.dynamic = True + spec.inputs.valid_type = str - class BaseProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.input('a', valid_type=str, default='a') - spec.input('b', valid_type=str, default='b') - spec.inputs.dynamic = True - spec.inputs.valid_type = str +class ExposeProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.expose_inputs(BaseProcess, namespace='base.name.space') + spec.input('c', valid_type=int, default=1) + spec.input('d', valid_type=int, default=2) + spec.inputs.dynamic = True + spec.inputs.valid_type = int - class ExposeProcess(NewLoopProcess): - - @classmethod - def define(cls, spec): - super().define(spec) - spec.expose_inputs(BaseProcess, namespace='base.name.space') - spec.input('c', valid_type=int, default=1) - spec.input('d', valid_type=int, default=2) - spec.inputs.dynamic = True - spec.inputs.valid_type = int - - self.BaseNamespaceProcess = BaseNamespaceProcess - self.BaseProcess = BaseProcess - self.ExposeProcess = ExposeProcess +class TestExposeProcess(unittest.TestCase): def check_ports(self, process, namespace, expected_port_names): """Check the port namespace of a given process inputs spec for existence of set of expected port names.""" port_namespace = process.spec().inputs @@ -68,24 +62,21 @@ def check_namespace_properties(self, process_left, namespace_left, process_right port_namespace_left = process_left.spec().inputs.get_port(namespace_left) port_namespace_right = process_right.spec().inputs.get_port(namespace_right) - # Pop the ports in stored in the `_ports` attribute - port_namespace_left.__dict__.pop('_ports', None) - port_namespace_right.__dict__.pop('_ports', None) + left_dict = {k: v for k, v in port_namespace_left.__dict__.items() if k != '_ports'} + right_dict = {k: v for k, v in port_namespace_right.__dict__.items() if k != '_ports'} - self.assertEqual(port_namespace_left.__dict__, port_namespace_right.__dict__) + self.assertEqual(left_dict, right_dict) def test_expose_dynamic(self): """Test that exposing a dynamic namespace remains dynamic.""" class Lower(Process): - @classmethod def define(cls, spec): super(Lower, cls).define(spec) spec.input_namespace('foo', dynamic=True) class Upper(Process): - @classmethod def define(cls, spec): super(Upper, cls).define(spec) @@ -96,7 +87,7 @@ def define(cls, spec): def test_expose_nested_namespace(self): """Test that expose_inputs can create nested namespaces while maintaining own ports.""" - inputs = self.ExposeProcess.spec().inputs + inputs = ExposeProcess.spec().inputs # Verify that the nested namespaces are present self.assertTrue('base' in inputs) @@ -116,7 +107,7 @@ def test_expose_nested_namespace(self): def test_expose_ports(self): """Test that the exposed ports are present and properly deepcopied.""" - exposed_inputs = self.ExposeProcess.spec().inputs.get_port('base.name.space') + exposed_inputs = ExposeProcess.spec().inputs.get_port('base.name.space') self.assertEqual(len(exposed_inputs), 2) self.assertTrue('a' in exposed_inputs) @@ -125,32 +116,30 @@ def test_expose_ports(self): self.assertEqual(exposed_inputs['b'].default, 'b') # Change the default of base process port and verify they don't change the exposed port - self.BaseProcess.spec().inputs['a'].default = 'c' - self.assertEqual(self.BaseProcess.spec().inputs['a'].default, 'c') + BaseProcess.spec().inputs['a'].default = 'c' + self.assertEqual(BaseProcess.spec().inputs['a'].default, 'c') self.assertEqual(exposed_inputs['a'].default, 'a') def test_expose_attributes(self): """Test that the attributes of the exposed PortNamespace are maintained and properly deepcopied.""" - inputs = self.ExposeProcess.spec().inputs - exposed_inputs = self.ExposeProcess.spec().inputs.get_port('base.name.space') + inputs = ExposeProcess.spec().inputs + exposed_inputs = ExposeProcess.spec().inputs.get_port('base.name.space') - self.assertEqual(str, self.BaseProcess.spec().inputs.valid_type) + self.assertEqual(str, BaseProcess.spec().inputs.valid_type) self.assertEqual(str, exposed_inputs.valid_type) self.assertEqual(int, inputs.valid_type) # Now change the valid type of the BaseProcess inputs and verify it does not affect ExposeProcess - self.BaseProcess.spec().inputs.valid_type = float + BaseProcess.spec().inputs.valid_type = float - self.assertEqual(self.BaseProcess.spec().inputs.valid_type, float) + self.assertEqual(BaseProcess.spec().inputs.valid_type, float) self.assertEqual(exposed_inputs.valid_type, str) self.assertEqual(inputs.valid_type, int) def test_expose_exclude(self): """Test that the exclude argument of exposed_inputs works correctly and excludes ports from being absorbed.""" - BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -165,10 +154,8 @@ def define(cls, spec): def test_expose_include(self): """Test that the include argument of exposed_inputs works correctly and includes only specified ports.""" - BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -183,10 +170,8 @@ def define(cls, spec): def test_expose_exclude_include_mutually_exclusive(self): """Test that passing both exclude and include raises.""" - BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -208,7 +193,7 @@ def validator_function(input, port): # Define child process with all mutable properties of the inputs PortNamespace to a non-default value # This way we can check if the defaults of the ParentProcessSpec will be properly overridden - ChildProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 ChildProcessSpec.input('a', valid_type=int) ChildProcessSpec.input('b', valid_type=str) ChildProcessSpec.inputs.validator = validator_function @@ -218,7 +203,7 @@ def validator_function(input, port): ChildProcessSpec.inputs.default = True ChildProcessSpec.inputs.help = 'testing' - ParentProcessSpec = ProcessSpec() + ParentProcessSpec = ProcessSpec() # noqa: N806 ParentProcessSpec.input('c', valid_type=float) ParentProcessSpec._expose_ports( process_class=None, @@ -228,7 +213,7 @@ def validator_function(input, port): namespace=None, exclude=(), include=None, - namespace_options={} + namespace_options={}, ) # Verify that all the ports are there @@ -256,7 +241,7 @@ def validator_function(input, port): # Define child process with all mutable properties of the inputs PortNamespace to a non-default value # This way we can check if the defaults of the ParentProcessSpec will be properly overridden - ChildProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 ChildProcessSpec.input('a', valid_type=int) ChildProcessSpec.input('b', valid_type=str) ChildProcessSpec.inputs.validator = validator_function @@ -266,7 +251,7 @@ def validator_function(input, port): ChildProcessSpec.inputs.default = True ChildProcessSpec.inputs.help = 'testing' - ParentProcessSpec = ProcessSpec() + ParentProcessSpec = ProcessSpec() # noqa: N806 ParentProcessSpec.input('c', valid_type=float) ParentProcessSpec._expose_ports( process_class=None, @@ -283,7 +268,7 @@ def validator_function(input, port): 'dynamic': False, 'default': None, 'help': None, - } + }, ) # Verify that all the ports are there @@ -310,7 +295,7 @@ def validator_function(input, port): # Define child process with all mutable properties of the inputs PortNamespace to a non-default value # This way we can check if the defaults of the ParentProcessSpec will be properly overridden - ChildProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 ChildProcessSpec.input('a', valid_type=int) ChildProcessSpec.input('b', valid_type=str) ChildProcessSpec.inputs.validator = validator_function @@ -320,7 +305,7 @@ def validator_function(input, port): ChildProcessSpec.inputs.default = True ChildProcessSpec.inputs.help = 'testing' - ParentProcessSpec = ProcessSpec() + ParentProcessSpec = ProcessSpec() # noqa: N806 ParentProcessSpec.input('c', valid_type=float) ParentProcessSpec._expose_ports( process_class=None, @@ -330,7 +315,7 @@ def validator_function(input, port): namespace='namespace', exclude=(), include=None, - namespace_options={} + namespace_options={}, ) # Verify that all the ports are there @@ -351,8 +336,8 @@ def test_expose_ports_namespace_options_non_existent(self): Verify that passing non-supported PortNamespace mutable properties in namespace_options will raise a ValueError """ - ChildProcessSpec = ProcessSpec() - ParentProcessSpec = ProcessSpec() + ChildProcessSpec = ProcessSpec() # noqa: N806 + ParentProcessSpec = ProcessSpec() # noqa: N806 with self.assertRaises(ValueError): ParentProcessSpec._expose_ports( @@ -365,15 +350,13 @@ def test_expose_ports_namespace_options_non_existent(self): include=None, namespace_options={ 'non_existent': None, - } + }, ) def test_expose_nested_include_top_level(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -384,10 +367,8 @@ def define(cls, spec): def test_expose_nested_include_namespace(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -400,10 +381,8 @@ def define(cls, spec): def test_expose_nested_include_namespace_sub(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -416,10 +395,8 @@ def define(cls, spec): def test_expose_nested_include_combination(self): """Test the include rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -432,10 +409,8 @@ def define(cls, spec): def test_expose_nested_exclude_top_level(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -448,10 +423,8 @@ def define(cls, spec): def test_expose_nested_exclude_namespace(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -462,10 +435,8 @@ def define(cls, spec): def test_expose_nested_exclude_namespace_sub(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -478,10 +449,8 @@ def define(cls, spec): def test_expose_nested_exclude_combination(self): """Test the exclude rules can be nested and are properly unwrapped.""" - BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -504,7 +473,6 @@ def test_expose_exclude_port_with_validator(self): """ class BaseProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -520,10 +488,10 @@ def validator(cls, value, ctx): return None if not isinstance(value['a'], str): - return f'value for input `a` should be a str, but got: {type(value["a"])}' + a_type = type(value['a']) + return f'value for input `a` should be a str, but got: {a_type}' class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_lang.py b/tests/test_lang.py similarity index 97% rename from test/test_lang.py rename to tests/test_lang.py index 13136530..a55af31a 100644 --- a/test/test_lang.py +++ b/tests/test_lang.py @@ -5,7 +5,6 @@ class A: - def __init__(self): self._a = None @@ -22,27 +21,24 @@ def protected_property(self): def protected_fn_nocheck(self): return self._a - def testA(self): + def testA(self): # noqa: N802 self.protected_fn() self.protected_property class B(A): - - def testB(self): + def testB(self): # noqa: N802 self.protected_fn() self.protected_property class C(B): - - def testC(self): + def testC(self): # noqa: N802 self.protected_fn() self.protected_property class TestProtected(TestCase): - def test_free_function(self): with self.assertRaises(RuntimeError): @@ -79,7 +75,6 @@ def test_incorrect_usage(self): with self.assertRaises(RuntimeError): class TestWrongDecoratorOrder: - @protected(check=True) @property def a(self): @@ -87,13 +82,11 @@ def a(self): class Superclass: - def test(self): pass class TestOverride(TestCase): - def test_free_function(self): with self.assertRaises(RuntimeError): @@ -102,9 +95,7 @@ def some_func(): pass def test_correct_usage(self): - class Derived(Superclass): - @override(check=True) def test(self): return True @@ -115,7 +106,6 @@ class Middle(Superclass): pass class Next(Middle): - @override(check=True) def test(self): return True @@ -123,9 +113,7 @@ def test(self): self.assertTrue(Next().test()) def test_incorrect_usage(self): - class Derived: - @override(check=True) def test(self): pass @@ -136,7 +124,6 @@ def test(self): with self.assertRaises(RuntimeError): class TestWrongDecoratorOrder(Superclass): - @override(check=True) @property def test(self): diff --git a/test/test_loaders.py b/tests/test_loaders.py similarity index 97% rename from test/test_loaders.py rename to tests/test_loaders.py index 3058b77c..a1813f09 100644 --- a/test/test_loaders.py +++ b/tests/test_loaders.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.loaders` module.""" + import pytest import plumpy @@ -7,6 +8,7 @@ class DummyClass: """Dummy class for testing.""" + pass @@ -38,11 +40,12 @@ def test_default_object_roundtrip(): @pytest.mark.parametrize( - 'identifier, match', ( + 'identifier, match', + ( ('plumpy.non_existing_module.SomeClass', r'identifier `.*` has an invalid format.'), ('plumpy.non_existing_module:SomeClass', r'module `.*` from identifier `.*` could not be loaded.'), ('plumpy.loaders:NonExistingClass', r'object `.*` form identifier `.*` could not be loaded.'), - ) + ), ) def test_default_object_loader_load_object_except(identifier, match): """Test the :meth:`plumpy.DefaultObjectLoader.load_object` when it is expected to raise.""" diff --git a/test/test_persistence.py b/tests/test_persistence.py similarity index 97% rename from test/test_persistence.py rename to tests/test_persistence.py index 2c9cf4f9..78724aa0 100644 --- a/test/test_persistence.py +++ b/tests/test_persistence.py @@ -15,7 +15,6 @@ class SaveEmpty(plumpy.Savable): @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): - def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -26,13 +25,11 @@ def m(): @plumpy.auto_persist('test') class Save(plumpy.Savable): - def __init__(self): self.test = Save1() class TestSavable(unittest.TestCase): - def test_empty_savable(self): self._save_round_trip(SaveEmpty()) @@ -79,9 +76,8 @@ def _save_round_trip_with_loader(self, savable): class TestBundle(unittest.TestCase): - def test_bundle_load_context(self): - """ Check that the loop from the load context is used """ + """Check that the loop from the load context is used""" loop1 = asyncio.get_event_loop() proc = utils.DummyProcess(loop=loop1) bundle = plumpy.Bundle(proc) diff --git a/test/test_port.py b/tests/test_port.py similarity index 99% rename from test/test_port.py rename to tests/test_port.py index ab9b51a6..da483e81 100644 --- a/test/test_port.py +++ b/tests/test_port.py @@ -7,7 +7,6 @@ class TestPort(TestCase): - def test_required(self): spec = Port('required_value', required=True) @@ -21,7 +20,6 @@ def test_validate(self): self.assertIsNotNone(spec.validate('a')) def test_validator(self): - def validate(value, port): assert isinstance(port, Port) if not isinstance(value, int): @@ -45,7 +43,6 @@ def validate(value, port): class TestInputPort(TestCase): - def test_default(self): """Test the default value property for the InputPort.""" port = InputPort('test', default=5) @@ -81,12 +78,14 @@ def test_lambda_default(self): # Testing that passing an actual lambda as a value is alos possible port = InputPort('test', valid_type=(types.FunctionType, int), default=lambda: 5) - some_lambda = lambda: 'string' + + def some_lambda(): + return 'string' + self.assertIsNone(port.validate(some_lambda)) class TestOutputPort(TestCase): - def test_default(self): """ Test the default value property for the InputPort @@ -108,7 +107,6 @@ def validator(value, port): class TestPortNamespace(TestCase): - BASE_PORT_NAME = 'port' BASE_PORT_NAMESPACE_NAME = 'port' @@ -299,7 +297,7 @@ def test_port_namespace_validate(self): # Check the breadcrumbs are correct self.assertEqual( validation_error.port, - self.port_namespace.NAMESPACE_SEPARATOR.join((self.BASE_PORT_NAMESPACE_NAME, 'sub', 'space', 'output')) + self.port_namespace.NAMESPACE_SEPARATOR.join((self.BASE_PORT_NAMESPACE_NAME, 'sub', 'space', 'output')), ) def test_port_namespace_required(self): @@ -371,7 +369,9 @@ def test_port_namespace_lambda_defaults(self): self.assertIsNone(port_namespace.validate(inputs)) # When passing a lambda directly as the value, it should NOT be evaluated during pre_processing - some_lambda = lambda: 5 + def some_lambda(): + return 5 + inputs = port_namespace.pre_process({'lambda_default': some_lambda}) self.assertEqual(inputs['lambda_default'], some_lambda) self.assertIsNone(port_namespace.validate(inputs)) diff --git a/test/test_process_comms.py b/tests/test_process_comms.py similarity index 87% rename from test/test_process_comms.py rename to tests/test_process_comms.py index 6d3d335c..c59737ac 100644 --- a/test/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,23 +1,17 @@ # -*- coding: utf-8 -*- -import asyncio -from test import utils -import unittest - -from kiwipy import rmq import pytest +from tests import utils import plumpy -from plumpy import communications, process_comms +from plumpy import process_comms class Process(plumpy.Process): - def run(self): pass class CustomObjectLoader(plumpy.DefaultObjectLoader): - def load_object(self, identifier): if identifier == 'jimmy': return Process @@ -49,7 +43,7 @@ async def test_continue(): @pytest.mark.asyncio async def test_loader_is_used(): - """ Make sure that the provided class loader is used by the process launcher """ + """Make sure that the provided class loader is used by the process launcher""" loader = CustomObjectLoader() proc = Process() persister = plumpy.InMemoryPersister(loader=loader) diff --git a/test/test_process_spec.py b/tests/test_process_spec.py similarity index 99% rename from test/test_process_spec.py rename to tests/test_process_spec.py index 443f7a64..3be8c1d2 100644 --- a/test/test_process_spec.py +++ b/tests/test_process_spec.py @@ -10,7 +10,6 @@ class StrSubtype(str): class TestProcessSpec(TestCase): - def setUp(self): self.spec = ProcessSpec() @@ -18,7 +17,6 @@ def test_get_port_namespace_base(self): """ Get the root, inputs and outputs port namespaces of the ProcessSpec """ - ports = self.spec.ports input_ports = self.spec.inputs output_ports = self.spec.outputs diff --git a/test/test_processes.py b/tests/test_processes.py similarity index 97% rename from test/test_processes.py rename to tests/test_processes.py index 0cb4161b..faea9eae 100644 --- a/test/test_processes.py +++ b/tests/test_processes.py @@ -1,20 +1,22 @@ # -*- coding: utf-8 -*- """Process tests""" + import asyncio +import copy import enum -from test import utils import unittest import kiwipy import pytest +from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY from plumpy.utils import AttributesFrozendict class ForgetToCallParent(plumpy.Process): - def __init__(self, forget_on): super().__init__() self.forget_on = forget_on @@ -42,9 +44,7 @@ def on_kill(self, msg): @pytest.mark.asyncio async def test_process_scope(): - class ProcessTaskInterleave(plumpy.Process): - async def task(self, steps: list): steps.append(f'[{self.pid}] started') assert plumpy.Process.current() is self @@ -64,7 +64,6 @@ async def task(self, steps: list): class TestProcess(unittest.TestCase): - def test_spec(self): """ Check that the references to specs are doing the right thing... @@ -82,12 +81,10 @@ class Proc(utils.DummyProcess): self.assertIs(p.spec(), Proc.spec()) def test_dynamic_inputs(self): - class NoDynamic(Process): pass class WithDynamic(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -100,9 +97,7 @@ def define(cls, spec): proc.execute() def test_inputs(self): - class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -122,7 +117,6 @@ def test_raw_inputs(self): """ class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -138,9 +132,7 @@ def define(cls, spec): self.assertDictEqual(dict(process.raw_inputs), {'a': 5, 'nested': {'a': 'value'}}) def test_inputs_default(self): - class Proc(utils.DummyProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -199,7 +191,6 @@ def test_inputs_default_that_evaluate_to_false(self): for def_val in (True, False, 0, 1): class Proc(utils.DummyProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -214,7 +205,6 @@ def test_nested_namespace_defaults(self): """Process with a default in a nested namespace should be created, even if top level namespace not supplied.""" class SomeProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -229,7 +219,6 @@ def test_raise_in_define(self): """Process which raises in its 'define' method. Check that the spec is not set.""" class BrokenProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -293,12 +282,11 @@ def test_run_kill(self): proc.execute() def test_get_description(self): - class ProcWithoutSpec(Process): pass class ProcWithSpec(Process): - """ Process with a spec and a docstring """ + """Process with a spec and a docstring""" @classmethod def define(cls, spec): @@ -324,9 +312,7 @@ def define(cls, spec): self.assertIsInstance(desc_with_spec['description'], str) def test_logging(self): - class LoggerTester(Process): - def run(self, **kwargs): self.logger.info('Test') @@ -335,11 +321,13 @@ def run(self, **kwargs): proc.execute() def test_kill(self): - proc = utils.DummyProcess() + proc: Process = utils.DummyProcess() - proc.kill('Farewell!') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'Farewell!' + proc.kill(msg) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg(), 'Farewell!') + self.assertEqual(proc.killed_msg(), msg) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self): @@ -438,12 +426,13 @@ async def async_test(): self.assertEqual(proc.state, ProcessState.FINISHED) def test_kill_in_run(self): - class KillProcess(Process): after_kill = False def run(self, **kwargs): - self.kill('killed') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'killed' + self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state self.after_kill = True @@ -456,9 +445,7 @@ def run(self, **kwargs): self.assertEqual(proc.state, ProcessState.KILLED) def test_kill_when_paused_in_run(self): - class PauseProcess(Process): - def run(self, **kwargs): self.pause() self.kill() @@ -510,9 +497,7 @@ def test_run_multiple(self): self.assertDictEqual(proc_class.EXPECTED_OUTPUTS, result) def test_invalid_output(self): - class InvalidOutput(plumpy.Process): - def run(self): self.out('invalid', 5) @@ -536,7 +521,6 @@ def test_unsuccessful_result(self): ERROR_CODE = 256 class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -550,11 +534,10 @@ def run(self): self.assertEqual(proc.result(), ERROR_CODE) def test_pause_in_process(self): - """ Test that we can pause and cancel that by playing within the process """ + """Test that we can pause and cancel that by playing within the process""" test_case = self class TestPausePlay(plumpy.Process): - def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -574,12 +557,11 @@ def run(self): self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) def test_pause_play_in_process(self): - """ Test that we can pause and play that by playing within the process """ + """Test that we can pause and play that by playing within the process""" test_case = self class TestPausePlay(plumpy.Process): - def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -596,7 +578,6 @@ def test_process_stack(self): test_case = self class StackTest(plumpy.Process): - def run(self): test_case.assertIs(self, Process.current()) @@ -613,7 +594,6 @@ def test_nested(process): expect_true.append(process == Process.current()) class StackTest(plumpy.Process): - def run(self): # TODO: unexpected behaviour here # if assert error happend here not raise @@ -623,7 +603,6 @@ def run(self): test_nested(self) class ParentProcess(plumpy.Process): - def run(self): expect_true.append(self == Process.current()) StackTest().execute() @@ -646,21 +625,17 @@ def test_process_nested(self): """ class StackTest(plumpy.Process): - def run(self): pass class ParentProcess(plumpy.Process): - def run(self): StackTest().execute() ParentProcess().execute() def test_call_soon(self): - class CallSoon(plumpy.Process): - def run(self): self.call_soon(self.do_except) @@ -680,7 +655,6 @@ def test_exception_during_on_entered(self): """Test that an exception raised during ``on_entered`` will cause the process to be excepted.""" class RaisingProcess(Process): - def on_entered(self, from_state): if from_state is not None and from_state.label == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') @@ -696,9 +670,7 @@ def on_entered(self, from_state): assert str(process.exception()) == 'exception during on_entered' def test_exception_during_run(self): - class RaisingProcess(Process): - def run(self): raise RuntimeError('exception during run') @@ -862,7 +834,7 @@ async def async_test(): loop.run_until_complete(async_test()) def test_wait_save_continue(self): - """ Test that process saved while in WAITING state restarts correctly when loaded """ + """Test that process saved while in WAITING state restarts correctly when loaded""" loop = asyncio.get_event_loop() proc = utils.WaitForSignalProcess() @@ -905,7 +877,6 @@ def _check_round_trip(self, proc1): class TestProcessNamespace(unittest.TestCase): - def test_namespaced_process(self): """ Test that inputs in nested namespaces are properly validated and the returned @@ -913,7 +884,6 @@ def test_namespaced_process(self): """ class NameSpacedProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -938,7 +908,6 @@ def test_namespaced_process_inputs(self): """ class NameSpacedProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -964,7 +933,6 @@ def test_namespaced_process_dynamic(self): namespace = 'name.space' class DummyDynamicProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -991,14 +959,12 @@ def test_namespaced_process_outputs(self): namespace_nested = f'{namespace}.nested' class OutputMode(enum.Enum): - NONE = 0 DYNAMIC_PORT_NAMESPACE = 1 SINGLE_REQUIRED_PORT = 2 BOTH_SINGLE_AND_NAMESPACE = 3 class DummyDynamicProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -1057,7 +1023,6 @@ def run(self): class TestProcessEvents(unittest.TestCase): - def test_basic_events(self): proc = utils.DummyProcessWithOutput() events_tester = utils.ProcessListenerTester( @@ -1077,11 +1042,14 @@ def test_killed(self): def test_excepted(self): proc = utils.ExceptionProcess() - events_tester = utils.ProcessListenerTester(proc, ( - 'excepted', - 'running', - 'output_emitted', - )) + events_tester = utils.ProcessListenerTester( + proc, + ( + 'excepted', + 'running', + 'output_emitted', + ), + ) with self.assertRaises(RuntimeError): proc.execute() proc.result() @@ -1120,7 +1088,6 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): class _RestartProcess(utils.WaitForSignalProcess): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_utils.py b/tests/test_utils.py similarity index 99% rename from test/test_utils.py rename to tests/test_utils.py index 546261f2..9567db7a 100644 --- a/test/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,6 @@ class TestAttributesFrozendict: - def test_getitem(self): d = AttributesFrozendict({'a': 5}) assert d['a'] == 5 @@ -40,7 +39,6 @@ async def async_fct(): class TestEnsureCoroutine: - def test_sync_func(self): coro = ensure_coroutine(fct) assert inspect.iscoroutinefunction(coro) @@ -50,9 +48,7 @@ def test_async_func(self): assert coro is async_fct def test_callable_class(self): - class AsyncDummy: - async def __call__(self): pass @@ -60,9 +56,7 @@ async def __call__(self): assert coro is AsyncDummy def test_callable_object(self): - class AsyncDummy: - async def __call__(self): pass diff --git a/test/test_waiting_process.py b/tests/test_waiting_process.py similarity index 99% rename from test/test_waiting_process.py rename to tests/test_waiting_process.py index 87d39192..90427554 100644 --- a/test/test_waiting_process.py +++ b/tests/test_waiting_process.py @@ -9,7 +9,6 @@ class TestWaitingProcess(unittest.TestCase): - def test_instance_state(self): proc = utils.ThreeSteps() wl = utils.ProcessSaver(proc) diff --git a/test/test_workchains.py b/tests/test_workchains.py similarity index 93% rename from test/test_workchains.py rename to tests/test_workchains.py index c698aff9..08c7317a 100644 --- a/test/test_workchains.py +++ b/tests/test_workchains.py @@ -33,9 +33,17 @@ def on_create(self): super().on_create() # Reset the finished step self.finished_steps = { - k: False for k in [ - self.s1.__name__, self.s2.__name__, self.s3.__name__, self.s4.__name__, self.s5.__name__, - self.s6.__name__, self.isA.__name__, self.isB.__name__, self.ltN.__name__ + k: False + for k in [ + self.s1.__name__, + self.s2.__name__, + self.s3.__name__, + self.s4.__name__, + self.s5.__name__, + self.s6.__name__, + self.isA.__name__, + self.isB.__name__, + self.ltN.__name__, ] } @@ -59,15 +67,15 @@ def s6(self): self.ctx.counter = self.ctx.counter + 1 self._set_finished(inspect.stack()[0][3]) - def isA(self): + def isA(self): # noqa: N802 self._set_finished(inspect.stack()[0][3]) return self.inputs.value == 'A' - def isB(self): + def isB(self): # noqa: N802 self._set_finished(inspect.stack()[0][3]) return self.inputs.value == 'B' - def ltN(self): + def ltN(self): # noqa: N802 keep_looping = self.ctx.counter < self.inputs.n if not keep_looping: self._set_finished(inspect.stack()[0][3]) @@ -78,7 +86,6 @@ def _set_finished(self, function_name): class IfTest(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -101,7 +108,6 @@ def step2(self): class DummyWc(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -112,7 +118,6 @@ def do_nothing(self): class TestContext(unittest.TestCase): - def test_attributes(self): wc = DummyWc() wc.ctx.new_attr = 5 @@ -136,9 +141,9 @@ class TestWorkchain(unittest.TestCase): maxDiff = None def test_run(self): - A = 'A' - B = 'B' - C = 'C' + A = 'A' # noqa: N806 + B = 'B' # noqa: N806 + C = 'C' # noqa: N806 three = 3 # Try the if(..) part @@ -163,9 +168,7 @@ def test_run(self): self.assertTrue(finished, f'Step {step} was not called by workflow') def test_incorrect_outline(self): - class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -176,9 +179,7 @@ def define(cls, spec): Wf.spec() def test_same_input_node(self): - class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -195,11 +196,10 @@ def check_a_b(self): Wf(inputs=dict(a=x, b=x)).execute() def test_context(self): - A = 'a' - B = 'b' + A = 'a' # noqa: N806 + B = 'b' # noqa: N806 class ReturnA(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -209,7 +209,6 @@ def run(self): self.out('res', A) class ReturnB(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -219,7 +218,6 @@ def run(self): self.out('res', B) class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -257,9 +255,9 @@ def test_malformed_outline(self): spec.outline(lambda x, y: 5) def test_checkpointing(self): - A = 'A' - B = 'B' - C = 'C' + A = 'A' # noqa: N806 + B = 'B' # noqa: N806 + C = 'C' # noqa: N806 three = 3 # Try the if(..) part @@ -288,13 +286,11 @@ def test_listener_persistence(self): process_finished_count = 0 class TestListener(plumpy.ProcessListener): - def on_process_finished(self, process, output): nonlocal process_finished_count process_finished_count += 1 class SimpleWorkChain(plumpy.WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -315,7 +311,8 @@ def step2(self): workchain = SimpleWorkChain() workchain.add_process_listener(TestListener()) - output = workchain.execute() + + workchain.execute() self.assertEqual(process_finished_count, 1) @@ -324,7 +321,6 @@ def step2(self): self.assertEqual(process_finished_count, 2) def test_return_in_outline(self): - class WcWithReturn(WorkChain): FAILED_CODE = 1 @@ -360,9 +356,7 @@ def default(self): workchain.execute() def test_return_in_step(self): - class WcWithReturn(WorkChain): - FAILED_CODE = 1 @classmethod @@ -393,9 +387,7 @@ def after(self): workchain.execute() def test_tocontext_schedule_workchain(self): - class MainWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -409,7 +401,6 @@ def check(self): assert self.ctx.subwc.out.value == 5 class SubWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -446,14 +437,13 @@ async def async_test(): self.assertTrue(workchain.ctx.s2) loop = asyncio.get_event_loop() - loop.create_task(workchain.step_until_terminated()) + loop.create_task(workchain.step_until_terminated()) # noqa: RUF006 loop.run_until_complete(async_test()) def test_to_context(self): val = 5 class SimpleWc(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -463,7 +453,6 @@ def run(self): self.out('_return', val) class Workchain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -484,7 +473,6 @@ def test_output_namespace(self): """Test running a workchain with nested outputs.""" class TestWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -501,7 +489,6 @@ def test_exception_tocontext(self): my_exception = RuntimeError('Should not be reached') class Workchain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -528,7 +515,6 @@ def test_stepper_info(self): """Check status information provided by steppers""" class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -539,7 +525,13 @@ def define(cls, spec): cls.chill, cls.chill, ), - if_(cls.do_step)(cls.chill,).elif_(cls.do_step)(cls.chill,).else_(cls.chill), + if_(cls.do_step)( + cls.chill, + ) + .elif_(cls.do_step)( + cls.chill, + ) + .else_(cls.chill), ) def check_n(self): @@ -560,7 +552,6 @@ def do_step(self): return False class StatusCollector(ProcessListener): - def __init__(self): self.stepper_strings = [] @@ -574,9 +565,15 @@ def on_process_running(self, process): wf.execute() stepper_strings = [ - '0:check_n', '1:while_(do_step)', '1:while_(do_step)(1:chill)', '1:while_(do_step)', - '1:while_(do_step)(1:chill)', '1:while_(do_step)', '1:while_(do_step)(1:chill)', '1:while_(do_step)', - '2:if_(do_step)' + '0:check_n', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '2:if_(do_step)', ] self.assertListEqual(collector.stepper_strings, stepper_strings) @@ -593,7 +590,6 @@ def test_immutable_input(self): test_class = self class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -630,7 +626,6 @@ def test_immutable_input_namespace(self): test_class = self class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/utils.py b/tests/utils.py similarity index 93% rename from test/utils.py rename to tests/utils.py index 1f7408f6..f2a58dfc 100644 --- a/test/utils.py +++ b/tests/utils.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- """Utilities for tests""" + import asyncio import collections -from collections.abc import Mapping +import copy import unittest - -import kiwipy.rmq -import shortuuid +from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -24,7 +24,9 @@ class DummyProcess(processes.Process): """ EXPECTED_STATE_SEQUENCE = [ - process_states.ProcessState.CREATED, process_states.ProcessState.RUNNING, process_states.ProcessState.FINISHED + process_states.ProcessState.CREATED, + process_states.ProcessState.RUNNING, + process_states.ProcessState.FINISHED, ] EXPECTED_OUTPUTS = {} @@ -58,14 +60,12 @@ def run(self, **kwargs): class KeyboardInterruptProc(processes.Process): - @utils.override def run(self): raise KeyboardInterrupt() class ProcessWithCheckpoint(processes.Process): - @utils.override def run(self): return process_states.Continue(self.last_step) @@ -75,7 +75,6 @@ def last_step(self): class WaitForSignalProcess(processes.Process): - @utils.override def run(self): return process_states.Wait(self.last_step) @@ -85,14 +84,15 @@ def last_step(self): class KillProcess(processes.Process): - @utils.override def run(self): - return process_states.Kill('killed') + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'killed' + return process_states.Kill(msg=msg) class MissingOutputProcess(processes.Process): - """ A process that does not generate a required output """ + """A process that does not generate a required output""" @classmethod def define(cls, spec): @@ -101,7 +101,6 @@ def define(cls, spec): class NewLoopProcess(processes.Process): - def __init__(self, *args, **kwargs): kwargs['loop'] = plumpy.new_event_loop() super().__init__(*args, **kwargs) @@ -118,8 +117,7 @@ def called(cls, event): cls.called_events.append(event) def __init__(self, *args, **kwargs): - assert isinstance(self, processes.Process), \ - 'Mixin has to be used with a type derived from a Process' + assert isinstance(self, processes.Process), 'Mixin has to be used with a type derived from a Process' super().__init__(*args, **kwargs) self.__class__.called_events = [] @@ -165,7 +163,6 @@ def on_terminate(self): class ProcessEventsTester(EventsTesterMixin, processes.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -193,7 +190,6 @@ def last_step(self): class TwoCheckpointNoFinish(ProcessEventsTester): - def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) @@ -203,21 +199,18 @@ def middle_step(self): class ExceptionProcess(ProcessEventsTester): - def run(self): self.out('test', 5) raise RuntimeError('Great scott!') class ThreeStepsThenException(ThreeSteps): - @utils.override def last_step(self): raise RuntimeError('Great scott!') class ProcessListenerTester(plumpy.ProcessListener): - def __init__(self, process, expected_events): process.add_process_listener(self) self.expected_events = set(expected_events) @@ -249,7 +242,6 @@ def on_process_killed(self, process, msg): class Saver: - def __init__(self): self.snapshots = [] self.outputs = [] @@ -357,7 +349,11 @@ def on_process_killed(self, process, msg): TEST_PROCESSES = [DummyProcess, DummyProcessWithOutput, DummyProcessWithDynamicOutput, ThreeSteps] TEST_WAITING_PROCESSES = [ - ProcessWithCheckpoint, TwoCheckpointNoFinish, ExceptionProcess, ProcessEventsTester, ThreeStepsThenException + ProcessWithCheckpoint, + TwoCheckpointNoFinish, + ExceptionProcess, + ProcessEventsTester, + ThreeStepsThenException, ] TEST_EXCEPTION_PROCESSES = [ExceptionProcess, ThreeStepsThenException, MissingOutputProcess] @@ -402,7 +398,7 @@ def check_process_against_snapshots(loop, proc_class, snapshots): saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], - exclude={'exception', '_listeners'} + exclude={'exception', '_listeners'}, ) j += 1 @@ -438,9 +434,8 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): compare_value(bundle1, bundle2, list(v1), list(v2), exclude) elif isinstance(v1, set) and isinstance(v2, set): raise NotImplementedError('Comparison between sets not implemented') - else: - if v1 != v2: - raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') + elif v1 != v2: + raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') class TestPersister(persistence.Persister): @@ -449,7 +444,7 @@ class TestPersister(persistence.Persister): """ def save_checkpoint(self, process, tag=None): - """ Create the checkpoint bundle """ + """Create the checkpoint bundle""" persistence.Bundle(process) def load_checkpoint(self, pid, tag=None): @@ -469,7 +464,7 @@ def delete_process_checkpoints(self, pid): def run_until_waiting(proc): - """ Set up a future that will be resolved on entering the WAITING state """ + """Set up a future that will be resolved on entering the WAITING state""" from plumpy import ProcessState listener = plumpy.ProcessListener() @@ -490,7 +485,7 @@ def on_waiting(_waiting_proc): def run_until_paused(proc): - """ Set up a future that will be resolved when the process is paused """ + """Set up a future that will be resolved when the process is paused""" listener = plumpy.ProcessListener() paused = plumpy.Future()