Skip to content

Commit

Permalink
Add checkpoint and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Aug 29, 2024
1 parent 7916e27 commit c36d72b
Show file tree
Hide file tree
Showing 3 changed files with 327 additions and 3 deletions.
7 changes: 4 additions & 3 deletions aiida_workgraph/engine/scheduler/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from aiida.manage.manager import get_manager
from aiida.common.exceptions import ConfigurationError
import os
from typing import Optional

WORKGRAPH_BIN = shutil.which("workgraph")

Expand Down Expand Up @@ -54,8 +55,8 @@ def filepaths(self):
},
},
"daemon": {
"log": str(DAEMON_LOG_DIR / f"aiida-{self.profile.name}.log"),
"pid": str(DAEMON_DIR / f"aiida-{self.profile.name}.pid"),
"log": str(DAEMON_LOG_DIR / f"aiida-scheduler-{self.profile.name}.log"),
"pid": str(DAEMON_DIR / f"aiida-scheduler-{self.profile.name}.pid"),
},
}

Expand Down Expand Up @@ -196,7 +197,7 @@ def _start_daemon(self, foreground: bool = False) -> None:
pidfile.unlink()


def get_scheduler_client(profile_name: str | None = None) -> "SchedulerClient":
def get_scheduler_client(profile_name: Optional[str] = None) -> "SchedulerClient":
"""Return the daemon client for the given profile or the current profile if not specified.
:param profile_name: Optional profile name to load.
Expand Down
23 changes: 23 additions & 0 deletions aiida_workgraph/engine/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ def load_instance_state(
self.resume()
# For other awaitables, because they exist in the db, we only need to re-register the callbacks
self._action_awaitables()
# load checkpoint
launched_workgraphs = self.node.base.extras.get("_launched_workgraphs", [])
for pk in launched_workgraphs:
print("load workgraph: ", pk)
node = load_node(pk)
wgdata = node.base.extras.get("_workgraph", None)
if wgdata is None:
self.launch_workgraph(pk)
else:
self.ctx._workgraph[pk] = deserialize_unsafe(wgdata)
print("continue workgraph: ", pk)
self.continue_workgraph(pk)

def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]:
"""
Expand Down Expand Up @@ -478,6 +490,7 @@ def setup(self) -> None:
"""Setup the variables in the context."""
# track if the awaitable callback is added to the runner

self.ctx.launched_workgraphs = []
self.ctx._workgraph = {}
self.ctx._max_number_awaitables = 10000
awaitable = Awaitable(
Expand All @@ -501,6 +514,9 @@ def launch_workgraph(self, pk: str) -> None:
"""Launch the workgraph."""
# create the workgraph process
self.report(f"Launch workgraph: {pk}")
# append the pk to the self.node.base.extras
self.ctx.launched_workgraphs.append(pk)
self.node.base.extras.set("_launched_workgraphs", self.ctx.launched_workgraphs)
self.init_ctx_workgraph(pk)
self.ctx._workgraph[pk]["_node"].set_process_state(Running.LABEL)
self.init_task_results(pk)
Expand Down Expand Up @@ -795,6 +811,7 @@ def update_task_state(
task.setdefault("results", None)

self.update_parent_task_state(pk, name)
self.save_workgraph_checkpoint(pk)
if continue_workgraph:
self.continue_workgraph(pk)

Expand All @@ -817,6 +834,12 @@ def set_normal_task_results(self, pk, name, results):
self.report(f"Workgraph: {pk}, Task: {name} finished.")
self.update_parent_task_state(pk, name)

def save_workgraph_checkpoint(self, pk: int):
"""Save the workgraph checkpoint."""
self.ctx._workgraph[pk]["_node"].set_extra(
"_checkpoint", serialize(self.ctx._workgraph[pk])
)

def update_parent_task_state(self, pk, name: str) -> None:
"""Update parent task state."""
parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"]
Expand Down
Loading

0 comments on commit c36d72b

Please sign in to comment.