Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: scheduler #275

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiida_workgraph/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aiida_workgraph.cli import cmd_graph
from aiida_workgraph.cli import cmd_web
from aiida_workgraph.cli import cmd_task
from aiida_workgraph.cli import cmd_scheduler


__all__ = ["cmd_graph", "cmd_web", "cmd_task"]
__all__ = ["cmd_graph", "cmd_web", "cmd_task", "cmd_scheduler"]
163 changes: 163 additions & 0 deletions aiida_workgraph/cli/cmd_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from aiida_workgraph.cli.cmd_workgraph import workgraph
import click
from aiida.cmdline.utils import decorators, echo
from aiida.cmdline.commands.cmd_daemon import validate_daemon_workers
from aiida.cmdline.params import options
from aiida_workgraph.engine.scheduler.client import get_scheduler_client
import sys


@workgraph.group("scheduler")
def scheduler():
"""Commands to manage the scheduler process."""


@scheduler.command()
def worker():
"""Start the scheduler application."""
from aiida_workgraph.engine.scheduler.client import start_scheduler_worker

click.echo("Starting the scheduler worker...")

start_scheduler_worker()


@scheduler.command()
@click.option("--foreground", is_flag=True, help="Run in foreground.")
@click.argument("number", required=False, type=int, callback=validate_daemon_workers)
@options.TIMEOUT(default=None, required=False, type=int)
@decorators.with_dbenv()
@decorators.requires_broker
@decorators.check_circus_zmq_version
def start(foreground, number, timeout):
"""Start the scheduler application."""
from aiida_workgraph.engine.scheduler.client import start_scheduler_process

click.echo("Starting the scheduler process...")

client = get_scheduler_client()
client.start_daemon(number_workers=number, foreground=foreground, timeout=timeout)
start_scheduler_process(number)


@scheduler.command()
@click.option("--no-wait", is_flag=True, help="Do not wait for confirmation.")
@click.option("--all", "all_profiles", is_flag=True, help="Stop all daemons.")
@options.TIMEOUT(default=None, required=False, type=int)
@decorators.requires_broker
@click.pass_context
def stop(ctx, no_wait, all_profiles, timeout):
"""Stop the scheduler daemon.

Returns exit code 0 if the daemon was shut down successfully (or was not running), non-zero if there was an error.
"""
if all_profiles is True:
profiles = [
profile
for profile in ctx.obj.config.profiles
if not profile.is_test_profile
]
else:
profiles = [ctx.obj.profile]

for profile in profiles:
echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False)
echo.echo(f"{profile.name}", bold=True)
echo.echo("Stopping the daemon... ", nl=False)
try:
client = get_scheduler_client()
client.stop_daemon(wait=not no_wait, timeout=timeout)
except Exception as exception:
echo.echo_error(f"Failed to stop the daemon: {exception}")


@scheduler.command(hidden=True)
@click.option("--foreground", is_flag=True, help="Run in foreground.")
@click.argument("number", required=False, type=int, callback=validate_daemon_workers)
@decorators.with_dbenv()
@decorators.requires_broker
@decorators.check_circus_zmq_version
def start_circus(foreground, number):
"""This will actually launch the circus daemon, either daemonized in the background or in the foreground.

If run in the foreground all logs are redirected to stdout.

.. note:: this should not be called directly from the commandline!
"""

get_scheduler_client()._start_daemon(number_workers=number, foreground=foreground)


@scheduler.command()
@click.option("--all", "all_profiles", is_flag=True, help="Show status of all daemons.")
@options.TIMEOUT(default=None, required=False, type=int)
@click.pass_context
@decorators.requires_loaded_profile()
@decorators.requires_broker
def status(ctx, all_profiles, timeout):
"""Print the status of the scheduler daemon.

Returns exit code 0 if all requested daemons are running, else exit code 3.
"""
from tabulate import tabulate

from aiida.cmdline.utils.common import format_local_time
from aiida.engine.daemon.client import DaemonException

if all_profiles is True:
profiles = [
profile
for profile in ctx.obj.config.profiles
if not profile.is_test_profile
]
else:
profiles = [ctx.obj.profile]

daemons_running = []

for profile in profiles:
client = get_scheduler_client(profile.name)
echo.echo("Profile: ", fg=echo.COLORS["report"], bold=True, nl=False)
echo.echo(f"{profile.name}", bold=True)

try:
client.get_status(timeout=timeout)
except DaemonException as exception:
echo.echo_error(str(exception))
daemons_running.append(False)
continue

worker_response = client.get_worker_info()
daemon_response = client.get_daemon_info()

workers = []
for pid, info in worker_response["info"].items():
if isinstance(info, dict):
row = [
pid,
info["mem"],
info["cpu"],
format_local_time(info["create_time"]),
]
else:
row = [pid, "-", "-", "-"]
workers.append(row)

if workers:
workers_info = tabulate(
workers, headers=["PID", "MEM %", "CPU %", "started"], tablefmt="simple"
)
else:
workers_info = (
"--> No workers are running. Use `verdi daemon incr` to start some!\n"
)

start_time = format_local_time(daemon_response["info"]["create_time"])
echo.echo(
f'Daemon is running as PID {daemon_response["info"]["pid"]} since {start_time}\n'
f"Active workers [{len(workers)}]:\n{workers_info}\n"
"Use `verdi daemon [incr | decr] [num]` to increase / decrease the number of workers"
)

if not all(daemons_running):
sys.exit(3)
179 changes: 179 additions & 0 deletions aiida_workgraph/engine/launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from __future__ import annotations

import time
import typing as t

from aiida.common import InvalidOperation
from aiida.common.log import AIIDA_LOGGER
from aiida.manage import manager
from aiida.orm import ProcessNode

from aiida.engine.processes.builder import ProcessBuilder
from aiida.engine.processes.functions import get_stack_size
from aiida.engine.processes.process import Process
from aiida.engine.utils import prepare_inputs
from .utils import instantiate_process

import signal
import sys

from aiida.manage import get_manager

__all__ = ("run_get_node", "submit")

TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder]
# run can also be process function, but it is not clear what type this should be
TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder]
LOGGER = AIIDA_LOGGER.getChild("engine.launch")


"""
Note: I modified the run_get_node and submit functions to include the parent_pid argument.
This is necessary for keeping track of the provenance of the processes.

"""


def run_get_node(
process_class, *args, **kwargs
) -> tuple[dict[str, t.Any] | None, "ProcessNode"]:
"""Run the FunctionProcess with the supplied inputs in a local runner.
:param args: input arguments to construct the FunctionProcess
:param kwargs: input keyword arguments to construct the FunctionProcess
:return: tuple of the outputs of the process and the process node
"""
parent_pid = kwargs.pop("parent_pid", None)
frame_delta = 1000
frame_count = get_stack_size()
stack_limit = sys.getrecursionlimit()
LOGGER.info(
"Executing process function, current stack status: %d frames of %d",
frame_count,
stack_limit,
)
# If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the
# stack limit by ``frame_delta``.
if frame_count > min(0.8 * stack_limit, stack_limit - 200):
LOGGER.warning(
"Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d",
frame_count,
stack_limit,
frame_delta,
)
sys.setrecursionlimit(stack_limit + frame_delta)
manager = get_manager()
runner = manager.get_runner()
inputs = process_class.create_inputs(*args, **kwargs)
# Remove all the known inputs from the kwargs
for port in process_class.spec().inputs:
kwargs.pop(port, None)
# If any kwargs remain, the spec should be dynamic, so we raise if it isn't
if kwargs and not process_class.spec().inputs.dynamic:
raise ValueError(
f"{function.__name__} does not support these kwargs: {kwargs.keys()}"
)
process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid)
# Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner.
# Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown
current_runner = manager.get_runner()
original_handler = None
kill_signal = signal.SIGINT
if not current_runner.is_daemon_runner:

def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
LOGGER.critical(
"runner received interrupt, killing process %s", process.pid
)
result = process.kill(
msg="Process was killed because the runner received an interrupt"
)
return result

# Store the current handler on the signal such that it can be restored after process has terminated
original_handler = signal.getsignal(kill_signal)
signal.signal(kill_signal, kill_process)
try:
result = process.execute()
finally:
# If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset
if original_handler:
signal.signal(signal.SIGINT, original_handler)
store_provenance = inputs.get("metadata", {}).get("store_provenance", True)
if not store_provenance:
process.node._storable = False
process.node._unstorable_message = (
"cannot store node because it was run with `store_provenance=False`"
)
return result, process.node


def submit(
process: TYPE_SUBMIT_PROCESS,
inputs: dict[str, t.Any] | None = None,
*,
wait: bool = False,
wait_interval: int = 5,
parent_pid: int | None = None,
runner: "Runner" | None = None,
**kwargs: t.Any,
) -> ProcessNode:
"""Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter.

.. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of
the wrapping process itself, i.e. use ``self.submit``.

.. warning: submission of processes requires ``store_provenance=True``.

:param process: the process class, instance or builder to submit
:param inputs: the input dictionary to be passed to the process
:param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which
point the function returns the calculation node.
:param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``.
:param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument.
:return: the calculation node of the process
"""
inputs = prepare_inputs(inputs, **kwargs)

# Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the
# current process in the scope should be an instance of ``FunctionProcess``.
# if is_process_scoped() and not isinstance(Process.current(), FunctionProcess):
# raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead')

if not runner:
runner = manager.get_manager().get_runner()
assert runner.persister is not None, "runner does not have a persister"
assert runner.controller is not None, "runner does not have a controller"

process_inited = instantiate_process(
runner, process, parent_pid=parent_pid, **inputs
)

# If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this
# instead of raising, because in this way the user does not have to change the launcher when testing. The same goes
# for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation.
if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs:
_, node = run_get_node(process_inited)
return node

if not process_inited.metadata.store_provenance:
raise InvalidOperation("cannot submit a process with `store_provenance=False`")

runner.persister.save_checkpoint(process_inited)
process_inited.close()

# Do not wait for the future's result, because in the case of a single worker this would cock-block itself
runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True)
node = process_inited.node

if not wait:
return node

while not node.is_terminated:
LOGGER.report(
f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. "
f"Waiting for {wait_interval} seconds."
)
time.sleep(wait_interval)

return node
Loading
Loading