diff --git a/plumpy/base/state_machine.py b/plumpy/base/state_machine.py index a3996d5a..e90cb827 100644 --- a/plumpy/base/state_machine.py +++ b/plumpy/base/state_machine.py @@ -286,6 +286,9 @@ def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE self._event_callbacks.setdefault(hook, []).append(callback) def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: + if getattr(self, '_closed', False): + # if the process is closed, then all callbacks have already been removed + return None try: self._event_callbacks[hook].remove(callback) except (KeyError, ValueError): diff --git a/plumpy/processes.py b/plumpy/processes.py index 782d9512..7dab678f 100644 --- a/plumpy/processes.py +++ b/plumpy/processes.py @@ -317,17 +317,16 @@ def try_killing(future: futures.Future) -> None: def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" - self.add_state_event_callback( - state_machine.StateEventHook.ENTERING_STATE, - lambda _s, _h, state: self.on_entering(cast(process_states.State, state)) - ) - self.add_state_event_callback( - state_machine.StateEventHook.ENTERED_STATE, - lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)) - ) - self.add_state_event_callback( - state_machine.StateEventHook.EXITING_STATE, lambda _s, _h, _state: self.on_exiting() - ) + event_hooks = { + state_machine.StateEventHook.ENTERING_STATE: + lambda _s, _h, state: self.on_entering(cast(process_states.State, state)), + state_machine.StateEventHook.ENTERED_STATE: + lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)), + state_machine.StateEventHook.EXITING_STATE: + lambda _s, _h, _state: self.on_exiting() + } + for hook, callback in event_hooks.items(): + self.add_state_event_callback(hook, callback) @property def creation_time(self) -> Optional[float]: @@ -845,6 +844,7 @@ def on_close(self) -> None: self.logger.exception('Exception calling cleanup method %s', cleanup) self._cleanups = None finally: + self._event_callbacks = {} self._closed = True def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None: