Skip to content

Commit

Permalink
Merge branch 'main' into feature/scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 authored Aug 27, 2024
2 parents d43816f + cff769e commit b2f6ad8
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 26 deletions.
2 changes: 1 addition & 1 deletion aiida_workgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
from .task import Task
from .decorator import task, build_task

__version__ = "0.3.22"
__version__ = "0.3.24"

__all__ = ["WorkGraph", "Task", "task", "build_task"]
1 change: 0 additions & 1 deletion aiida_workgraph/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def new(

# build the socket on the fly if the identifier is a callable
if callable(identifier):
print("identifier is callable", identifier)
identifier = build_socket_from_AiiDA(identifier)
# Call the original new method
return super().new(identifier, name, **kwargs)
58 changes: 36 additions & 22 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def define(cls, spec: WorkChainSpec) -> None:

spec.outputs.dynamic = True

spec.output_namespace("new_data", dynamic=True)
spec.output(
"execution_count",
valid_type=orm.Int,
Expand Down Expand Up @@ -481,7 +480,7 @@ def setup(self) -> None:
self.init_ctx(wgdata)
#
self.ctx._msgs = []
self.ctx._execution_count = 0
self.ctx._execution_count = 1
# init task results
self.set_task_results()
# data not to be persisted, because they are not serializable
Expand Down Expand Up @@ -667,7 +666,6 @@ def kill_task(self, name: str) -> None:
self.logger.error(f"Error in killing task {name}: {e}")

def continue_workgraph(self) -> None:
print("Continue workgraph.")
self.report("Continue workgraph.")
# self.update_workgraph_from_base()
task_to_run = []
Expand Down Expand Up @@ -1114,29 +1112,44 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
self.set_task_state_info(name, "process", process)
self.to_context(**{name: process})
elif task["metadata"]["node_type"].upper() in ["WHILE"]:
# check the conditions of the while task
should_run = self.should_run_while_task(name)
if not should_run:
self.set_task_state_info(name, "state", "FINISHED")
self.set_tasks_state(self.ctx._tasks[name]["children"], "SKIPPED")
self.update_parent_task_state(name)
self.report(
f"While Task {name}: Condition not fullilled, task finished. Skip all its children."
)
# TODO refactor this for while, if and zone
# in case of an empty zone, it will finish immediately
if self.are_childen_finished(name)[0]:
self.update_while_task_state(name)
else:
task["execution_count"] += 1
self.set_task_state_info(name, "state", "RUNNING")
self.continue_workgraph()
# check the conditions of the while task
should_run = self.should_run_while_task(name)
if not should_run:
self.set_task_state_info(name, "state", "FINISHED")
self.set_tasks_state(
self.ctx._tasks[name]["children"], "SKIPPED"
)
self.update_parent_task_state(name)
self.report(
f"While Task {name}: Condition not fullilled, task finished. Skip all its children."
)
else:
task["execution_count"] += 1
self.set_task_state_info(name, "state", "RUNNING")
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["IF"]:
should_run = self.should_run_if_task(name)
if should_run:
self.set_task_state_info(name, "state", "RUNNING")
else:
self.set_tasks_state(task["children"], "SKIPPED")
# in case of an empty zone, it will finish immediately
if self.are_childen_finished(name)[0]:
self.update_zone_task_state(name)
else:
should_run = self.should_run_if_task(name)
if should_run:
self.set_task_state_info(name, "state", "RUNNING")
else:
self.set_tasks_state(task["children"], "SKIPPED")
self.update_zone_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["ZONE"]:
self.set_task_state_info(name, "state", "RUNNING")
# in case of an empty zone, it will finish immediately
if self.are_childen_finished(name)[0]:
self.update_zone_task_state(name)
else:
self.set_task_state_info(name, "state", "RUNNING")
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["FROM_CONTEXT"]:
# get the results from the context
Expand Down Expand Up @@ -1484,7 +1497,8 @@ def finalize(self) -> t.Optional[ExitCode]:
)
self.out_many(group_outputs)
# output the new data
self.out("new_data", self.ctx._new_data)
if self.ctx._new_data:
self.out("new_data", self.ctx._new_data)
self.out("execution_count", orm.Int(self.ctx._execution_count).store())
self.report("Finalize workgraph.")
for _, task in self.ctx._tasks.items():
Expand Down
2 changes: 1 addition & 1 deletion docs/gallery/autogen/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def multiply(x, y):
# export the workgraph to html file so that it can be visualized in a browser
wg.to_html()
# comment out the following line to visualize the workgraph in jupyter-notebook
wg
# wg


######################################################################
Expand Down
9 changes: 9 additions & 0 deletions tests/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ def test_if_task(decorated_add, decorated_multiply, decorated_compare):
add3 = wg.add_task(decorated_add, name="add3", x=select1.outputs["result"], y=1)
wg.run()
assert add3.outputs["result"].value == 5


def test_empty_if_task():
"""Test the If task with no children."""

wg = WorkGraph("test_empty_if")
wg.add_task("If", name="if_true")
wg.run()
assert wg.state == "FINISHED"
2 changes: 1 addition & 1 deletion tests/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare):
add1.set_context({"result": "n"})
wg.add_link(multiply1.outputs["result"], add1.inputs["x"])
wg.submit(wait=True, timeout=100)
assert wg.execution_count == 3
assert wg.execution_count == 4
assert wg.tasks["add1"].outputs["result"].value == 29


Expand Down

0 comments on commit b2f6ad8

Please sign in to comment.