Skip to content

Commit

Permalink
Support multiple scheduler.
Browse files Browse the repository at this point in the history
1) can run multiple runner (daemon) for the scheduler, each runner will listen to the `scheduler_queue`, and the prefetch_count is set to 1, thus each runner can only launch one Scheduler process.
2) The scheduler process listen to the `workgraph_queue` to launch WorkGraph
3) the scheduler recieve rpc call to launch WorkGrpah
4) user can submit workgraph to the workgraph queue, or select the shceduler to run it by pk
  • Loading branch information
superstar54 committed Sep 2, 2024
1 parent d267263 commit a791b9d
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 125 deletions.
15 changes: 10 additions & 5 deletions aiida_workgraph/cli/cmd_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import click
from pathlib import Path
from aiida.cmdline.utils import decorators, echo
from aiida.cmdline.commands.cmd_daemon import validate_daemon_workers
from aiida.cmdline.params import options
from aiida_workgraph.engine.scheduler.client import get_scheduler_client
import sys
Expand Down Expand Up @@ -31,7 +32,7 @@ def scheduler():
@scheduler.command()
def worker():
"""Start the scheduler application."""
from aiida_workgraph.engine.launch import start_scheduler_worker
from aiida_workgraph.engine.scheduler.client import start_scheduler_worker

click.echo("Starting the scheduler worker...")

Expand All @@ -40,17 +41,20 @@ def worker():

@scheduler.command()
@click.option("--foreground", is_flag=True, help="Run in foreground.")
@click.argument("number", required=False, type=int, callback=validate_daemon_workers)
@options.TIMEOUT(default=None, required=False, type=int)
@decorators.with_dbenv()
@decorators.requires_broker
@decorators.check_circus_zmq_version
def start(foreground, timeout):
def start(foreground, number, timeout):
"""Start the scheduler application."""
from aiida_workgraph.engine.scheduler.client import start_scheduler_process

click.echo("Starting the scheduler process...")

client = get_scheduler_client()
client.start_daemon(foreground=foreground)
client.start_daemon(number_workers=number, foreground=foreground, timeout=timeout)
start_scheduler_process(number)


@scheduler.command()
Expand Down Expand Up @@ -86,18 +90,19 @@ def stop(ctx, no_wait, all_profiles, timeout):

@scheduler.command(hidden=True)
@click.option("--foreground", is_flag=True, help="Run in foreground.")
@click.argument("number", required=False, type=int, callback=validate_daemon_workers)
@decorators.with_dbenv()
@decorators.requires_broker
@decorators.check_circus_zmq_version
def start_circus(foreground):
def start_circus(foreground, number):
"""This will actually launch the circus daemon, either daemonized in the background or in the foreground.
If run in the foreground all logs are redirected to stdout.
.. note:: this should not be called directly from the commandline!
"""

get_scheduler_client()._start_daemon(foreground=foreground)
get_scheduler_client()._start_daemon(number_workers=number, foreground=foreground)


@scheduler.command()
Expand Down
77 changes: 0 additions & 77 deletions aiida_workgraph/engine/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,80 +170,3 @@ def submit(
time.sleep(wait_interval)

return node


def start_scheduler_worker(foreground: bool = False) -> None:
"""Start a scheduler worker for the currently configured profile.
:param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to
write to the scheduler log file.
"""
import asyncio
import signal
import sys

from aiida.common.log import configure_logging
from aiida.manage import get_config_option, get_manager
from aiida_workgraph.engine.scheduler import WorkGraphScheduler
from aiida_workgraph.engine.scheduler.client import (
get_scheduler_client,
get_scheduler,
)
from aiida.engine.processes.launcher import ProcessLauncher
from aiida.engine import persistence
from plumpy.persistence import LoadSaveContext
from aiida.engine.daemon.worker import shutdown_worker

daemon_client = get_scheduler_client()
configure_logging(
daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file
)

LOGGER.debug(f"sys.executable: {sys.executable}")
LOGGER.debug(f"sys.path: {sys.path}")

try:
manager = get_manager()
# runner = manager.create_daemon_runner()
runner = manager.create_runner(broker_submit=True)
manager.set_runner(runner)
except Exception:
LOGGER.exception("daemon worker failed to start")
raise

if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int):
LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit)
sys.setrecursionlimit(rlimit)

signals = (signal.SIGTERM, signal.SIGINT)
for s in signals:
# https://github.com/python/mypy/issues/12557
runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc]

try:
running_scheduler = get_scheduler()
runner_loop = runner.loop
task_receiver = ProcessLauncher(
loop=runner_loop,
persister=manager.get_persister(),
load_context=LoadSaveContext(runner=runner),
loader=persistence.get_object_loader(),
)
asyncio.run(
task_receiver._continue(
communicator=None, pid=running_scheduler, nowait=True
)
)
except ValueError:
print("Starting a new Scheduler")
process_inited = instantiate_process(runner, WorkGraphScheduler)
runner.loop.create_task(process_inited.step_until_terminated())

try:
LOGGER.info("Starting a daemon worker")
runner.start()
except SystemError as exception:
LOGGER.info("Received a SystemError: %s", exception)
runner.close()

LOGGER.info("Daemon worker started")
65 changes: 65 additions & 0 deletions aiida_workgraph/engine/override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from plumpy.process_comms import RemoteProcessThreadController
from typing import Any, Optional


def create_daemon_runner(
manager, queue_name: str = None, loop: Optional["asyncio.AbstractEventLoop"] = None
) -> "Runner":
"""Create and return a new daemon runner.
This is used by workers when the daemon is running and in testing.
:param loop: the (optional) asyncio event loop to use
:return: a runner configured to work in the daemon configuration
"""
from plumpy.persistence import LoadSaveContext
from aiida.engine import persistence
from aiida.engine.processes.launcher import ProcessLauncher
from plumpy.communications import convert_to_comm

runner = manager.create_runner(broker_submit=True, loop=loop)
runner_loop = runner.loop
# Listen for incoming launch requests
task_receiver = ProcessLauncher(
loop=runner_loop,
persister=manager.get_persister(),
load_context=LoadSaveContext(runner=runner),
loader=persistence.get_object_loader(),
)

def callback(_comm, msg):
print("Received message: {}".format(msg))
import asyncio

asyncio.run(task_receiver(_comm, msg))
print("task_receiver._continue done")
return True

assert runner.communicator is not None, "communicator not set for runner"
if queue_name is not None:
print("queue_name: {}".format(queue_name))
queue = runner.communicator._communicator.task_queue(
queue_name, prefetch_count=1
)
# queue.add_task_subscriber(callback)
# important to convert the callback
converted = convert_to_comm(task_receiver, runner.communicator._loop)
queue.add_task_subscriber(converted)
else:
runner.communicator.add_task_subscriber(task_receiver)
return runner


class ControllerWithQueueName(RemoteProcessThreadController):
def __init__(self, queue_name: str, **kwargs):
super().__init__(**kwargs)
self.queue_name = queue_name

def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]:
"""
Send a task to be performed using the communicator
:param message: the task message
:param no_reply: if True, this call will be fire-and-forget, i.e. no return value
:return: the response from the remote side (if no_reply=False)
"""
queue = self._communicator.task_queue(self.queue_name)
return queue.task_send(message, no_reply=no_reply)
108 changes: 101 additions & 7 deletions aiida_workgraph/engine/scheduler/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from aiida.common.exceptions import ConfigurationError
import os
from typing import Optional
from aiida.common.log import AIIDA_LOGGER
from typing import List

WORKGRAPH_BIN = shutil.which("workgraph")
LOGGER = AIIDA_LOGGER.getChild("engine.launch")


class SchedulerClient(DaemonClient):
Expand Down Expand Up @@ -102,6 +105,7 @@ def cmd_start_daemon(
self.profile.name,
"scheduler",
"start-circus",
str(number_workers),
]

if foreground:
Expand All @@ -114,7 +118,7 @@ def cmd_start_daemon_worker(self) -> list[str]:
"""Return the command to start a daemon worker process."""
return [self._workgraph_bin, "-p", self.profile.name, "scheduler", "worker"]

def _start_daemon(self, foreground: bool = False) -> None:
def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> None:
"""Start the daemon.
.. warning:: This will daemonize the current process and put it in the background. It is most likely not what
Expand Down Expand Up @@ -149,7 +153,7 @@ def _start_daemon(self, foreground: bool = False) -> None:
{
"cmd": " ".join(self.cmd_start_daemon_worker),
"name": self.daemon_name,
"numprocesses": 1,
"numprocesses": number_workers,
"virtualenv": self.virtualenv,
"copy_env": True,
"stdout_stream": {
Expand Down Expand Up @@ -210,7 +214,7 @@ def get_scheduler_client(profile_name: Optional[str] = None) -> "SchedulerClient
return SchedulerClient(profile)


def get_scheduler():
def get_scheduler() -> List[int]:
from aiida.orm import QueryBuilder
from aiida_workgraph.engine.scheduler import WorkGraphScheduler

Expand All @@ -224,7 +228,97 @@ def get_scheduler():
}
qb.append(WorkGraphScheduler, filters=filters, project=projections, tag="process")
results = qb.all()
if len(results) == 0:
raise ValueError("No scheduler found. Please start the scheduler first.")
scheduler_id = results[0][0]
return scheduler_id
pks = [r[0] for r in results]
return pks


def start_scheduler_worker(foreground: bool = False) -> None:
"""Start a scheduler worker for the currently configured profile.
:param foreground: If true, the logging will be configured to write to stdout, otherwise it will be configured to
write to the scheduler log file.
"""
import asyncio
import signal
import sys
from aiida_workgraph.engine.scheduler.client import get_scheduler_client
from aiida_workgraph.engine.override import create_daemon_runner

from aiida.common.log import configure_logging
from aiida.manage import get_config_option
from aiida.engine.daemon.worker import shutdown_worker

daemon_client = get_scheduler_client()
configure_logging(
daemon=not foreground, daemon_log_file=daemon_client.daemon_log_file
)

LOGGER.debug(f"sys.executable: {sys.executable}")
LOGGER.debug(f"sys.path: {sys.path}")

try:
manager = get_manager()
runner = create_daemon_runner(manager, queue_name="scheduler_queue")
except Exception:
LOGGER.exception("daemon worker failed to start")
raise

if isinstance(rlimit := get_config_option("daemon.recursion_limit"), int):
LOGGER.info("Setting maximum recursion limit of daemon worker to %s", rlimit)
sys.setrecursionlimit(rlimit)

signals = (signal.SIGTERM, signal.SIGINT)
for s in signals:
# https://github.com/python/mypy/issues/12557
runner.loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown_worker(runner))) # type: ignore[misc]

try:
LOGGER.info("Starting a daemon worker")
runner.start()
except SystemError as exception:
LOGGER.info("Received a SystemError: %s", exception)
runner.close()

LOGGER.info("Daemon worker started")


def start_scheduler_process(number: int = 1) -> None:
"""Start or restart the specified number of scheduler processes."""
from aiida_workgraph.engine.scheduler import WorkGraphScheduler
from aiida_workgraph.engine.scheduler.client import get_scheduler
from aiida_workgraph.utils.control import create_scheduler_action
from aiida_workgraph.engine.utils import instantiate_process

try:
schedulers: List[int] = get_scheduler()
existing_schedulers_count = len(schedulers)
print(
"Found {} existing scheduler(s): {}".format(
existing_schedulers_count, " ".join([str(pk) for pk in schedulers])
)
)

count = 0

# Restart existing schedulers if they exceed the number to start
if existing_schedulers_count > number:
for pk in schedulers[:number]:
create_scheduler_action(pk)
print(f"Scheduler with pk {pk} restarted.")
count += 1
else:
count = existing_schedulers_count

# Start new schedulers if more are needed
runner = get_manager().get_runner()
for i in range(count, number):
process_inited = instantiate_process(runner, WorkGraphScheduler)
process_inited.runner.persister.save_checkpoint(process_inited)
process_inited.close()
create_scheduler_action(process_inited.node.pk)
print(f"Scheduler with pk {process_inited.node.pk} started.")

print(f"Total schedulers running: {number}")

except Exception as e:
raise (f"An error occurred while starting schedulers: {e}")
8 changes: 4 additions & 4 deletions aiida_workgraph/engine/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,18 +1637,18 @@ def message_receive(
def call_on_receive_workgraph_message(self, _comm, msg):
"""Call on receive workgraph message."""
# self.report(f"Received workgraph message: {msg}")
pk = int(msg)
pk = msg["args"]["pid"]
# To avoid "DbNode is not persistent", we need to schedule the call
self._schedule_rpc(self.launch_workgraph, pk=pk)
return True

def add_workgraph_subsriber(self) -> None:
"""Add workgraph subscriber."""
queue_name = "scheduler_queue"
# self.report(f"Add workgraph subscriber on queue: {queue_name}")
queue_name = "workgraph_queue"
self.report(f"Add workgraph subscriber on queue: {queue_name}")
comm = self.runner.communicator._communicator
queue = comm.task_queue(queue_name, prefetch_count=1000)
queue.add_task_subscriber(self.callback)
queue.add_task_subscriber(self.call_on_receive_workgraph_message)

def finalize_workgraph(self, pk: int) -> t.Optional[ExitCode]:
""""""
Expand Down
Loading

0 comments on commit a791b9d

Please sign in to comment.