Skip to content

Commit

Permalink
Reinstate broadcast subscription of workchain for child processes
Browse files Browse the repository at this point in the history
When a `WorkChain` step submits sub processes, awaitables will be
created for them. Only when these awaitables have been resolved, meaning
the subprocesses, have terminated, can the workchain continue to the
next step.

The original concept was, for each awaitable, to schedule a callback
once the process had reached a terminal state. The callback was supposed
to be triggered by having the runner add a broadcast subscriber that would
listen for state changes of the sub process. As a fail-safe, a polling
mechanism would also check periodically just in case the broadcast
message would be missed and prevent the caller from waiting indefinitely.

However, the broadcast subscriber was never added and so the system
relied solely on the polling mechanism. This completely undermines the
benefits of having an event-based mechanism, so in this commit the
`Runner.call_on_process_finish` now also registers the broadcast
subscriber.

Note that the affected code had some references to `calculation` which
has been generalized to `process`, since this also applies to workflows
that might waited upon. The `CalculationFuture` has been renamed to
`ProcessFuture` in similar vein. It is currently not used, but it could
have been used for the problem that this commit solves, so it has been
decided to leave it in for now and not remove it entirely.
  • Loading branch information
sphuber committed Jun 17, 2020
1 parent d558e46 commit 7441fb8
Show file tree
Hide file tree
Showing 10 changed files with 134 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .ci/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def main():

# Run the `MultiplyAddWorkChain`
print('Running the `MultiplyAddWorkChain`')
run_base_restart_workchain()
run_multiply_add_workchain()

# Submitting the Calculations the new way directly through the launchers
print('Submitting {} calculations to the daemon'.format(NUMBER_CALCULATIONS))
Expand Down
48 changes: 23 additions & 25 deletions aiida/engine/processes/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,64 +14,62 @@
import plumpy
import kiwipy

__all__ = ('CalculationFuture',)
__all__ = ('ProcessFuture',)


class CalculationFuture(plumpy.Future):
"""
A future that waits for a calculation to complete using both polling and
listening for broadcast events if possible
"""
class ProcessFuture(plumpy.Future):
"""Future that waits for a process to complete using both polling and listening for broadcast events if possible."""

_filtered = None

def __init__(self, pk, loop=None, poll_interval=None, communicator=None):
"""
Get a future for a calculation node being finished. If a None poll_interval is
supplied polling will not be used. If a communicator is supplied it will be used
"""Construct a future for a process node being finished.
If a None poll_interval is supplied polling will not be used. If a communicator is supplied it will be used
to listen for broadcast messages.
:param pk: The calculation pk
:param pk: process pk
:param loop: An event loop
:param poll_interval: The polling interval. Can be None in which case no polling.
:param communicator: A communicator. Can be None in which case no broadcast listens.
:param poll_interval: optional polling interval, if None, polling is not activated.
:param communicator: optional communicator, if None, will not subscribe to broadcasts.
"""
from aiida.orm import load_node
from .process import ProcessState

super().__init__()
assert not (poll_interval is None and communicator is None), 'Must poll or have a communicator to use'

calc_node = load_node(pk=pk)
node = load_node(pk=pk)

if calc_node.is_terminated:
self.set_result(calc_node)
if node.is_terminated:
self.set_result(node)
else:
self._communicator = communicator
self.add_done_callback(lambda _: self.cleanup())

# Try setting up a filtered broadcast subscriber
if self._communicator is not None:
self._filtered = kiwipy.BroadcastFilter(lambda *args, **kwargs: self.set_result(calc_node), sender=pk)
broadcast_filter = kiwipy.BroadcastFilter(lambda *args, **kwargs: self.set_result(node), sender=pk)
for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]:
self._filtered.add_subject_filter('state_changed.*.{}'.format(state.value))
self._communicator.add_broadcast_subscriber(self._filtered)
broadcast_filter.add_subject_filter('state_changed.*.{}'.format(state.value))
self._broadcast_identifier = self._communicator.add_broadcast_subscriber(broadcast_filter)

# Start polling
if poll_interval is not None:
loop.add_callback(self._poll_calculation, calc_node, poll_interval)
loop.add_callback(self._poll_process, node, poll_interval)

def cleanup(self):
"""Clean up the future by removing broadcast subscribers from the communicator if it still exists."""
if self._communicator is not None:
self._communicator.remove_broadcast_subscriber(self._filtered)
self._filtered = None
self._communicator.remove_broadcast_subscriber(self._broadcast_identifier)
self._communicator = None
self._broadcast_identifier = None

@tornado.gen.coroutine
def _poll_calculation(self, calc_node, poll_interval):
"""Poll whether the calculation node has reached a terminal state."""
while not self.done() and not calc_node.is_terminated:
def _poll_process(self, node, poll_interval):
"""Poll whether the process node has reached a terminal state."""
while not self.done() and not node.is_terminated:
yield tornado.gen.sleep(poll_interval)

if not self.done():
self.set_result(calc_node)
self.set_result(node)
1 change: 0 additions & 1 deletion aiida/engine/processes/workchains/awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
###########################################################################
# pylint: disable=too-few-public-methods
"""Enums and function for the awaitables of Processes."""

from enum import Enum

from plumpy.utils import AttributesDict
Expand Down
15 changes: 8 additions & 7 deletions aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def on_wait(self, awaitables):
self.call_soon(self.resume)

def action_awaitables(self):
"""Handle the awaitables that are currently registered with the work chain
"""Handle the awaitables that are currently registered with the work chain.
Depending on the class type of the awaitable's target a different callback
function will be bound with the awaitable and the runner will be asked to
Expand All @@ -243,24 +243,24 @@ def action_awaitables(self):
for awaitable in self._awaitables:
if awaitable.target == AwaitableTarget.PROCESS:
callback = functools.partial(self._run_task, self.on_process_finished, awaitable)
self.runner.call_on_calculation_finish(awaitable.pk, callback)
self.runner.call_on_process_finish(awaitable.pk, callback)
else:
assert "invalid awaitable target '{}'".format(awaitable.target)

def on_process_finished(self, awaitable, pk):
def on_process_finished(self, awaitable):
"""Callback function called by the runner when the process instance identified by pk is completed.
The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all
awaitables have been dealt with, the work chain process is resumed.
:param awaitable: an Awaitable instance
:param pk: the pk of the awaitable's target
:type pk: int
"""
self.logger.info('received callback that awaitable %d has terminated', awaitable.pk)

try:
node = load_node(pk)
node = load_node(awaitable.pk)
except (exceptions.MultipleObjectsError, exceptions.NotExistent):
raise ValueError('provided pk<{}> could not be resolved to a valid Node instance'.format(pk))
raise ValueError('provided pk<{}> could not be resolved to a valid Node instance'.format(awaitable.pk))

if awaitable.outputs:
value = {entry.link_label: entry.node for entry in node.get_outgoing()}
Expand All @@ -275,5 +275,6 @@ def on_process_finished(self, awaitable, pk):
assert "invalid awaitable action '{}'".format(awaitable.action)

self.remove_awaitable(awaitable)

if self.state == ProcessState.WAITING and not self._awaitables:
self.resume()
77 changes: 57 additions & 20 deletions aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@
###########################################################################
# pylint: disable=global-statement
"""Runners that can run and submit processes."""

import collections
import functools
import logging
import signal
import tornado.ioloop
import threading
import uuid

import kiwipy
import plumpy
import tornado.ioloop

from aiida.common import exceptions
from aiida.orm import load_node
from aiida.plugins.utils import PluginVersionProvider

from .processes import futures
from .processes import futures, ProcessState
from .processes.calcjobs import manager
from . import transports
from . import utils
Expand Down Expand Up @@ -262,27 +265,61 @@ def run_get_pk(self, process, *args, **inputs):
result, node = self._run(process, *args, **inputs)
return ResultAndPk(result, node.pk)

def call_on_calculation_finish(self, pk, callback):
"""
Callback to be called when the calculation of the given pk is terminated
def call_on_process_finish(self, pk, callback):
"""Schedule a callback when the process of the given pk is terminated.
:param pk: the pk of the calculation
:param callback: the function to be called upon calculation termination
"""
calculation = load_node(pk=pk)
self._poll_calculation(calculation, callback)
This method will add a broadcast subscriber that will listen for state changes of the target process to be
terminated. As a fail-safe, a polling-mechanism is used to check the state of the process, should the broadcast
message be missed by the subscriber, in order to prevent the caller to wait indefinitely.
def get_calculation_future(self, pk):
:param pk: pk of the process
:param callback: function to be called upon process termination
"""
Get a future for an orm Calculation. The future will have the calculation node
as the result when finished.
node = load_node(pk=pk)
subscriber_identifier = str(uuid.uuid4())
event = threading.Event()

def inline_callback(event, *args, **kwargs): # pylint: disable=unused-argument
"""Callback to wrap the actual callback, that will always remove the subscriber that will be registered.
As soon as the callback is called successfully once, the `event` instance is toggled, such that if this
inline callback is called a second time, the actual callback is not called again.
"""
if event.is_set():
return

:return: A future representing the completion of the calculation node
try:
callback()
finally:
event.set()
self._communicator.remove_broadcast_subscriber(subscriber_identifier)

broadcast_filter = kiwipy.BroadcastFilter(functools.partial(inline_callback, event), sender=pk)
for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]:
broadcast_filter.add_subject_filter('state_changed.*.{}'.format(state.value))

LOGGER.info('adding subscriber for broadcasts of %d', pk)
self._communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier)
self._poll_process(node, functools.partial(inline_callback, event))

def get_process_future(self, pk):
"""Return a future for a process.
The future will have the process node as the result when finished.
:return: A future representing the completion of the process node
"""
return futures.CalculationFuture(pk, self._loop, self._poll_interval, self._communicator)
return futures.ProcessFuture(pk, self._loop, self._poll_interval, self._communicator)

def _poll_process(self, node, callback):
"""Check whether the process state of the node is terminated and call the callback or reschedule it.
def _poll_calculation(self, calc_node, callback):
if calc_node.is_terminated:
self._loop.add_callback(callback, calc_node.pk)
:param node: the process node
:param callback: callback to be called when process is terminated
"""
if node.is_terminated:
args = [node.__class__.__name__, node.pk]
LOGGER.info('%s<%d> confirmed to be terminated by backup polling mechanism', *args)
self._loop.add_callback(callback)
else:
self._loop.call_later(self._poll_interval, self._poll_calculation, calc_node, callback)
self._loop.call_later(self._poll_interval, self._poll_process, node, callback)
2 changes: 1 addition & 1 deletion aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_communicator(self):
return self._communicator

def create_communicator(self, task_prefetch_count=None, with_orm=True):
"""Create a Communicator
"""Create a Communicator.
:param task_prefetch_count: optional specify how many tasks this communicator take simultaneously
:param with_orm: if True, use ORM (de)serializers. If false, use json.
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ def validate_incoming(self, source, link_type, link_label):
"""Validate adding a link of the given type from a given node to ourself.
This function will first validate the types of the inputs, followed by the node and link types and validate
whether in principle a link of that type between the nodes of these types is allowed.the
whether in principle a link of that type between the nodes of these types is allowed.
Subsequently, the validity of the "degree" of the proposed link is validated, which means validating the
number of links of the given type from the given node type is allowed.
Expand Down
4 changes: 2 additions & 2 deletions tests/engine/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_calculation_future_broadcasts(self):
process = test_processes.DummyProcess()

# No polling
future = processes.futures.CalculationFuture(
future = processes.futures.ProcessFuture(
pk=process.pid, poll_interval=None, communicator=manager.get_communicator()
)

Expand All @@ -46,7 +46,7 @@ def test_calculation_future_polling(self):
process = test_processes.DummyProcess()

# No communicator
future = processes.futures.CalculationFuture(pk=process.pid, loop=runner.loop, poll_interval=0)
future = processes.futures.ProcessFuture(pk=process.pid, loop=runner.loop, poll_interval=0)

runner.run(process)
calc_node = runner.run_until_complete(gen.with_timeout(self.TIMEOUT, future))
Expand Down
12 changes: 6 additions & 6 deletions tests/engine/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_submit_simple(self):
@gen.coroutine
def do_submit():
calc_node = submit(test_processes.DummyProcess)
yield self.wait_for_calc(calc_node)
yield self.wait_for_process(calc_node)

self.assertTrue(calc_node.is_finished_ok)
self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.FINISHED.value)
Expand All @@ -61,7 +61,7 @@ def do_launch():
term_b = Int(10)

calc_node = submit(test_processes.AddProcess, a=term_a, b=term_b)
yield self.wait_for_calc(calc_node)
yield self.wait_for_process(calc_node)
self.assertTrue(calc_node.is_finished_ok)
self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.FINISHED.value)

Expand All @@ -77,7 +77,7 @@ def test_exception_process(self):
@gen.coroutine
def do_exception():
calc_node = submit(test_processes.ExceptionProcess)
yield self.wait_for_calc(calc_node)
yield self.wait_for_process(calc_node)

self.assertFalse(calc_node.is_finished_ok)
self.assertEqual(calc_node.process_state.value, plumpy.ProcessState.EXCEPTED.value)
Expand Down Expand Up @@ -147,15 +147,15 @@ def do_kill():
result = yield self.wait_future(future)
self.assertTrue(result)

self.wait_for_calc(calc_node)
self.wait_for_process(calc_node)
self.assertTrue(calc_node.is_killed)
self.assertEqual(calc_node.process_status, kill_message)

self.runner.loop.run_sync(do_kill)

@gen.coroutine
def wait_for_calc(self, calc_node, timeout=2.):
future = self.runner.get_calculation_future(calc_node.pk)
def wait_for_process(self, calc_node, timeout=2.):
future = self.runner.get_process_future(calc_node.pk)
raise gen.Return((yield with_timeout(future, timeout)))

@staticmethod
Expand Down
Loading

0 comments on commit 7441fb8

Please sign in to comment.