diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index fd971af4..a2ca7596 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,6 +1,6 @@ name: ci -on: [pull_request] +on: [push, pull_request] jobs: diff --git a/README.md b/README.md index 4045b477..04870412 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # AiiDA-WorkGraph [![PyPI version](https://badge.fury.io/py/aiida-workgraph.svg)](https://badge.fury.io/py/aiida-workgraph) [![Unit test](https://github.com/aiidateam/aiida-workgraph/actions/workflows/ci.yaml/badge.svg)](https://github.com/aiidateam/aiida-workgraph/actions/workflows/ci.yaml) -[![codecov](https://codecov.io/gh/superstar54/aiida-workgraph/branch/main/graph/badge.svg)](https://codecov.io/gh/superstar54/aiida-workgraph) +[![codecov](https://codecov.io/gh/aiidateam/aiida-workgraph/branch/main/graph/badge.svg)](https://codecov.io/gh/aiidateam/aiida-workgraph) [![Docs status](https://readthedocs.org/projects/aiida-workgraph/badge)](http://aiida-workgraph.readthedocs.io/) Efficiently design and manage flexible workflows with AiiDA, featuring an interactive GUI, checkpoints, provenance tracking, error-resistant, and remote execution capabilities. diff --git a/aiida_workgraph/calculations/__init__.py b/aiida_workgraph/calculations/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/aiida_workgraph/calculations/python.py b/aiida_workgraph/calculations/python.py deleted file mode 100644 index b89a0862..00000000 --- a/aiida_workgraph/calculations/python.py +++ /dev/null @@ -1,324 +0,0 @@ -"""Calcjob to run a Python function on a remote computer.""" -from __future__ import annotations - -import pathlib -import typing as t - -from aiida.common.datastructures import CalcInfo, CodeInfo -from aiida.common.folders import Folder -from aiida.common.extendeddicts import AttributeDict -from aiida.engine import CalcJob, CalcJobProcessSpec -from aiida.orm import ( - Data, - SinglefileData, - Str, - List, - FolderData, - RemoteData, - to_aiida_type, -) -from aiida_workgraph.orm.function_data import PickledFunction, to_pickled_function - - -__all__ = ("PythonJob",) - - -class PythonJob(CalcJob): - """Calcjob to run a Python function on a remote computer.""" - - _internal_retrieve_list = [] - _retrieve_singlefile_list = [] - _retrieve_temporary_list = [] - - _DEFAULT_INPUT_FILE = "script.py" - _DEFAULT_OUTPUT_FILE = "aiida.out" - _DEFAULT_PARENT_FOLDER_NAME = "./parent_folder/" - - _default_parser = "workgraph.python" - - @classmethod - def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] - """Define the process specification, including its inputs, outputs and known exit codes. - - :param spec: the calculation job process spec to define. - """ - super().define(spec) - spec.input( - "function", - valid_type=PickledFunction, - serializer=to_pickled_function, - required=False, - ) - spec.input( - "function_source_code", - valid_type=Str, - serializer=to_aiida_type, - required=False, - ) - spec.input( - "function_name", valid_type=Str, serializer=to_aiida_type, required=False - ) - spec.input( - "process_label", valid_type=Str, serializer=to_aiida_type, required=False - ) - spec.input_namespace( - "function_kwargs", valid_type=Data, required=False - ) # , serializer=serialize_to_aiida_nodes) - spec.input( - "function_outputs", - valid_type=List, - default=lambda: List(), - required=False, - serializer=to_aiida_type, - help="The information of the output ports", - ) - spec.input( - "parent_folder", - valid_type=(RemoteData, FolderData, SinglefileData), - required=False, - help="Use a local or remote folder as parent folder (for restarts and similar)", - ) - spec.input( - "parent_folder_name", - valid_type=Str, - required=False, - serializer=to_aiida_type, - help="""Default name of the subfolder that you want to create in the working directory, - in which you want to place the files taken from parent_folder""", - ) - spec.input( - "parent_output_folder", - valid_type=Str, - default=None, - required=False, - serializer=to_aiida_type, - help="Name of the subfolder inside 'parent_folder' from which you want to copy the files", - ) - spec.input_namespace( - "upload_files", - valid_type=(FolderData, SinglefileData), - required=False, - help="The folder/files to upload", - ) - spec.input_namespace( - "copy_files", - valid_type=(RemoteData,), - required=False, - help="The folder/files to copy from the remote computer", - ) - spec.input( - "additional_retrieve_list", - valid_type=List, - default=None, - required=False, - serializer=to_aiida_type, - help="The names of the files to retrieve", - ) - spec.outputs.dynamic = True - # set default options (optional) - spec.inputs["metadata"]["options"]["parser_name"].default = "workgraph.python" - spec.inputs["metadata"]["options"]["input_filename"].default = "script.py" - spec.inputs["metadata"]["options"]["output_filename"].default = "aiida.out" - spec.inputs["metadata"]["options"]["resources"].default = { - "num_machines": 1, - "num_mpiprocs_per_machine": 1, - } - # start exit codes - marker for docs - spec.exit_code( - 310, - "ERROR_READING_OUTPUT_FILE", - invalidates_cache=True, - message="The output file could not be read.", - ) - spec.exit_code( - 320, - "ERROR_INVALID_OUTPUT", - invalidates_cache=True, - message="The output file contains invalid output.", - ) - spec.exit_code( - 321, - "ERROR_RESULT_OUTPUT_MISMATCH", - invalidates_cache=True, - message="The number of results does not match the number of outputs.", - ) - - def _build_process_label(self) -> str: - """Use the function name as the process label. - - :returns: The process label to use for ``ProcessNode`` instances. - """ - if "process_label" in self.inputs: - return self.inputs.process_label.value - else: - data = self.get_function_data() - return f"PythonJob<{data['name']}>" - - def on_create(self) -> None: - """Called when a Process is created.""" - - super().on_create() - self.node.label = self._build_process_label() - - def get_function_data(self) -> dict[str, t.Any]: - """Get the function data. - - :returns: The function data. - """ - if "function" in self.inputs: - metadata = self.inputs.function.metadata - metadata["source_code"] = ( - metadata["import_statements"] - + "\n" - + metadata["source_code_without_decorator"] - ) - return metadata - else: - return { - "source_code": self.inputs.function_source_code.value, - "name": self.inputs.function_name.value, - } - - def prepare_for_submission(self, folder: Folder) -> CalcInfo: - """Prepare the calculation for submission. - - 1) Write the python script to the folder. - 2) Write the inputs to a pickle file and save it to the folder. - - :param folder: A temporary folder on the local file system. - :returns: A :class:`aiida.common.datastructures.CalcInfo` instance. - """ - import cloudpickle as pickle - - dirpath = pathlib.Path(folder._abspath) - inputs: dict[str, t.Any] - - if self.inputs.function_kwargs: - inputs = dict(self.inputs.function_kwargs) - else: - inputs = {} - if "parent_folder_name" in self.inputs: - parent_folder_name = self.inputs.parent_folder_name.value - else: - parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME - function_data = self.get_function_data() - # create python script to run the function - script = f""" -import pickle - -# define the function -{function_data["source_code"]} - -# load the inputs from the pickle file -with open('inputs.pickle', 'rb') as handle: - inputs = pickle.load(handle) - -# run the function -result = {function_data["name"]}(**inputs) -# save the result as a pickle file -with open('results.pickle', 'wb') as handle: - pickle.dump(result, handle) -""" - # write the script to the folder - with folder.open(self.options.input_filename, "w", encoding="utf8") as handle: - handle.write(script) - # symlink = settings.pop('PARENT_FOLDER_SYMLINK', False) - symlink = True - - remote_copy_list = [] - local_copy_list = [] - remote_symlink_list = [] - remote_list = remote_symlink_list if symlink else remote_copy_list - - source = self.inputs.get("parent_folder", None) - - if source is not None: - if isinstance(source, RemoteData): - dirpath = pathlib.Path(source.get_remote_path()) - if self.inputs.parent_output_folder is not None: - dirpath = ( - pathlib.Path(source.get_remote_path()) - / self.inputs.parent_output_folder.value - ) - remote_list.append( - ( - source.computer.uuid, - str(dirpath), - parent_folder_name, - ) - ) - elif isinstance(source, FolderData): - dirname = ( - self.inputs.parent_output_folder.value - if self.inputs.parent_output_folder is not None - else "" - ) - local_copy_list.append((source.uuid, dirname, parent_folder_name)) - elif isinstance(source, SinglefileData): - local_copy_list.append((source.uuid, source.filename, source.filename)) - if "upload_files" in self.inputs: - upload_files = self.inputs.upload_files - for key, source in upload_files.items(): - # replace "_dot_" with "." in the key - key = key.replace("_dot_", ".") - if isinstance(source, FolderData): - local_copy_list.append((source.uuid, "", key)) - elif isinstance(source, SinglefileData): - local_copy_list.append( - (source.uuid, source.filename, source.filename) - ) - else: - raise ValueError( - f"""Input folder/file: {source} is not supported. -Only AiiDA SinglefileData and FolderData are allowed.""" - ) - if "copy_files" in self.inputs: - copy_files = self.inputs.copy_files - for key, source in copy_files.items(): - # replace "_dot_" with "." in the key - key = key.replace("_dot_", ".") - dirpath = pathlib.Path(source.get_remote_path()) - remote_list.append((source.computer.uuid, str(dirpath), key)) - # create pickle file for the inputs - input_values = {} - for key, value in inputs.items(): - if isinstance(value, Data) and hasattr(value, "value"): - # get the value of the pickled data - input_values[key] = value.value - # TODO: should check this recursively - elif isinstance(value, (AttributeDict, dict)): - # if the value is an AttributeDict, use recursively - input_values[key] = {k: v.value for k, v in value.items()} - else: - raise ValueError( - f"Input data {value} is not supported. Only AiiDA data Node with a value attribute is allowed. " - ) - # save the value as a pickle file, the path is absolute - filename = "inputs.pickle" - dirpath = pathlib.Path(folder._abspath) - with folder.open(filename, "wb") as handle: - pickle.dump(input_values, handle) - # create a singlefiledata object for the pickled data - file_data = SinglefileData(file=f"{dirpath}/{filename}") - file_data.store() - local_copy_list.append((file_data.uuid, file_data.filename, filename)) - - codeinfo = CodeInfo() - codeinfo.stdin_name = self.options.input_filename - codeinfo.stdout_name = self.options.output_filename - codeinfo.code_uuid = self.inputs.code.uuid - - calcinfo = CalcInfo() - calcinfo.codes_info = [codeinfo] - calcinfo.local_copy_list = local_copy_list - calcinfo.remote_copy_list = remote_copy_list - calcinfo.remote_symlink_list = remote_symlink_list - calcinfo.retrieve_list = ["results.pickle", self.options.output_filename] - if self.inputs.additional_retrieve_list is not None: - calcinfo.retrieve_list += self.inputs.additional_retrieve_list.get_list() - calcinfo.retrieve_list += self._internal_retrieve_list - - calcinfo.retrieve_temporary_list = self._retrieve_temporary_list - calcinfo.retrieve_singlefile_list = self._retrieve_singlefile_list - - return calcinfo diff --git a/aiida_workgraph/calculations/python_parser.py b/aiida_workgraph/calculations/python_parser.py deleted file mode 100644 index 17d71141..00000000 --- a/aiida_workgraph/calculations/python_parser.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Parser for an `PythonJob` job.""" -from aiida.parsers.parser import Parser -from aiida_workgraph.orm import general_serializer -from aiida.engine import ExitCode - - -class PythonParser(Parser): - """Parser for an `PythonJob` job.""" - - def parse(self, **kwargs): - """Parse the contents of the output files stored in the `retrieved` output node. - - The function_outputs could be a namespce, e.g., - function_outputs=[ - {"identifier": "workgraph.namespace", "name": "add_multiply"}, - {"name": "add_multiply.add"}, - {"name": "add_multiply.multiply"}, - {"name": "minus"}, - ] - """ - import pickle - - function_outputs = self.node.inputs.function_outputs.get_list() - # function_outputs exclude ['_wait', '_outputs', 'remote_folder', 'remote_stash', 'retrieved'] - self.output_list = [ - data - for data in function_outputs - if data["name"] - not in [ - "_wait", - "_outputs", - "remote_folder", - "remote_stash", - "retrieved", - "exit_code", - ] - ] - # first we remove nested outputs, e.g., "add_multiply.add" - top_level_output_list = [ - output for output in self.output_list if "." not in output["name"] - ] - exit_code = 0 - try: - with self.retrieved.base.repository.open("results.pickle", "rb") as handle: - results = pickle.load(handle) - if isinstance(results, tuple): - if len(top_level_output_list) != len(results): - self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH - for i in range(len(top_level_output_list)): - top_level_output_list[i]["value"] = self.serialize_output( - results[i], top_level_output_list[i] - ) - elif isinstance(results, dict) and len(top_level_output_list) > 1: - # pop the exit code if it exists - exit_code = results.pop("exit_code", 0) - for output in top_level_output_list: - if output.get("required", False): - if output["name"] not in results: - self.exit_codes.ERROR_MISSING_OUTPUT - output["value"] = self.serialize_output( - results.pop(output["name"]), output - ) - # if there are any remaining results, raise an warning - if results: - self.logger.warning( - f"Found extra results that are not included in the output: {results.keys()}" - ) - elif isinstance(results, dict) and len(top_level_output_list) == 1: - exit_code = results.pop("exit_code", 0) - # if output name in results, use it - if top_level_output_list[0]["name"] in results: - top_level_output_list[0]["value"] = self.serialize_output( - results[top_level_output_list[0]["name"]], - top_level_output_list[0], - ) - # otherwise, we assume the results is the output - else: - top_level_output_list[0]["value"] = self.serialize_output( - results, top_level_output_list[0] - ) - elif len(top_level_output_list) == 1: - # otherwise, we assume the results is the output - top_level_output_list[0]["value"] = self.serialize_output( - results, top_level_output_list[0] - ) - else: - raise ValueError( - "The number of results does not match the number of outputs." - ) - for output in top_level_output_list: - self.out(output["name"], output["value"]) - if exit_code: - if isinstance(exit_code, dict): - exit_code = ExitCode(exit_code["status"], exit_code["message"]) - elif isinstance(exit_code, int): - exit_code = ExitCode(exit_code) - return exit_code - except OSError: - return self.exit_codes.ERROR_READING_OUTPUT_FILE - except ValueError as exception: - self.logger.error(exception) - return self.exit_codes.ERROR_INVALID_OUTPUT - - def find_output(self, name): - """Find the output with the given name.""" - for output in self.output_list: - if output["name"] == name: - return output - return None - - def serialize_output(self, result, output): - """Serialize outputs.""" - - name = output["name"] - if output.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE": - if isinstance(result, dict): - serialized_result = {} - for key, value in result.items(): - full_name = f"{name}.{key}" - full_name_output = self.find_output(full_name) - if ( - full_name_output - and full_name_output.get("identifier", "Any").upper() - == "WORKGRAPH.NAMESPACE" - ): - serialized_result[key] = self.serialize_output( - value, full_name_output - ) - else: - serialized_result[key] = general_serializer(value) - return serialized_result - else: - self.exit_codes.ERROR_INVALID_OUTPUT - else: - return general_serializer(result) diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index d9fc24ab..c5551b48 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -281,7 +281,7 @@ def build_task_from_AiiDA( def build_pythonjob_task(func: Callable) -> Task: """Build PythonJob task from function.""" - from aiida_workgraph.calculations.python import PythonJob + from aiida_pythonjob import PythonJob from copy import deepcopy # if the function is not a task, build a task from the function @@ -297,16 +297,14 @@ def build_pythonjob_task(func: Callable) -> Task: } _, tdata_py = build_task_from_AiiDA(tdata) tdata = deepcopy(func.tdata) - function_kwargs = [ + function_inputs = [ name for name in tdata["inputs"] if name not in ["_wait", "_outputs"] ] # merge the inputs and outputs from the PythonJob task to the function task # skip the already existed inputs and outputs for input in [ {"identifier": "workgraph.string", "name": "computer"}, - {"identifier": "workgraph.string", "name": "code_label"}, - {"identifier": "workgraph.string", "name": "code_path"}, - {"identifier": "workgraph.string", "name": "prepend_text"}, + {"identifier": "workgraph.any", "name": "command_info"}, ]: input["list_index"] = len(tdata["inputs"]) + 1 tdata["inputs"][input["name"]] = input @@ -325,7 +323,7 @@ def build_pythonjob_task(func: Callable) -> Task: tdata["inputs"]["copy_files"]["link_limit"] = 1e6 # append the kwargs of the PythonJob task to the function task kwargs = tdata["kwargs"] - kwargs.extend(["computer", "code_label", "code_path", "prepend_text"]) + kwargs.extend(["computer", "command_info"]) kwargs.extend(tdata_py["kwargs"]) tdata["kwargs"] = kwargs tdata["metadata"]["task_type"] = "PYTHONJOB" @@ -336,7 +334,7 @@ def build_pythonjob_task(func: Callable) -> Task: } task = create_task(tdata) task.is_aiida_component = True - task.function_kwargs = function_kwargs + task.function_inputs = function_inputs return task, tdata @@ -559,8 +557,8 @@ def decorator_task( identifier: Optional[str] = None, task_type: str = "Normal", properties: Optional[List[Tuple[str, str]]] = None, - inputs: Optional[List[Tuple[str, str]]] = None, - outputs: Optional[List[Tuple[str, str]]] = None, + inputs: Optional[List[str | dict]] = None, + outputs: Optional[List[str | dict]] = None, error_handlers: Optional[List[Dict[str, Any]]] = None, catalog: str = "Others", ) -> Callable: @@ -576,6 +574,12 @@ def decorator_task( outputs (list): task outputs """ + if inputs: + inputs = validate_task_inout(inputs, "inputs") + + if outputs: + outputs = validate_task_inout(outputs, "outputs") + def decorator(func): nonlocal identifier, task_type diff --git a/aiida_workgraph/engine/awaitable_manager.py b/aiida_workgraph/engine/awaitable_manager.py new file mode 100644 index 00000000..414933a5 --- /dev/null +++ b/aiida_workgraph/engine/awaitable_manager.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +import functools +from aiida.orm import ProcessNode +from aiida.engine.processes.workchains.awaitable import ( + Awaitable, + AwaitableAction, + AwaitableTarget, + construct_awaitable, +) +from aiida.orm import load_node +from aiida.common import exceptions +from typing import Any, List +import logging + + +class AwaitableManager: + """Handles awaitable objects and their resolutions.""" + + def __init__( + self, _awaitables, runner, logger: logging.Logger, process, ctx_manager + ): + self.runner = runner + self.logger = logger + self.process = process + self.ctx_manager = ctx_manager + self.ctx = ctx_manager.ctx + # awaitables that are persisted + self._awaitables: List[Awaitable] = _awaitables + # awaitables that are not persisted, because they are not serializable + # but don't worry, because we re-register them when loading the process + self.not_persisted_awaitables = {} + self.ctx._awaitable_actions = [] + + def insert_awaitable(self, awaitable: Awaitable) -> None: + """Insert an awaitable that should be terminated before before continuing to the next step. + + :param awaitable: the thing to await + """ + ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key) + + # Already assign the awaitable itself to the location in the context container where it is supposed to end up + # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the + # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the + # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. + if awaitable.action == AwaitableAction.ASSIGN: + ctx[key] = awaitable + elif awaitable.action == AwaitableAction.APPEND: + ctx.setdefault(key, []).append(awaitable) + else: + raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + + self._awaitables.append( + awaitable + ) # add only if everything went ok, otherwise we end up in an inconsistent state + self.update_process_status() + + def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: + """Resolve an awaitable. + + Precondition: must be an awaitable that was previously inserted. + + :param awaitable: the awaitable to resolve + :param value: the value to assign to the awaitable + """ + ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key) + + if awaitable.action == AwaitableAction.ASSIGN: + ctx[key] = value + elif awaitable.action == AwaitableAction.APPEND: + # Find the same awaitable inserted in the context + container = ctx[key] + for index, placeholder in enumerate(container): + if ( + isinstance(placeholder, Awaitable) + and placeholder.pk == awaitable.pk + ): + container[index] = value + break + else: + raise AssertionError( + f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`" + ) + else: + raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + + awaitable.resolved = True + # remove awaitabble from the list + self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk] + + if not self.process.has_terminated(): + # the process may be terminated, for example, if the process was killed or excepted + # then we should not try to update it + self.update_process_status() + + def update_process_status(self) -> None: + """Set the process status with a message accounting the current sub processes that we are waiting for.""" + if self._awaitables: + status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" + self.process.node.set_process_status(status) + else: + self.process.node.set_process_status(None) + + def action_awaitables(self) -> None: + """Handle the awaitables that are currently registered with the work chain. + + Depending on the class type of the awaitable's target a different callback + function will be bound with the awaitable and the runner will be asked to + call it when the target is completed + """ + for awaitable in self._awaitables: + # if the waitable already has a callback, skip + if awaitable.pk in self.ctx._awaitable_actions: + continue + if awaitable.target == AwaitableTarget.PROCESS: + callback = functools.partial( + self.process.call_soon, self.on_awaitable_finished, awaitable + ) + self.runner.call_on_process_finish(awaitable.pk, callback) + self.ctx._awaitable_actions.append(awaitable.pk) + elif awaitable.target == "asyncio.tasks.Task": + # this is a awaitable task, the callback function is already set + self.ctx._awaitable_actions.append(awaitable.pk) + else: + assert f"invalid awaitable target '{awaitable.target}'" + + def on_awaitable_finished(self, awaitable: Awaitable) -> None: + """Callback function, for when an awaitable process instance is completed. + + The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all + awaitables have been dealt with, the work chain process is resumed. + + :param awaitable: an Awaitable instance + """ + self.logger.debug(f"Awaitable {awaitable.key} finished.") + + if isinstance(awaitable.pk, int): + self.logger.info( + "received callback that awaitable with key {} and pk {} has terminated".format( + awaitable.key, awaitable.pk + ) + ) + try: + node = load_node(awaitable.pk) + except (exceptions.MultipleObjectsError, exceptions.NotExistent): + raise ValueError( + f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance" + ) + + if awaitable.outputs: + value = { + entry.link_label: entry.node + for entry in node.base.links.get_outgoing() + } + else: + value = node # type: ignore + else: + # In this case, the pk and key are the same. + self.logger.info( + "received callback that awaitable {} has terminated".format( + awaitable.key + ) + ) + try: + # if awaitable is cancelled, the result is None + if awaitable.cancelled(): + self.process.task_manager.set_task_state_info( + awaitable.key, "state", "KILLED" + ) + # set child tasks state to SKIPPED + self.process.task_manager.set_tasks_state( + self.ctx._connectivity["child_node"][awaitable.key], + "SKIPPED", + ) + self.process.report(f"Task: {awaitable.key} cancelled.") + else: + results = awaitable.result() + self.process.task_manager.update_normal_task_state( + awaitable.key, results + ) + except Exception as e: + self.logger.error(f"Error in awaitable {awaitable.key}: {e}") + self.process.task_manager.set_task_state_info( + awaitable.key, "state", "FAILED" + ) + # set child tasks state to SKIPPED + self.process.task_manager.set_tasks_state( + self.ctx._connectivity["child_node"][awaitable.key], + "SKIPPED", + ) + self.process.report(f"Task: {awaitable.key} failed: {e}") + self.process.error_handler_manager.run_error_handlers(awaitable.key) + value = None + + self.resolve_awaitable(awaitable, value) + + # node finished, update the task state and result + # udpate the task state + self.process.task_manager.update_task_state(awaitable.key) + # try to resume the workgraph, if the workgraph is already resumed + # by other awaitable, this will not work + try: + self.process.resume() + except Exception as e: + print(e) + + def construct_awaitable_function( + self, name: str, awaitable_target: Awaitable + ) -> None: + """Construct the awaitable function.""" + awaitable = Awaitable( + **{ + "pk": name, + "action": AwaitableAction.ASSIGN, + "target": "asyncio.tasks.Task", + "outputs": False, + } + ) + awaitable_target.key = name + awaitable_target.pk = name + awaitable_target.action = AwaitableAction.ASSIGN + awaitable_target.add_done_callback(self.on_awaitable_finished) + return awaitable + + def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: + """Add a dictionary of awaitables to the context. + + This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will + assign a certain value to the corresponding key in the context of the work graph. + """ + for key, value in kwargs.items(): + awaitable = construct_awaitable(value) + awaitable.key = key + self.insert_awaitable(awaitable) diff --git a/aiida_workgraph/engine/context_manager.py b/aiida_workgraph/engine/context_manager.py new file mode 100644 index 00000000..67b3740c --- /dev/null +++ b/aiida_workgraph/engine/context_manager.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from aiida.common.extendeddicts import AttributeDict +from typing import Any +from aiida_workgraph.utils import get_nested_dict +import logging + + +class ContextManager: + """Manages the context for the WorkGraphEngine.""" + + _CONTEXT = "CONTEXT" + + def __init__(self, _context, process, logger: logging.Logger): + self.process = process + self._context = _context + self.logger = logger + + @property + def ctx(self) -> AttributeDict: + """Access the context.""" + return self._context + + def resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: + """ + Returns a reference to a sub-dictionary of the context and the last key, + after resolving a potentially segmented key where required sub-dictionaries are created as needed. + + :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary + """ + ctx = self.ctx + ctx_path = key.split(".") + + for index, path in enumerate(ctx_path[:-1]): + try: + ctx = ctx[path] + except KeyError: # see below why this is the only exception we have to catch here + ctx[ + path + ] = AttributeDict() # create the sub-dict and update the context + ctx = ctx[path] + continue + + # Notes: + # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking + # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables + # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself + # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable + # would be an AttributeDict we can append things to it since the order of tasks is maintained. + if type(ctx) != AttributeDict: # pylint: disable=C0123 + raise ValueError( + f"Can not update the context for key `{key}`: " + f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict' + ) + + return ctx, ctx_path[-1] + + def update_context_variable(self, value: Any) -> Any: + """Replace placeholders in the value with actual context values.""" + if isinstance(value, dict): + return {k: self.update_context_variable(v) for k, v in value.items()} + elif ( + isinstance(value, str) + and value.strip().startswith("{{") + and value.strip().endswith("}}") + ): + name = value[2:-2].strip() + return get_nested_dict(self.ctx, name) + return value diff --git a/aiida_workgraph/engine/error_handler_manager.py b/aiida_workgraph/engine/error_handler_manager.py new file mode 100644 index 00000000..a1ec80c7 --- /dev/null +++ b/aiida_workgraph/engine/error_handler_manager.py @@ -0,0 +1,54 @@ +from __future__ import annotations + + +class ErrorHandlerManager: + def __init__(self, process, ctx_manager, logger): + self.process = process + self.ctx_manager = ctx_manager + self.ctx = ctx_manager.ctx + self.logger = logger + + def run_error_handlers(self, task_name: str) -> None: + """Run error handlers for a task.""" + + node = self.process.task_manager.get_task_state_info(task_name, "process") + if not node or not node.exit_status: + return + # error_handlers from the task + for _, data in self.ctx._tasks[task_name]["error_handlers"].items(): + if node.exit_status in data.get("exit_codes", []): + handler = data["handler"] + self.run_error_handler(handler, data, task_name) + return + # error_handlers from the workgraph + for _, data in self.ctx._error_handlers.items(): + if node.exit_code.status in data["tasks"].get(task_name, {}).get( + "exit_codes", [] + ): + handler = data["handler"] + metadata = data["tasks"][task_name] + self.run_error_handler(handler, metadata, task_name) + return + + def run_error_handler(self, handler: dict, metadata: dict, task_name: str) -> None: + """Run the error handler for a task.""" + from inspect import signature + from aiida_workgraph.utils import get_executor + + handler, _ = get_executor(handler) + handler_sig = signature(handler) + metadata.setdefault("retry", 0) + self.process.report(f"Run error handler: {handler.__name__}") + if metadata["retry"] < metadata["max_retries"]: + task = self.process.task_manager.get_task(task_name) + try: + if "engine" in handler_sig.parameters: + msg = handler(task, engine=self, **metadata.get("kwargs", {})) + else: + msg = handler(task, **metadata.get("kwargs", {})) + self.process.task_manager.update_task(task) + if msg: + self.process.report(msg) + metadata["retry"] += 1 + except Exception as e: + self.process.report(f"Error in running error handler: {e}") diff --git a/aiida_workgraph/engine/task_manager.py b/aiida_workgraph/engine/task_manager.py new file mode 100644 index 00000000..c8f5153a --- /dev/null +++ b/aiida_workgraph/engine/task_manager.py @@ -0,0 +1,946 @@ +from __future__ import annotations + +from aiida.orm import ProcessNode, Data +from aiida.engine import run_get_node +from typing import Any, Dict, List, Optional, Tuple, Callable, Union, Sequence +from aiida_workgraph.utils import create_and_pause_process +from aiida_workgraph.task import Task +from aiida_workgraph.utils import get_nested_dict +from aiida.orm.utils.serialize import deserialize_unsafe, serialize +import asyncio +from aiida.engine.processes.exit_code import ExitCode +from aiida_workgraph.executors.monitors import monitor + +MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}." + +process_task_types = [ + "CALCJOB", + "WORKCHAIN", + "GRAPH_BUILDER", + "WORKGRAPH", + "PYTHONJOB", + "SHELLJOB", +] + + +class TaskManager: + """Manages task execution, state updates, and error handling.""" + + def __init__(self, ctx_manager, logger, runner, process, awaitable_manager): + self.ctx_manager = ctx_manager + self.ctx = ctx_manager.ctx + self.logger = logger + self.runner = runner + self.process = process + self.awaitable_manager = awaitable_manager + + def get_task(self, name: str): + """Get task from the context.""" + task = Task.from_dict(self.ctx._tasks[name]) + # update task results + for output in task.outputs: + output.value = get_nested_dict( + self.ctx._tasks[name]["results"], + output.name, + default=output.value, + ) + return task + + def reset_task( + self, + name: str, + reset_process: bool = True, + recursive: bool = True, + reset_execution_count: bool = True, + ) -> None: + """Reset task state and remove it from the executed task. + If recursive is True, reset its child tasks.""" + + self.set_task_state_info(name, "state", "PLANNED") + if reset_process: + self.set_task_state_info(name, "process", None) + self.remove_executed_task(name) + # self.logger.debug(f"Task {name} action: RESET.") + # if the task is a while task, reset its child tasks + if self.ctx._tasks[name]["metadata"]["node_type"].upper() == "WHILE": + if reset_execution_count: + self.ctx._tasks[name]["execution_count"] = 0 + for child_task in self.ctx._tasks[name]["children"]: + self.reset_task(child_task, reset_process=False, recursive=False) + elif self.ctx._tasks[name]["metadata"]["node_type"].upper() in [ + "IF", + "ZONE", + ]: + for child_task in self.ctx._tasks[name]["children"]: + self.reset_task(child_task, reset_process=False, recursive=False) + if recursive: + # reset its child tasks + names = self.ctx._connectivity["child_node"][name] + for name in names: + self.reset_task(name, recursive=False) + + def remove_executed_task(self, name: str) -> None: + """Remove labels with name from executed tasks.""" + self.ctx._executed_tasks = [ + label for label in self.ctx._executed_tasks if label.split(".")[0] != name + ] + + def is_task_ready_to_run(self, name: str) -> Tuple[bool, Optional[str]]: + """Check if the task ready to run. + For normal task and a zone task, we need to check its input tasks in the connectivity["zone"]. + For task inside a zone, we need to check if the zone (parent task) is ready. + """ + parent_task = self.ctx._tasks[name]["parent_task"] + # input_tasks, parent_task, conditions + parent_states = [True, True] + # if the task belongs to a parent zone + if parent_task[0]: + state = self.get_task_state_info(parent_task[0], "state") + if state not in ["RUNNING"]: + parent_states[1] = False + # check the input tasks of the zone + # check if the zone input tasks are ready + for child_task_name in self.ctx._connectivity["zone"][name]["input_tasks"]: + if self.get_task_state_info(child_task_name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + parent_states[0] = False + break + + return all(parent_states), parent_states + + def set_task_results(self) -> None: + for name, task in self.ctx._tasks.items(): + if self.get_task_state_info(name, "action").upper() == "RESET": + self.reset_task(task["name"]) + self.update_task_state(name) + + def task_set_context(self, name: str) -> None: + """Export task results to the context based on context mapping.""" + from aiida_workgraph.utils import update_nested_dict + + items = self.ctx._tasks[name]["context_mapping"] + for key, result_name in items.items(): + result = self.ctx._tasks[name]["results"][result_name] + update_nested_dict(self.ctx, key, result) + + def get_task_state_info(self, name: str, key: str) -> str: + """Get task state info from ctx.""" + + value = self.ctx._tasks[name].get(key, None) + if key == "process" and value is not None: + value = deserialize_unsafe(value) + return value + + def set_task_state_info(self, name: str, key: str, value: any) -> None: + """Set task state info to ctx and base.extras. + We task state to the base.extras, so that we can access outside the engine""" + + if key == "process": + value = serialize(value) + self.process.node.base.extras.set(f"_task_{key}_{name}", value) + else: + self.process.node.base.extras.set(f"_task_{key}_{name}", value) + self.ctx._tasks[name][key] = value + + def set_tasks_state( + self, tasks: Union[List[str], Sequence[str]], value: str + ) -> None: + """Set tasks state""" + for name in tasks: + self.set_task_state_info(name, "state", value) + if "children" in self.ctx._tasks[name]: + self.set_tasks_state(self.ctx._tasks[name]["children"], value) + + def is_workgraph_finished(self) -> bool: + """Check if the workgraph is finished. + For `while` workgraph, we need check its conditions""" + is_finished = True + failed_tasks = [] + for name, task in self.ctx._tasks.items(): + # self.update_task_state(name) + if self.get_task_state_info(task["name"], "state") in [ + "RUNNING", + "CREATED", + "PLANNED", + "READY", + ]: + is_finished = False + elif self.get_task_state_info(task["name"], "state") == "FAILED": + failed_tasks.append(name) + if is_finished: + if self.ctx._workgraph["workgraph_type"].upper() == "WHILE": + should_run = self.check_while_conditions() + is_finished = not should_run + if self.ctx._workgraph["workgraph_type"].upper() == "FOR": + should_run = self.check_for_conditions() + is_finished = not should_run + if is_finished and len(failed_tasks) > 0: + message = f"WorkGraph finished, but tasks: {failed_tasks} failed. Thus all their child tasks are skipped." + self.process.report(message) + result = ExitCode(302, message) + else: + result = None + return is_finished, result + + def continue_workgraph(self) -> None: + self.process.report("Continue workgraph.") + # self.update_workgraph_from_base() + task_to_run = [] + for name, task in self.ctx._tasks.items(): + # update task state + if ( + self.get_task_state_info(task["name"], "state") + in [ + "CREATED", + "RUNNING", + "FINISHED", + "FAILED", + "SKIPPED", + ] + or name in self.ctx._executed_tasks + ): + continue + ready, _ = self.is_task_ready_to_run(name) + if ready: + task_to_run.append(name) + # + self.process.report("tasks ready to run: {}".format(",".join(task_to_run))) + self.run_tasks(task_to_run) + + def run_tasks(self, names: List[str], continue_workgraph: bool = True) -> None: + """Run tasks. + Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder, + WorkGraph, PythonJob, ShellJob, While, If, Zone, GetContext, SetContext, Normal. + + """ + from aiida_workgraph.utils import ( + get_executor, + update_nested_dict_with_special_keys, + ) + + for name in names: + # skip if the max number of awaitables is reached + task = self.ctx._tasks[name] + if task["metadata"]["node_type"].upper() in process_task_types: + if len(self.process._awaitables) >= self.ctx._max_number_awaitables: + print( + MAX_NUMBER_AWAITABLES_MSG.format( + self.ctx._max_number_awaitables, name + ) + ) + continue + # skip if the task is already executed + # or if the task is in a skippped state + if name in self.ctx._executed_tasks or self.get_task_state_info( + name, "state" + ) in ["SKIPPED"]: + continue + self.ctx._executed_tasks.append(name) + print("-" * 60) + + self.process.report( + f"Run task: {name}, type: {task['metadata']['node_type']}" + ) + executor, _ = get_executor(task["executor"]) + args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) + for i, key in enumerate(self.ctx._tasks[name]["args"]): + kwargs[key] = args[i] + # update the port namespace + kwargs = update_nested_dict_with_special_keys(kwargs) + # kwargs["meta.label"] = name + # output must be a Data type or a mapping of {string: Data} + task["results"] = {} + task_type = task["metadata"]["node_type"].upper() + if task_type == "NODE": + self.execute_node_task( + name, executor, kwargs, var_args, var_kwargs, continue_workgraph + ) + elif task_type == "DATA": + self.execute_data_task(name, executor, args, kwargs, continue_workgraph) + elif task_type in ["CALCFUNCTION", "WORKFUNCTION"]: + self.execute_function_task( + name, executor, kwargs, var_kwargs, continue_workgraph + ) + elif task_type in ["CALCJOB", "WORKCHAIN"]: + self.execute_process_task(name, executor, kwargs) + elif task_type == "PYTHONJOB": + self.execute_python_job_task(task, kwargs, var_kwargs) + elif task_type == "GRAPH_BUILDER": + self.execute_graph_builder_task( + task, executor, kwargs, var_args, var_kwargs + ) + elif task_type in ["WORKGRAPH"]: + self.execute_workgraph_task(task, kwargs) + elif task_type == "SHELLJOB": + self.execute_shell_job_task(task, kwargs) + elif task_type == "WHILE": + self.execute_while_task(task) + elif task_type == "IF": + self.execute_if_task(task) + elif task_type == "ZONE": + self.execute_zone_task(task) + elif task_type == "GET_CONTEXT": + self.execute_get_context_task(task, kwargs) + elif task_type == "SET_CONTEXT": + self.execute_set_context_task(task, kwargs) + elif task_type == "AWAITABLE": + self.execute_awaitable_task( + task, executor, args, kwargs, var_args, var_kwargs + ) + elif task_type == "MONITOR": + self.execute_monitor_task( + task, executor, args, kwargs, var_args, var_kwargs + ) + elif task_type == "NORMAL": + self.execute_normal_task( + task, + executor, + args, + kwargs, + var_args, + var_kwargs, + continue_workgraph, + ) + else: + self.process.report(f"Unknown task type {task_type}") + return self.process.exit_codes.UNKNOWN_TASK_TYPE + + def execute_node_task( + self, name, executor, kwargs, var_args, var_kwargs, continue_workgraph + ): + """Execute a NODE task.""" + results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) + self.set_task_state_info(name, "process", results) + self.update_task_state(name) + if continue_workgraph: + self.continue_workgraph() + + def execute_data_task(self, name, executor, args, kwargs, continue_workgraph): + """Execute a DATA task.""" + from aiida_workgraph.utils import create_data_node + + for key in self.ctx._tasks[name]["args"]: + kwargs.pop(key, None) + results = create_data_node(executor, args, kwargs) + self.set_task_state_info(name, "process", results) + self.update_task_state(name) + self.ctx._new_data[name] = results + if continue_workgraph: + self.continue_workgraph() + + def execute_function_task( + self, name, executor, kwargs, var_kwargs, continue_workgraph + ): + """Execute a CalcFunction or WorkFunction task.""" + kwargs.setdefault("metadata", {}) + kwargs["metadata"].update({"call_link_label": name}) + try: + # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node + if var_kwargs is None: + results, process = run_get_node(executor, **kwargs) + else: + results, process = run_get_node(executor, **kwargs, **var_kwargs) + process.label = name + self.set_task_state_info(name, "process", process) + self.update_task_state(name) + except Exception as e: + self.logger.error(f"Error in task {name}: {e}") + self.update_task_state(name, success=False) + # exclude the current tasks from the next run + if continue_workgraph: + self.continue_workgraph() + + def execute_process_task(self, name, executor, kwargs): + """Execute a CalcJob or WorkChain task.""" + # process = run_get_node(executor, *args, **kwargs) + kwargs.setdefault("metadata", {}) + kwargs["metadata"].update({"call_link_label": name}) + # transfer the args to kwargs + if self.get_task_state_info(name, "action").upper() == "PAUSE": + self.set_task_state_info(name, "action", "") + self.process.report(f"Task {name} is created and paused.") + process = create_and_pause_process( + self.runner, + executor, + kwargs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(name, "state", "CREATED") + process = process.node + else: + process = self.process.submit(executor, **kwargs) + self.set_task_state_info(name, "state", "RUNNING") + process.label = name + self.set_task_state_info(name, "process", process) + self.awaitable_manager.to_context(**{name: process}) + + def execute_graph_builder_task(self, task, executor, kwargs, var_args, var_kwargs): + """Execute a GraphBuilder task.""" + name = task["name"] + wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs) + wg.name = name + wg.group_outputs = self.ctx._tasks[name]["metadata"]["group_outputs"] + wg.parent_uuid = self.process.node.uuid + inputs = wg.prepare_inputs(metadata={"call_link_label": name}) + process = self.process.submit(self.process.__class__, inputs=inputs) + self.set_task_state_info(name, "process", process) + self.set_task_state_info(name, "state", "RUNNING") + self.awaitable_manager.to_context(**{name: process}) + + def execute_workgraph_task(self, task, kwargs): + from .utils import prepare_for_workgraph_task + + name = task["name"] + inputs, _ = prepare_for_workgraph_task(task, kwargs) + process = self.process.submit(self.process.__class__, inputs=inputs) + self.set_task_state_info(name, "process", process) + self.set_task_state_info(name, "state", "RUNNING") + self.awaitable_manager.to_context(**{name: process}) + + def execute_python_job_task(self, task, kwargs, var_kwargs): + """Execute a PythonJob task.""" + from aiida_pythonjob import PythonJob + from .utils import prepare_for_python_task + + name = task["name"] + inputs = prepare_for_python_task(task, kwargs, var_kwargs) + # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs + if self.get_task_state_info(name, "action").upper() == "PAUSE": + self.set_task_state_info(name, "action", "") + self.process.report(f"Task {name} is created and paused.") + process = create_and_pause_process( + self.runner, + PythonJob, + inputs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(name, "state", "CREATED") + process = process.node + else: + process = self.process.submit(PythonJob, **inputs) + self.set_task_state_info(name, "state", "RUNNING") + process.label = name + self.set_task_state_info(name, "process", process) + self.awaitable_manager.to_context(**{name: process}) + + def execute_shell_job_task(self, task, kwargs): + """Execute a ShellJob task.""" + from aiida_shell.calculations.shell import ShellJob + from .utils import prepare_for_shell_task + + name = task["name"] + inputs = prepare_for_shell_task(task, kwargs) + if self.get_task_state_info(name, "action").upper() == "PAUSE": + self.set_task_state_info(name, "action", "") + self.process.report(f"Task {name} is created and paused.") + process = create_and_pause_process( + self.runner, + ShellJob, + inputs, + state_msg="Paused through WorkGraph", + ) + self.set_task_state_info(name, "state", "CREATED") + process = process.node + else: + process = self.process.submit(ShellJob, **inputs) + self.set_task_state_info(name, "state", "RUNNING") + process.label = name + self.set_task_state_info(name, "process", process) + self.awaitable_manager.to_context(**{name: process}) + + def execute_while_task(self, task): + """Execute a WHILE task.""" + # TODO refactor this for while, if and zone + # in case of an empty zone, it will finish immediately + name = task["name"] + if self.are_childen_finished(name)[0]: + self.update_while_task_state(name) + else: + # check the conditions of the while task + should_run = self.should_run_while_task(name) + if not should_run: + self.set_task_state_info(name, "state", "FINISHED") + self.set_tasks_state(self.ctx._tasks[name]["children"], "SKIPPED") + self.update_parent_task_state(name) + self.process.report( + f"While Task {name}: Condition not fullilled, task finished. Skip all its children." + ) + else: + task["execution_count"] += 1 + self.set_task_state_info(name, "state", "RUNNING") + self.continue_workgraph() + + def execute_if_task(self, task): + # in case of an empty zone, it will finish immediately + name = task["name"] + if self.are_childen_finished(name)[0]: + self.update_zone_task_state(name) + else: + should_run = self.should_run_if_task(name) + if should_run: + self.set_task_state_info(name, "state", "RUNNING") + else: + self.set_tasks_state(task["children"], "SKIPPED") + self.update_zone_task_state(name) + self.continue_workgraph() + + def execute_zone_task(self, task): + # in case of an empty zone, it will finish immediately + name = task["name"] + if self.are_childen_finished(name)[0]: + self.update_zone_task_state(name) + else: + self.set_task_state_info(name, "state", "RUNNING") + self.continue_workgraph() + + def execute_get_context_task(self, task, kwargs): + # get the results from the context + name = task["name"] + results = {"result": getattr(self.ctx, kwargs["key"])} + task["results"] = results + self.set_task_state_info(name, "state", "FINISHED") + self.update_parent_task_state(name) + self.continue_workgraph() + + def execute_set_context_task(self, task, kwargs): + name = task["name"] + # get the results from the context + setattr(self.ctx, kwargs["key"], kwargs["value"]) + self.set_task_state_info(name, "state", "FINISHED") + self.update_parent_task_state(name) + self.continue_workgraph() + + def execute_awaitable_task( + self, task, executor, args, kwargs, var_args, var_kwargs + ): + name = task["name"] + for key in task["args"]: + kwargs.pop(key, None) + awaitable_target = asyncio.ensure_future( + self.run_executor(executor, args, kwargs, var_args, var_kwargs), + loop=self.process.loop, + ) + awaitable = self.awaitable_manager.construct_awaitable_function( + name, awaitable_target + ) + self.set_task_state_info(name, "state", "RUNNING") + self.awaitable_manager.to_context(**{name: awaitable}) + + def execute_monitor_task(self, task, executor, args, kwargs, var_args, var_kwargs): + name = task["name"] + for key in self.ctx._tasks[name]["args"]: + kwargs.pop(key, None) + # add function and interval to the args + args = [ + executor, + kwargs.pop("interval", 1), + kwargs.pop("timeout", 3600), + *args, + ] + awaitable_target = asyncio.ensure_future( + self.run_executor(monitor, args, kwargs, var_args, var_kwargs), + loop=self.process.loop, + ) + awaitable = self.awaitable_manager.construct_awaitable_function( + name, awaitable_target + ) + self.set_task_state_info(name, "state", "RUNNING") + # save the awaitable to the temp, so that we can kill it if needed + self.awaitable_manager.not_persisted_awaitables[name] = awaitable_target + self.awaitable_manager.to_context(**{name: awaitable}) + + def execute_normal_task( + self, task, executor, args, kwargs, var_args, var_kwargs, continue_workgraph + ): + # Normal task is created by decoratoring a function with @task() + name = task["name"] + if "context" in task["kwargs"]: + self.ctx.task_name = name + kwargs.update({"context": self.ctx}) + for key in self.ctx._tasks[name]["args"]: + kwargs.pop(key, None) + try: + results = self.run_executor(executor, args, kwargs, var_args, var_kwargs) + self.update_normal_task_state(name, results) + except Exception as e: + self.logger.error(f"Error in task {name}: {e}") + self.update_normal_task_state(name, results=None, success=False) + if continue_workgraph: + self.continue_workgraph() + + def get_inputs( + self, name: str + ) -> Tuple[ + List[Any], + Dict[str, Any], + Optional[List[Any]], + Optional[Dict[str, Any]], + Dict[str, Any], + ]: + """Get input based on the links.""" + + args = [] + args_dict = {} + kwargs = {} + var_args = None + var_kwargs = None + task = self.ctx._tasks[name] + properties = task.get("properties", {}) + inputs = {} + for name, input in task["inputs"].items(): + # print(f"input: {input['name']}") + if len(input["links"]) == 0: + inputs[name] = self.ctx_manager.update_context_variable( + input["property"]["value"] + ) + elif len(input["links"]) == 1: + link = input["links"][0] + if self.ctx._tasks[link["from_node"]]["results"] is None: + inputs[name] = None + else: + # handle the special socket _wait, _outputs + if link["from_socket"] == "_wait": + continue + elif link["from_socket"] == "_outputs": + inputs[name] = self.ctx._tasks[link["from_node"]]["results"] + else: + inputs[name] = get_nested_dict( + self.ctx._tasks[link["from_node"]]["results"], + link["from_socket"], + ) + # handle the case of multiple outputs + elif len(input["links"]) > 1: + value = {} + for link in input["links"]: + item_name = f'{link["from_node"]}_{link["from_socket"]}' + # handle the special socket _wait, _outputs + if link["from_socket"] == "_wait": + continue + if self.ctx._tasks[link["from_node"]]["results"] is None: + value[item_name] = None + else: + value[item_name] = self.ctx._tasks[link["from_node"]][ + "results" + ][link["from_socket"]] + inputs[name] = value + for name in task.get("args", []): + if name in inputs: + args.append(inputs[name]) + args_dict[name] = inputs[name] + else: + value = self.ctx_manager.update_context_variable( + properties[name]["value"] + ) + args.append(value) + args_dict[name] = value + for name in task.get("kwargs", []): + if name in inputs: + kwargs[name] = inputs[name] + else: + value = self.ctx_manager.update_context_variable( + properties[name]["value"] + ) + kwargs[name] = value + if task["var_args"] is not None: + name = task["var_args"] + if name in inputs: + var_args = inputs[name] + else: + value = self.ctx_manager.update_context_variable( + properties[name]["value"] + ) + var_args = value + if task["var_kwargs"] is not None: + name = task["var_kwargs"] + if name in inputs: + var_kwargs = inputs[name] + else: + value = self.ctx_manager.update_context_variable( + properties[name]["value"] + ) + var_kwargs = value + return args, kwargs, var_args, var_kwargs, args_dict + + def update_task_state(self, name: str, success=True) -> None: + """Update task state when the task is finished.""" + task = self.ctx._tasks[name] + if success: + node = self.get_task_state_info(name, "process") + if isinstance(node, ProcessNode): + # print(f"set task result: {name} process") + state = node.process_state.value.upper() + if node.is_finished_ok: + self.set_task_state_info(task["name"], "state", state) + if task["metadata"]["node_type"].upper() == "WORKGRAPH": + # expose the outputs of all the tasks in the workgraph + task["results"] = {} + outgoing = node.base.links.get_outgoing() + for link in outgoing.all(): + if isinstance(link.node, ProcessNode) and getattr( + link.node, "process_state", False + ): + task["results"][link.link_label] = link.node.outputs + else: + task["results"] = node.outputs + # self.ctx._new_data[name] = task["results"] + self.set_task_state_info(task["name"], "state", "FINISHED") + self.task_set_context(name) + self.process.report(f"Task: {name} finished.") + # all other states are considered as failed + else: + task["results"] = node.outputs + self.on_task_failed(name) + elif isinstance(node, Data): + # + output_name = [ + output_name + for output_name in list(task["outputs"].keys()) + if output_name not in ["_wait", "_outputs"] + ][0] + task["results"] = {output_name: node} + self.set_task_state_info(task["name"], "state", "FINISHED") + self.task_set_context(name) + self.process.report(f"Task: {name} finished.") + else: + task.setdefault("results", None) + else: + self.on_task_failed(name) + self.update_parent_task_state(name) + + def on_task_failed(self, name: str) -> None: + """Handle the case where a task has failed.""" + self.set_task_state_info(name, "state", "FAILED") + self.set_tasks_state(self.ctx._connectivity["child_node"][name], "SKIPPED") + self.process.report(f"Task: {name} failed.") + self.process.error_handler_manager.run_error_handlers(name) + + def update_task(self, task: Task): + """Update task in the context. + This is used in error handlers to update the task parameters.""" + tdata = task.to_dict() + self.ctx._tasks[task.name]["properties"] = tdata["properties"] + self.ctx._tasks[task.name]["inputs"] = tdata["inputs"] + self.reset_task(task.name) + + def update_normal_task_state(self, name, results, success=True): + """Set the results of a normal task. + A normal task is created by decorating a function with @task(). + """ + from aiida_workgraph.utils import get_sorted_names + + if success: + task = self.ctx._tasks[name] + if isinstance(results, tuple): + if len(task["outputs"]) != len(results): + return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS + output_names = get_sorted_names(task["outputs"]) + for i, output_name in enumerate(output_names): + task["results"][output_name] = results[i] + elif isinstance(results, dict): + task["results"] = results + else: + output_name = [ + output_name + for output_name in list(task["outputs"].keys()) + if output_name not in ["_wait", "_outputs"] + ][0] + task["results"][output_name] = results + self.task_set_context(name) + self.set_task_state_info(name, "state", "FINISHED") + self.process.report(f"Task: {name} finished.") + else: + self.on_task_failed(name) + self.update_parent_task_state(name) + + def update_parent_task_state(self, name: str) -> None: + """Update parent task state.""" + parent_task = self.ctx._tasks[name]["parent_task"] + if parent_task[0]: + task_type = self.ctx._tasks[parent_task[0]]["metadata"]["node_type"].upper() + if task_type == "WHILE": + self.update_while_task_state(parent_task[0]) + elif task_type == "IF": + self.update_zone_task_state(parent_task[0]) + elif task_type == "ZONE": + self.update_zone_task_state(parent_task[0]) + + def update_while_task_state(self, name: str) -> None: + """Update while task state.""" + finished, _ = self.are_childen_finished(name) + + if finished: + self.process.report( + f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration." + ) + # reset the condition tasks + for link in self.ctx._tasks[name]["inputs"]["conditions"]["links"]: + self.reset_task(link["from_node"], recursive=False) + # reset the task and all its children, so that the task can run again + # do not reset the execution count + self.reset_task(name, reset_execution_count=False) + + def update_zone_task_state(self, name: str) -> None: + """Update zone task state.""" + finished, _ = self.are_childen_finished(name) + if finished: + self.set_task_state_info(name, "state", "FINISHED") + self.process.report(f"Task: {name} finished.") + self.update_parent_task_state(name) + + def should_run_while_task(self, name: str) -> tuple[bool, Any]: + """Check if the while task should run.""" + # check the conditions of the while task + not_excess_max_iterations = ( + self.ctx._tasks[name]["execution_count"] + < self.ctx._tasks[name]["inputs"]["max_iterations"]["property"]["value"] + ) + conditions = [not_excess_max_iterations] + _, kwargs, _, _, _ = self.get_inputs(name) + if isinstance(kwargs["conditions"], list): + for condition in kwargs["conditions"]: + value = get_nested_dict(self.ctx, condition) + conditions.append(value) + elif isinstance(kwargs["conditions"], dict): + for _, value in kwargs["conditions"].items(): + conditions.append(value) + else: + conditions.append(kwargs["conditions"]) + return False not in conditions + + def should_run_if_task(self, name: str) -> tuple[bool, Any]: + """Check if the IF task should run.""" + _, kwargs, _, _, _ = self.get_inputs(name) + flag = kwargs["conditions"] + if kwargs["invert_condition"]: + return not flag + return flag + + def are_childen_finished(self, name: str) -> tuple[bool, Any]: + """Check if the child tasks are finished.""" + task = self.ctx._tasks[name] + finished = True + for name in task["children"]: + if self.get_task_state_info(name, "state") not in [ + "FINISHED", + "SKIPPED", + "FAILED", + ]: + finished = False + break + return finished, None + + def run_executor( + self, + executor: Callable, + args: List[Any], + kwargs: Dict[str, Any], + var_args: Optional[List[Any]], + var_kwargs: Optional[Dict[str, Any]], + ) -> Any: + if var_kwargs is None: + return executor(*args, **kwargs) + else: + return executor(*args, **kwargs, **var_kwargs) + + def apply_task_actions(self, msg: dict) -> None: + """Apply task actions to the workgraph.""" + action = msg["action"] + tasks = msg["tasks"] + self.process.report(f"Action: {action}. {tasks}") + if action.upper() == "RESET": + for name in tasks: + self.reset_task(name) + elif action.upper() == "PAUSE": + for name in tasks: + self.pause_task(name) + elif action.upper() == "PLAY": + for name in tasks: + self.play_task(name) + elif action.upper() == "SKIP": + for name in tasks: + self.skip_task(name) + elif action.upper() == "KILL": + for name in tasks: + self.kill_task(name) + + def pause_task(self, name: str) -> None: + """Pause task.""" + self.set_task_state_info(name, "action", "PAUSE") + self.process.report(f"Task {name} action: PAUSE.") + + def play_task(self, name: str) -> None: + """Play task.""" + self.set_task_state_info(name, "action", "") + self.process.report(f"Task {name} action: PLAY.") + + def skip_task(self, name: str) -> None: + """Skip task.""" + self.set_task_state_info(name, "state", "SKIPPED") + self.process.report(f"Task {name} action: SKIP.") + + def kill_task(self, name: str) -> None: + """Kill task. + This is used to kill the awaitable and monitor task. + """ + if self.get_task_state_info(name, "state") in ["RUNNING"]: + if self.ctx._tasks[name]["metadata"]["node_type"].upper() in [ + "AWAITABLE", + "MONITOR", + ]: + try: + self.awaitable_manager.not_persisted_awaitables[name].cancel() + self.set_task_state_info(name, "state", "KILLED") + self.process.report(f"Task {name} action: KILLED.") + except Exception as e: + self.logger.error(f"Error in killing task {name}: {e}") + + def check_while_conditions(self) -> bool: + """Check while conditions. + Run all condition tasks and check if all the conditions are True. + """ + self.process.report("Check while conditions.") + if self.ctx._execution_count >= self.ctx._max_iteration: + self.process.report("Max iteration reached.") + return False + condition_tasks = [] + for c in self.ctx._workgraph["conditions"]: + task_name, socket_name = c.split(".") + if "task_name" != "context": + condition_tasks.append(task_name) + self.reset_task(task_name) + self.run_tasks(condition_tasks, continue_workgraph=False) + conditions = [] + for c in self.ctx._workgraph["conditions"]: + task_name, socket_name = c.split(".") + if task_name == "context": + conditions.append(self.ctx[socket_name]) + else: + conditions.append(self.ctx._tasks[task_name]["results"][socket_name]) + should_run = False not in conditions + if should_run: + self.reset() + self.set_tasks_state(condition_tasks, "SKIPPED") + return should_run + + def check_for_conditions(self) -> bool: + condition_tasks = [c[0] for c in self.ctx._workgraph["conditions"]] + self.run_tasks(condition_tasks) + conditions = [self.ctx._count < len(self.ctx._sequence)] + [ + self.ctx._tasks[c[0]]["results"][c[1]] + for c in self.ctx._workgraph["conditions"] + ] + should_run = False not in conditions + if should_run: + self.reset() + self.set_tasks_state(condition_tasks, "SKIPPED") + self.ctx["i"] = self.ctx._sequence[self.ctx._count] + self.ctx._count += 1 + return should_run + + def reset(self) -> None: + self.ctx._execution_count += 1 + self.set_tasks_state(self.ctx._tasks.keys(), "PLANNED") + self.ctx._executed_tasks = [] diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index c14b4915..9d86efbd 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -1,4 +1,3 @@ -from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes from aiida import orm from aiida.common.extendeddicts import AttributeDict @@ -27,105 +26,87 @@ def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: return inputs, wgdata +def sort_socket_data(socket_data: dict) -> dict: + """Sort the socket data by the list_index""" + data = [ + {"name": data["name"], "identifier": data["identifier"]} + for data, _ in sorted( + ((data, data["list_index"]) for data in socket_data.values()), + key=lambda x: x[1], + ) + ] + return data + + def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: """Prepare the inputs for PythonJob""" - from aiida_workgraph.utils import get_or_create_code - import os + from aiida_pythonjob import prepare_pythonjob_inputs # get the names kwargs for the PythonJob, which are the inputs before _wait - function_kwargs = kwargs.pop("function_kwargs", {}) - # TODO better way to find the function_kwargs - input_names = [ - name - for name, _ in sorted( - ((name, input["list_index"]) for name, input in task["inputs"].items()), - key=lambda x: x[1], - ) - ] - for name in input_names: - if name == "_wait": + function_inputs = kwargs.pop("function_inputs", {}) + sorted_inputs = sort_socket_data(task["inputs"]) + # TODO better way to find the function_inputs + for input in sorted_inputs: + if input["name"] == "_wait": break - function_kwargs[name] = kwargs.pop(name, None) + function_inputs[input["name"]] = kwargs.pop(input["name"], None) + # if the var_kwargs is not None, we need to pop the var_kwargs from the kwargs - # then update the function_kwargs if var_kwargs is not None + # then update the function_inputs if var_kwargs is not None if task["var_kwargs"] is not None: - function_kwargs.pop(task["var_kwargs"], None) + function_inputs.pop(task["var_kwargs"], None) if var_kwargs: # var_kwargs can be AttributeDict if it get data from the previous task output if isinstance(var_kwargs, (dict, AttributeDict)): - function_kwargs.update(var_kwargs) + function_inputs.update(var_kwargs) # otherwise, it should be a Data node elif isinstance(var_kwargs, orm.Data): - function_kwargs.update(var_kwargs.value) + function_inputs.update(var_kwargs.value) else: raise ValueError(f"Invalid var_kwargs type: {type(var_kwargs)}") # setup code code = kwargs.pop("code", None) - computer = kwargs.pop("computer", None) - code_label = kwargs.pop("code_label", None) - code_path = kwargs.pop("code_path", None) - prepend_text = kwargs.pop("prepend_text", None) + computer = kwargs.pop("computer", "localhost") + command_info = kwargs.pop("command_info", {}) upload_files = kwargs.pop("upload_files", {}) - new_upload_files = {} - # change the string in the upload files to SingleFileData, or FolderData - for key, source in upload_files.items(): - # only alphanumeric and underscores are allowed in the key - # replace all "." with "_dot_" - new_key = key.replace(".", "_dot_") - if isinstance(source, str): - if os.path.isfile(source): - new_upload_files[new_key] = orm.SinglefileData(file=source) - elif os.path.isdir(source): - new_upload_files[new_key] = orm.FolderData(tree=source) - elif isinstance(source, (orm.SinglefileData, orm.FolderData)): - new_upload_files[new_key] = source - else: - raise ValueError(f"Invalid upload file type: {type(source)}, {source}") - # - if code is None: - code = get_or_create_code( - computer=computer if computer else "localhost", - code_label=code_label if code_label else "python3", - code_path=code_path if code_path else None, - prepend_text=prepend_text if prepend_text else None, - ) + metadata = kwargs.pop("metadata", {}) metadata.update({"call_link_label": task["name"]}) # get the source code of the function - function_name = task["executor"]["name"] - if task["executor"].get("is_pickle", False): - function_source_code = ( - task["executor"]["import_statements"] - + "\n" - + task["executor"]["source_code_without_decorator"] - ) - else: - function_source_code = ( - f"from {task['executor']['module']} import {function_name}" - ) + executor = task["executor"] # outputs - function_outputs = [ - output + outputs = [ + {"name": output["name"], "identifier": output["identifier"]} for output, _ in sorted( ((output, output["list_index"]) for output in task["outputs"].values()), key=lambda x: x[1], ) ] - # serialize the kwargs into AiiDA Data - function_kwargs = serialize_to_aiida_nodes(function_kwargs) - # transfer the args to kwargs - inputs = { - "process_label": f"PythonJob<{task['name']}>", - "function_source_code": orm.Str(function_source_code), - "function_name": orm.Str(function_name), - "code": code, - "function_kwargs": function_kwargs, - "upload_files": new_upload_files, - "function_outputs": orm.List(function_outputs), - "metadata": metadata, + # only the output before _wait is the function_outputs + function_outputs = [] + for output in outputs: + if output["name"] == "_wait": + break + # if the output is WORKGRAPH.NAMESPACE, we need to change it to NAMESPACE + if output["identifier"].upper() == "WORKGRAPH.NAMESPACE": + function_outputs.append({"name": output["name"], "identifier": "NAMESPACE"}) + else: + function_outputs.append(output) + + inputs = prepare_pythonjob_inputs( + pickled_function=executor, + function_inputs=function_inputs, + function_outputs=function_outputs, + code=code, + command_info=command_info, + computer=computer, + metadata=metadata, + upload_files=upload_files, + process_label=f"PythonJob<{task['name']}>", **kwargs, - } + ) + return inputs diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index bebe7086..f4145a9d 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -1,9 +1,7 @@ """AiiDA workflow components: WorkGraph.""" from __future__ import annotations -import asyncio import collections.abc -import functools import logging import typing as t @@ -13,28 +11,21 @@ from plumpy.workchains import _PropagateReturn import kiwipy -from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict from aiida.common.lang import override from aiida import orm -from aiida.orm import load_node, Node, ProcessNode, WorkChainNode -from aiida.orm.utils.serialize import deserialize_unsafe, serialize +from aiida.orm import Node, WorkChainNode +from aiida.orm.utils.serialize import deserialize_unsafe from aiida.engine.processes.exit_code import ExitCode from aiida.engine.processes.process import Process - -from aiida.engine.processes.workchains.awaitable import ( - Awaitable, - AwaitableAction, - AwaitableTarget, - construct_awaitable, -) from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec -from aiida.engine import run_get_node -from aiida_workgraph.utils import create_and_pause_process -from aiida_workgraph.task import Task from aiida_workgraph.utils import get_nested_dict, update_nested_dict -from aiida_workgraph.executors.monitors import monitor +from .context_manager import ContextManager +from .awaitable_manager import AwaitableManager +from .task_manager import TaskManager +from .error_handler_manager import ErrorHandlerManager +from aiida.engine.processes.workchains.awaitable import Awaitable if t.TYPE_CHECKING: from aiida.engine.runners import Runner # pylint: disable=unused-import @@ -42,9 +33,6 @@ __all__ = "WorkGraph" -MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}." - - @auto_persist("_awaitables") class WorkGraphEngine(Process, metaclass=Protect): """The `WorkGraph` class is used to construct workflows in AiiDA.""" @@ -71,9 +59,20 @@ def __init__( """ super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) - self._awaitables: list[Awaitable] = [] self._context = AttributeDict() + self.ctx_manager = ContextManager( + self._context, process=self, logger=self.logger + ) + self.awaitable_manager = AwaitableManager( + self._awaitables, self.runner, self.logger, self, self.ctx_manager + ) + self.task_manager = TaskManager( + self.ctx_manager, self.logger, self.runner, self, self.awaitable_manager + ) + self.error_handler_manager = ErrorHandlerManager( + self, self.ctx_manager, self.logger + ) @classmethod def define(cls, spec: WorkChainSpec) -> None: @@ -143,141 +142,38 @@ def load_instance_state( super().load_instance_state(saved_state, load_context) # Load the context self._context = saved_state[self._CONTEXT] - self._temp = {"awaitables": {}} - self.set_logger(self.node.logger) - + # TODO I don't know why we need to reinitialize the context, awaitables, and task_manager + # Need to initialize the context, awaitables, and task_manager + self.ctx_manager = ContextManager( + self._context, process=self, logger=self.logger + ) + self.awaitable_manager = AwaitableManager( + self._awaitables, self.runner, self.logger, self, self.ctx_manager + ) + self.task_manager = TaskManager( + self.ctx_manager, self.logger, self.runner, self, self.awaitable_manager + ) + self.error_handler_manager = ErrorHandlerManager( + self, self.ctx_manager, self.logger + ) + # "_awaitables" is auto persisted. if self._awaitables: # For the "ascyncio.tasks.Task" awaitable, because there are only in-memory, # we need to reset the tasks and so that they can be re-run again. should_resume = False for awaitable in self._awaitables: if awaitable.target == "asyncio.tasks.Task": - self._resolve_awaitable(awaitable, None) + self.awaitable_manager.resolve_awaitable(awaitable, None) self.report(f"reset awaitable task: {awaitable.key}") - self.reset_task(awaitable.key) + self.task_manager.reset_task(awaitable.key) should_resume = True if should_resume: - self._update_process_status() + self.awaitable_manager.update_process_status() self.resume() # For other awaitables, because they exist in the db, we only need to re-register the callbacks self.ctx._awaitable_actions = [] - self._action_awaitables() - - def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: - """ - Returns a reference to a sub-dictionary of the context and the last key, - after resolving a potentially segmented key where required sub-dictionaries are created as needed. - - :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary - """ - ctx = self.ctx - ctx_path = key.split(".") - - for index, path in enumerate(ctx_path[:-1]): - try: - ctx = ctx[path] - except KeyError: # see below why this is the only exception we have to catch here - ctx[ - path - ] = AttributeDict() # create the sub-dict and update the context - ctx = ctx[path] - continue - - # Notes: - # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking - # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables - # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself - # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable - # would be an AttributeDict we can append things to it since the order of tasks is maintained. - if type(ctx) != AttributeDict: # pylint: disable=C0123 - raise ValueError( - f"Can not update the context for key `{key}`: " - f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict' - ) - - return ctx, ctx_path[-1] - - def _insert_awaitable(self, awaitable: Awaitable) -> None: - """Insert an awaitable that should be terminated before before continuing to the next step. - - :param awaitable: the thing to await - """ - ctx, key = self._resolve_nested_context(awaitable.key) - - # Already assign the awaitable itself to the location in the context container where it is supposed to end up - # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the - # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the - # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. - if awaitable.action == AwaitableAction.ASSIGN: - ctx[key] = awaitable - elif awaitable.action == AwaitableAction.APPEND: - ctx.setdefault(key, []).append(awaitable) - else: - raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") - - self._awaitables.append( - awaitable - ) # add only if everything went ok, otherwise we end up in an inconsistent state - self._update_process_status() - - def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None: - """Resolve an awaitable. - - Precondition: must be an awaitable that was previously inserted. - - :param awaitable: the awaitable to resolve - :param value: the value to assign to the awaitable - """ - ctx, key = self._resolve_nested_context(awaitable.key) - - if awaitable.action == AwaitableAction.ASSIGN: - ctx[key] = value - elif awaitable.action == AwaitableAction.APPEND: - # Find the same awaitable inserted in the context - container = ctx[key] - for index, placeholder in enumerate(container): - if ( - isinstance(placeholder, Awaitable) - and placeholder.pk == awaitable.pk - ): - container[index] = value - break - else: - raise AssertionError( - f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`" - ) - else: - raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") - - awaitable.resolved = True - # remove awaitabble from the list - self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk] - - if not self.has_terminated(): - # the process may be terminated, for example, if the process was killed or excepted - # then we should not try to update it - self._update_process_status() - - @Protect.final - def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: - """Add a dictionary of awaitables to the context. - - This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will - assign a certain value to the corresponding key in the context of the work graph. - """ - for key, value in kwargs.items(): - awaitable = construct_awaitable(value) - awaitable.key = key - self._insert_awaitable(awaitable) - - def _update_process_status(self) -> None: - """Set the process status with a message accounting the current sub processes that we are waiting for.""" - if self._awaitables: - status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" - self.node.set_process_status(status) - else: - self.node.set_process_status(None) + self.awaitable_manager.action_awaitables() @override def run(self) -> t.Any: @@ -297,11 +193,11 @@ def _do_step(self) -> t.Any: result: t.Any = None try: - self.continue_workgraph() + self.task_manager.continue_workgraph() except _PropagateReturn as exception: finished, result = True, exception.exit_code else: - finished, result = self.is_workgraph_finished() + finished, result = self.task_manager.is_workgraph_finished() # If the workgraph is finished or the result is an ExitCode, we exit by returning if finished: @@ -349,106 +245,11 @@ def on_wait(self, awaitables: t.Sequence[t.Awaitable]): """Entering the WAITING state.""" super().on_wait(awaitables) if self._awaitables: - self._action_awaitables() + self.awaitable_manager.action_awaitables() self.report("Process status: {}".format(self.node.process_status)) else: self.call_soon(self.resume) - def _action_awaitables(self) -> None: - """Handle the awaitables that are currently registered with the work chain. - - Depending on the class type of the awaitable's target a different callback - function will be bound with the awaitable and the runner will be asked to - call it when the target is completed - """ - for awaitable in self._awaitables: - # if the waitable already has a callback, skip - if awaitable.pk in self.ctx._awaitable_actions: - continue - if awaitable.target == AwaitableTarget.PROCESS: - callback = functools.partial( - self.call_soon, self._on_awaitable_finished, awaitable - ) - self.runner.call_on_process_finish(awaitable.pk, callback) - self.ctx._awaitable_actions.append(awaitable.pk) - elif awaitable.target == "asyncio.tasks.Task": - # this is a awaitable task, the callback function is already set - self.ctx._awaitable_actions.append(awaitable.pk) - else: - assert f"invalid awaitable target '{awaitable.target}'" - - def _on_awaitable_finished(self, awaitable: Awaitable) -> None: - """Callback function, for when an awaitable process instance is completed. - - The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all - awaitables have been dealt with, the work chain process is resumed. - - :param awaitable: an Awaitable instance - """ - self.logger.debug(f"Awaitable {awaitable.key} finished.") - - if isinstance(awaitable.pk, int): - self.logger.info( - "received callback that awaitable with key {} and pk {} has terminated".format( - awaitable.key, awaitable.pk - ) - ) - try: - node = load_node(awaitable.pk) - except (exceptions.MultipleObjectsError, exceptions.NotExistent): - raise ValueError( - f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance" - ) - - if awaitable.outputs: - value = { - entry.link_label: entry.node - for entry in node.base.links.get_outgoing() - } - else: - value = node # type: ignore - else: - # In this case, the pk and key are the same. - self.logger.info( - "received callback that awaitable {} has terminated".format( - awaitable.key - ) - ) - try: - # if awaitable is cancelled, the result is None - if awaitable.cancelled(): - self.set_task_state_info(awaitable.key, "state", "KILLED") - # set child tasks state to SKIPPED - self.set_tasks_state( - self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" - ) - self.report(f"Task: {awaitable.key} cancelled.") - else: - results = awaitable.result() - self.update_normal_task_state(awaitable.key, results) - except Exception as e: - self.logger.error(f"Error in awaitable {awaitable.key}: {e}") - self.set_task_state_info(awaitable.key, "state", "FAILED") - # set child tasks state to SKIPPED - self.set_tasks_state( - self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" - ) - self.report(f"Task: {awaitable.key} failed.") - self.run_error_handlers(awaitable.key) - value = None - - self._resolve_awaitable(awaitable, value) - - # node finished, update the task state and result - # udpate the task state - self.update_task_state(awaitable.key) - # try to resume the workgraph, if the workgraph is already resumed - # by other awaitable, this will not work - try: - self.resume() - except Exception as e: - print(e) - def _build_process_label(self) -> str: """Use the workgraph name as the process label.""" return f"WorkGraph<{self.inputs.wg['name']}>" @@ -481,20 +282,18 @@ def setup(self) -> None: self.ctx._msgs = [] self.ctx._execution_count = 1 # init task results - self.set_task_results() - # data not to be persisted, because they are not serializable - self._temp = {"awaitables": {}} + self.task_manager.set_task_results() # while workgraph if self.ctx._workgraph["workgraph_type"].upper() == "WHILE": self.ctx._max_iteration = self.ctx._workgraph.get("max_iteration", 1000) - should_run = self.check_while_conditions() + should_run = self.task_manager.check_while_conditions() if not should_run: - self.set_tasks_state(self.ctx._tasks.keys(), "SKIPPED") + self.task_manager.set_tasks_state(self.ctx._tasks.keys(), "SKIPPED") # for workgraph if self.ctx._workgraph["workgraph_type"].upper() == "FOR": - should_run = self.check_for_conditions() + should_run = self.task_manager.check_for_conditions() if not should_run: - self.set_tasks_state(self.ctx._tasks.keys(), "SKIPPED") + self.task_manager.set_tasks_state(self.ctx._tasks.keys(), "SKIPPED") def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: """setup the workgraph in the context.""" @@ -530,45 +329,6 @@ def update_workgraph_from_base(self) -> None: task["results"] = self.ctx._tasks[name].get("results") self.setup_ctx_workgraph(wgdata) - def get_task(self, name: str): - """Get task from the context.""" - task = Task.from_dict(self.ctx._tasks[name]) - # update task results - for output in task.outputs: - output.value = get_nested_dict( - self.ctx._tasks[name]["results"], - output.name, - default=output.value, - ) - return task - - def update_task(self, task: Task): - """Update task in the context. - This is used in error handlers to update the task parameters.""" - tdata = task.to_dict() - self.ctx._tasks[task.name]["properties"] = tdata["properties"] - self.ctx._tasks[task.name]["inputs"] = tdata["inputs"] - self.reset_task(task.name) - - def get_task_state_info(self, name: str, key: str) -> str: - """Get task state info from ctx.""" - - value = self.ctx._tasks[name].get(key, None) - if key == "process" and value is not None: - value = deserialize_unsafe(value) - return value - - def set_task_state_info(self, name: str, key: str, value: any) -> None: - """Set task state info to ctx and base.extras. - We task state to the base.extras, so that we can access outside the engine""" - - if key == "process": - value = serialize(value) - self.node.base.extras.set(f"_task_{key}_{name}", value) - else: - self.node.base.extras.set(f"_task_{key}_{name}", value) - self.ctx._tasks[name][key] = value - def init_ctx(self, wgdata: t.Dict[str, t.Any]) -> None: """Init the context from the workgraph data.""" from aiida_workgraph.utils import update_nested_dict @@ -584,899 +344,13 @@ def init_ctx(self, wgdata: t.Dict[str, t.Any]) -> None: # set up the workgraph self.setup_ctx_workgraph(wgdata) - def set_task_results(self) -> None: - for name, task in self.ctx._tasks.items(): - if self.get_task_state_info(name, "action").upper() == "RESET": - self.reset_task(task["name"]) - self.update_task_state(name) - def apply_action(self, msg: dict) -> None: if msg["catalog"] == "task": - self.apply_task_actions(msg) + self.task_manager.apply_task_actions(msg) else: self.report(f"Unknow message type {msg}") - def apply_task_actions(self, msg: dict) -> None: - """Apply task actions to the workgraph.""" - action = msg["action"] - tasks = msg["tasks"] - self.report(f"Action: {action}. {tasks}") - if action.upper() == "RESET": - for name in tasks: - self.reset_task(name) - elif action.upper() == "PAUSE": - for name in tasks: - self.pause_task(name) - elif action.upper() == "PLAY": - for name in tasks: - self.play_task(name) - elif action.upper() == "SKIP": - for name in tasks: - self.skip_task(name) - elif action.upper() == "KILL": - for name in tasks: - self.kill_task(name) - - def reset_task( - self, - name: str, - reset_process: bool = True, - recursive: bool = True, - reset_execution_count: bool = True, - ) -> None: - """Reset task state and remove it from the executed task. - If recursive is True, reset its child tasks.""" - - self.set_task_state_info(name, "state", "PLANNED") - if reset_process: - self.set_task_state_info(name, "process", None) - self.remove_executed_task(name) - # self.logger.debug(f"Task {name} action: RESET.") - # if the task is a while task, reset its child tasks - if self.ctx._tasks[name]["metadata"]["node_type"].upper() == "WHILE": - if reset_execution_count: - self.ctx._tasks[name]["execution_count"] = 0 - for child_task in self.ctx._tasks[name]["children"]: - self.reset_task(child_task, reset_process=False, recursive=False) - elif self.ctx._tasks[name]["metadata"]["node_type"].upper() in ["IF", "ZONE"]: - for child_task in self.ctx._tasks[name]["children"]: - self.reset_task(child_task, reset_process=False, recursive=False) - if recursive: - # reset its child tasks - names = self.ctx._connectivity["child_node"][name] - for name in names: - self.reset_task(name, recursive=False) - - def pause_task(self, name: str) -> None: - """Pause task.""" - self.set_task_state_info(name, "action", "PAUSE") - self.report(f"Task {name} action: PAUSE.") - - def play_task(self, name: str) -> None: - """Play task.""" - self.set_task_state_info(name, "action", "") - self.report(f"Task {name} action: PLAY.") - - def skip_task(self, name: str) -> None: - """Skip task.""" - self.set_task_state_info(name, "state", "SKIPPED") - self.report(f"Task {name} action: SKIP.") - - def kill_task(self, name: str) -> None: - """Kill task. - This is used to kill the awaitable and monitor task. - """ - if self.get_task_state_info(name, "state") in ["RUNNING"]: - if self.ctx._tasks[name]["metadata"]["node_type"].upper() in [ - "AWAITABLE", - "MONITOR", - ]: - try: - self._temp["awaitables"][name].cancel() - self.set_task_state_info(name, "state", "KILLED") - self.report(f"Task {name} action: KILLED.") - except Exception as e: - self.logger.error(f"Error in killing task {name}: {e}") - - def continue_workgraph(self) -> None: - self.report("Continue workgraph.") - # self.update_workgraph_from_base() - task_to_run = [] - for name, task in self.ctx._tasks.items(): - # update task state - if ( - self.get_task_state_info(task["name"], "state") - in [ - "CREATED", - "RUNNING", - "FINISHED", - "FAILED", - "SKIPPED", - ] - or name in self.ctx._executed_tasks - ): - continue - ready, _ = self.is_task_ready_to_run(name) - if ready: - task_to_run.append(name) - # - self.report("tasks ready to run: {}".format(",".join(task_to_run))) - self.run_tasks(task_to_run) - - def update_task_state(self, name: str, success=True) -> None: - """Update task state when the task is finished.""" - task = self.ctx._tasks[name] - if success: - node = self.get_task_state_info(name, "process") - if isinstance(node, orm.ProcessNode): - # print(f"set task result: {name} process") - state = node.process_state.value.upper() - if node.is_finished_ok: - self.set_task_state_info(task["name"], "state", state) - if task["metadata"]["node_type"].upper() == "WORKGRAPH": - # expose the outputs of all the tasks in the workgraph - task["results"] = {} - outgoing = node.base.links.get_outgoing() - for link in outgoing.all(): - if isinstance(link.node, ProcessNode) and getattr( - link.node, "process_state", False - ): - task["results"][link.link_label] = link.node.outputs - else: - task["results"] = node.outputs - # self.ctx._new_data[name] = task["results"] - self.set_task_state_info(task["name"], "state", "FINISHED") - self.task_set_context(name) - self.report(f"Task: {name} finished.") - # all other states are considered as failed - else: - task["results"] = node.outputs - self.on_task_failed(name) - elif isinstance(node, orm.Data): - # - output_name = [ - output_name - for output_name in list(task["outputs"].keys()) - if output_name not in ["_wait", "_outputs"] - ][0] - task["results"] = {output_name: node} - self.set_task_state_info(task["name"], "state", "FINISHED") - self.task_set_context(name) - self.report(f"Task: {name} finished.") - else: - task.setdefault("results", None) - else: - self.on_task_failed(name) - self.update_parent_task_state(name) - - def on_task_failed(self, name: str) -> None: - """Handle the case where a task has failed.""" - self.set_task_state_info(name, "state", "FAILED") - self.set_tasks_state(self.ctx._connectivity["child_node"][name], "SKIPPED") - self.report(f"Task: {name} failed.") - self.run_error_handlers(name) - - def update_normal_task_state(self, name, results, success=True): - """Set the results of a normal task. - A normal task is created by decorating a function with @task(). - """ - from aiida_workgraph.utils import get_sorted_names - - if success: - task = self.ctx._tasks[name] - if isinstance(results, tuple): - if len(task["outputs"]) != len(results): - return self.exit_codes.OUTPUS_NOT_MATCH_RESULTS - output_names = get_sorted_names(task["outputs"]) - for i, output_name in enumerate(output_names): - task["results"][output_name] = results[i] - elif isinstance(results, dict): - task["results"] = results - else: - output_name = [ - output_name - for output_name in list(task["outputs"].keys()) - if output_name not in ["_wait", "_outputs"] - ][0] - task["results"][output_name] = results - self.task_set_context(name) - self.set_task_state_info(name, "state", "FINISHED") - self.report(f"Task: {name} finished.") - else: - self.on_task_failed(name) - self.update_parent_task_state(name) - - def update_parent_task_state(self, name: str) -> None: - """Update parent task state.""" - parent_task = self.ctx._tasks[name]["parent_task"] - if parent_task[0]: - task_type = self.ctx._tasks[parent_task[0]]["metadata"]["node_type"].upper() - if task_type == "WHILE": - self.update_while_task_state(parent_task[0]) - elif task_type == "IF": - self.update_zone_task_state(parent_task[0]) - elif task_type == "ZONE": - self.update_zone_task_state(parent_task[0]) - - def update_while_task_state(self, name: str) -> None: - """Update while task state.""" - finished, _ = self.are_childen_finished(name) - - if finished: - self.report( - f"Wihle Task {name}: this iteration finished. Try to reset for the next iteration." - ) - # reset the condition tasks - for link in self.ctx._tasks[name]["inputs"]["conditions"]["links"]: - self.reset_task(link["from_node"], recursive=False) - # reset the task and all its children, so that the task can run again - # do not reset the execution count - self.reset_task(name, reset_execution_count=False) - - def update_zone_task_state(self, name: str) -> None: - """Update zone task state.""" - finished, _ = self.are_childen_finished(name) - if finished: - self.set_task_state_info(name, "state", "FINISHED") - self.report(f"Task: {name} finished.") - self.update_parent_task_state(name) - - def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: - """Check if the while task should run.""" - # check the conditions of the while task - not_excess_max_iterations = ( - self.ctx._tasks[name]["execution_count"] - < self.ctx._tasks[name]["inputs"]["max_iterations"]["property"]["value"] - ) - conditions = [not_excess_max_iterations] - _, kwargs, _, _, _ = self.get_inputs(name) - if isinstance(kwargs["conditions"], list): - for condition in kwargs["conditions"]: - value = get_nested_dict(self.ctx, condition) - conditions.append(value) - elif isinstance(kwargs["conditions"], dict): - for _, value in kwargs["conditions"].items(): - conditions.append(value) - else: - conditions.append(kwargs["conditions"]) - return False not in conditions - - def should_run_if_task(self, name: str) -> tuple[bool, t.Any]: - """Check if the IF task should run.""" - _, kwargs, _, _, _ = self.get_inputs(name) - flag = kwargs["conditions"] - if kwargs["invert_condition"]: - return not flag - return flag - - def are_childen_finished(self, name: str) -> tuple[bool, t.Any]: - """Check if the child tasks are finished.""" - task = self.ctx._tasks[name] - finished = True - for name in task["children"]: - if self.get_task_state_info(name, "state") not in [ - "FINISHED", - "SKIPPED", - "FAILED", - ]: - finished = False - break - return finished, None - - def run_error_handlers(self, task_name: str) -> None: - """Run error handler for a task.""" - - node = self.get_task_state_info(task_name, "process") - if not node or not node.exit_status: - return - # error_handlers from the task - for _, data in self.ctx._tasks[task_name]["error_handlers"].items(): - if node.exit_status in data.get("exit_codes", []): - handler = data["handler"] - self.run_error_handler(handler, data, task_name) - return - # error_handlers from the workgraph - for _, data in self.ctx._error_handlers.items(): - if node.exit_code.status in data["tasks"].get(task_name, {}).get( - "exit_codes", [] - ): - handler = data["handler"] - metadata = data["tasks"][task_name] - self.run_error_handler(handler, metadata, task_name) - return - - def run_error_handler(self, handler: dict, metadata: dict, task_name: str) -> None: - from inspect import signature - from aiida_workgraph.utils import get_executor - - handler, _ = get_executor(handler) - handler_sig = signature(handler) - metadata.setdefault("retry", 0) - self.report(f"Run error handler: {handler.__name__}") - if metadata["retry"] < metadata["max_retries"]: - task = self.get_task(task_name) - try: - if "engine" in handler_sig.parameters: - msg = handler(task, engine=self, **metadata.get("kwargs", {})) - else: - msg = handler(task, **metadata.get("kwargs", {})) - self.update_task(task) - if msg: - self.report(msg) - metadata["retry"] += 1 - except Exception as e: - self.report(f"Error in running error handler: {e}") - - def is_workgraph_finished(self) -> bool: - """Check if the workgraph is finished. - For `while` workgraph, we need check its conditions""" - is_finished = True - failed_tasks = [] - for name, task in self.ctx._tasks.items(): - # self.update_task_state(name) - if self.get_task_state_info(task["name"], "state") in [ - "RUNNING", - "CREATED", - "PLANNED", - "READY", - ]: - is_finished = False - elif self.get_task_state_info(task["name"], "state") == "FAILED": - failed_tasks.append(name) - if is_finished: - if self.ctx._workgraph["workgraph_type"].upper() == "WHILE": - should_run = self.check_while_conditions() - is_finished = not should_run - if self.ctx._workgraph["workgraph_type"].upper() == "FOR": - should_run = self.check_for_conditions() - is_finished = not should_run - if is_finished and len(failed_tasks) > 0: - message = f"WorkGraph finished, but tasks: {failed_tasks} failed. Thus all their child tasks are skipped." - self.report(message) - result = ExitCode(302, message) - else: - result = None - return is_finished, result - - def check_while_conditions(self) -> bool: - """Check while conditions. - Run all condition tasks and check if all the conditions are True. - """ - self.report("Check while conditions.") - if self.ctx._execution_count >= self.ctx._max_iteration: - self.report("Max iteration reached.") - return False - condition_tasks = [] - for c in self.ctx._workgraph["conditions"]: - task_name, socket_name = c.split(".") - if "task_name" != "context": - condition_tasks.append(task_name) - self.reset_task(task_name) - self.run_tasks(condition_tasks, continue_workgraph=False) - conditions = [] - for c in self.ctx._workgraph["conditions"]: - task_name, socket_name = c.split(".") - if task_name == "context": - conditions.append(self.ctx[socket_name]) - else: - conditions.append(self.ctx._tasks[task_name]["results"][socket_name]) - should_run = False not in conditions - if should_run: - self.reset() - self.set_tasks_state(condition_tasks, "SKIPPED") - return should_run - - def check_for_conditions(self) -> bool: - condition_tasks = [c[0] for c in self.ctx._workgraph["conditions"]] - self.run_tasks(condition_tasks) - conditions = [self.ctx._count < len(self.ctx._sequence)] + [ - self.ctx._tasks[c[0]]["results"][c[1]] - for c in self.ctx._workgraph["conditions"] - ] - should_run = False not in conditions - if should_run: - self.reset() - self.set_tasks_state(condition_tasks, "SKIPPED") - self.ctx["i"] = self.ctx._sequence[self.ctx._count] - self.ctx._count += 1 - return should_run - - def remove_executed_task(self, name: str) -> None: - """Remove labels with name from executed tasks.""" - self.ctx._executed_tasks = [ - label for label in self.ctx._executed_tasks if label.split(".")[0] != name - ] - - def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None: - """Run tasks. - Task type includes: Node, Data, CalcFunction, WorkFunction, CalcJob, WorkChain, GraphBuilder, - WorkGraph, PythonJob, ShellJob, While, If, Zone, FromContext, ToContext, Normal. - - Here we use ToContext to pass the results of the run to the next step. - This will force the engine to wait for all the submitted processes to - finish before continuing to the next step. - """ - from aiida_workgraph.utils import ( - get_executor, - create_data_node, - update_nested_dict_with_special_keys, - ) - - for name in names: - # skip if the max number of awaitables is reached - task = self.ctx._tasks[name] - if task["metadata"]["node_type"].upper() in [ - "CALCJOB", - "WORKCHAIN", - "GRAPH_BUILDER", - "WORKGRAPH", - "PYTHONJOB", - "SHELLJOB", - ]: - if len(self._awaitables) >= self.ctx._max_number_awaitables: - print( - MAX_NUMBER_AWAITABLES_MSG.format( - self.ctx._max_number_awaitables, name - ) - ) - continue - # skip if the task is already executed - # or if the task is in a skippped state - if name in self.ctx._executed_tasks or self.get_task_state_info( - name, "state" - ) in ["SKIPPED"]: - continue - self.ctx._executed_tasks.append(name) - print("-" * 60) - - self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") - executor, _ = get_executor(task["executor"]) - # print("executor: ", executor) - args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) - for i, key in enumerate(self.ctx._tasks[name]["args"]): - kwargs[key] = args[i] - # update the port namespace - kwargs = update_nested_dict_with_special_keys(kwargs) - # print("args: ", args) - # print("kwargs: ", kwargs) - # print("var_kwargs: ", var_kwargs) - # kwargs["meta.label"] = name - # output must be a Data type or a mapping of {string: Data} - task["results"] = {} - if task["metadata"]["node_type"].upper() == "NODE": - results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) - self.set_task_state_info(name, "process", results) - self.update_task_state(name) - if continue_workgraph: - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() == "DATA": - for key in self.ctx._tasks[name]["args"]: - kwargs.pop(key, None) - results = create_data_node(executor, args, kwargs) - self.set_task_state_info(name, "process", results) - self.update_task_state(name) - self.ctx._new_data[name] = results - if continue_workgraph: - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in [ - "CALCFUNCTION", - "WORKFUNCTION", - ]: - kwargs.setdefault("metadata", {}) - kwargs["metadata"].update({"call_link_label": name}) - try: - # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node - if var_kwargs is None: - results, process = run_get_node(executor, **kwargs) - else: - results, process = run_get_node( - executor, **kwargs, **var_kwargs - ) - process.label = name - # print("results: ", results) - self.set_task_state_info(name, "process", process) - self.update_task_state(name) - except Exception as e: - self.logger.error(f"Error in task {name}: {e}") - self.update_task_state(name, success=False) - # exclude the current tasks from the next run - if continue_workgraph: - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in ["CALCJOB", "WORKCHAIN"]: - # process = run_get_node(executor, *args, **kwargs) - kwargs.setdefault("metadata", {}) - kwargs["metadata"].update({"call_link_label": name}) - # transfer the args to kwargs - if self.get_task_state_info(name, "action").upper() == "PAUSE": - self.set_task_state_info(name, "action", "") - self.report(f"Task {name} is created and paused.") - process = create_and_pause_process( - self.runner, - executor, - kwargs, - state_msg="Paused through WorkGraph", - ) - self.set_task_state_info(name, "state", "CREATED") - process = process.node - else: - process = self.submit(executor, **kwargs) - self.set_task_state_info(name, "state", "RUNNING") - process.label = name - self.set_task_state_info(name, "process", process) - self.to_context(**{name: process}) - elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: - wg = self.run_executor(executor, [], kwargs, var_args, var_kwargs) - wg.name = name - wg.group_outputs = self.ctx._tasks[name]["metadata"]["group_outputs"] - wg.parent_uuid = self.node.uuid - inputs = wg.prepare_inputs(metadata={"call_link_label": name}) - process = self.submit(WorkGraphEngine, inputs=inputs) - self.set_task_state_info(name, "process", process) - self.set_task_state_info(name, "state", "RUNNING") - self.to_context(**{name: process}) - elif task["metadata"]["node_type"].upper() in ["WORKGRAPH"]: - from .utils import prepare_for_workgraph_task - - inputs, _ = prepare_for_workgraph_task(task, kwargs) - process = self.submit(WorkGraphEngine, inputs=inputs) - self.set_task_state_info(name, "process", process) - self.set_task_state_info(name, "state", "RUNNING") - self.to_context(**{name: process}) - elif task["metadata"]["node_type"].upper() in ["PYTHONJOB"]: - from aiida_workgraph.calculations.python import PythonJob - from .utils import prepare_for_python_task - - inputs = prepare_for_python_task(task, kwargs, var_kwargs) - # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs - if self.get_task_state_info(name, "action").upper() == "PAUSE": - self.set_task_state_info(name, "action", "") - self.report(f"Task {name} is created and paused.") - process = create_and_pause_process( - self.runner, - PythonJob, - inputs, - state_msg="Paused through WorkGraph", - ) - self.set_task_state_info(name, "state", "CREATED") - process = process.node - else: - process = self.submit(PythonJob, **inputs) - self.set_task_state_info(name, "state", "RUNNING") - process.label = name - self.set_task_state_info(name, "process", process) - self.to_context(**{name: process}) - elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]: - from aiida_shell.calculations.shell import ShellJob - from .utils import prepare_for_shell_task - - inputs = prepare_for_shell_task(task, kwargs) - if self.get_task_state_info(name, "action").upper() == "PAUSE": - self.set_task_state_info(name, "action", "") - self.report(f"Task {name} is created and paused.") - process = create_and_pause_process( - self.runner, - ShellJob, - inputs, - state_msg="Paused through WorkGraph", - ) - self.set_task_state_info(name, "state", "CREATED") - process = process.node - else: - process = self.submit(ShellJob, **inputs) - self.set_task_state_info(name, "state", "RUNNING") - process.label = name - self.set_task_state_info(name, "process", process) - self.to_context(**{name: process}) - elif task["metadata"]["node_type"].upper() in ["WHILE"]: - # TODO refactor this for while, if and zone - # in case of an empty zone, it will finish immediately - if self.are_childen_finished(name)[0]: - self.update_while_task_state(name) - else: - # check the conditions of the while task - should_run = self.should_run_while_task(name) - if not should_run: - self.set_task_state_info(name, "state", "FINISHED") - self.set_tasks_state( - self.ctx._tasks[name]["children"], "SKIPPED" - ) - self.update_parent_task_state(name) - self.report( - f"While Task {name}: Condition not fullilled, task finished. Skip all its children." - ) - else: - task["execution_count"] += 1 - self.set_task_state_info(name, "state", "RUNNING") - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in ["IF"]: - # in case of an empty zone, it will finish immediately - if self.are_childen_finished(name)[0]: - self.update_zone_task_state(name) - else: - should_run = self.should_run_if_task(name) - if should_run: - self.set_task_state_info(name, "state", "RUNNING") - else: - self.set_tasks_state(task["children"], "SKIPPED") - self.update_zone_task_state(name) - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in ["ZONE"]: - # in case of an empty zone, it will finish immediately - if self.are_childen_finished(name)[0]: - self.update_zone_task_state(name) - else: - self.set_task_state_info(name, "state", "RUNNING") - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in ["FROM_CONTEXT"]: - # get the results from the context - results = {"result": getattr(self.ctx, kwargs["key"])} - task["results"] = results - self.set_task_state_info(name, "state", "FINISHED") - self.update_parent_task_state(name) - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in ["TO_CONTEXT"]: - # get the results from the context - setattr(self.ctx, kwargs["key"], kwargs["value"]) - self.set_task_state_info(name, "state", "FINISHED") - self.update_parent_task_state(name) - self.continue_workgraph() - elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]: - for key in self.ctx._tasks[name]["args"]: - kwargs.pop(key, None) - awaitable_target = asyncio.ensure_future( - self.run_executor(executor, args, kwargs, var_args, var_kwargs), - loop=self.loop, - ) - awaitable = self.construct_awaitable_function(name, awaitable_target) - self.set_task_state_info(name, "state", "RUNNING") - self.to_context(**{name: awaitable}) - elif task["metadata"]["node_type"].upper() in ["MONITOR"]: - - for key in self.ctx._tasks[name]["args"]: - kwargs.pop(key, None) - # add function and interval to the args - args = [ - executor, - kwargs.pop("interval", 1), - kwargs.pop("timeout", 3600), - *args, - ] - awaitable_target = asyncio.ensure_future( - self.run_executor(monitor, args, kwargs, var_args, var_kwargs), - loop=self.loop, - ) - awaitable = self.construct_awaitable_function(name, awaitable_target) - self.set_task_state_info(name, "state", "RUNNING") - # save the awaitable to the temp, so that we can kill it if needed - self._temp["awaitables"][name] = awaitable_target - self.to_context(**{name: awaitable}) - elif task["metadata"]["node_type"].upper() in ["NORMAL"]: - # Normal task is created by decoratoring a function with @task() - if "context" in task["kwargs"]: - self.ctx.task_name = name - kwargs.update({"context": self.ctx}) - for key in self.ctx._tasks[name]["args"]: - kwargs.pop(key, None) - try: - results = self.run_executor( - executor, args, kwargs, var_args, var_kwargs - ) - self.update_normal_task_state(name, results) - except Exception as e: - self.logger.error(f"Error in task {name}: {e}") - self.update_normal_task_state(name, results=None, success=False) - if continue_workgraph: - self.continue_workgraph() - else: - # self.report("Unknow task type {}".format(task["metadata"]["node_type"])) - return self.exit_codes.UNKNOWN_TASK_TYPE - - def construct_awaitable_function( - self, name: str, awaitable_target: Awaitable - ) -> None: - """Construct the awaitable function.""" - awaitable = Awaitable( - **{ - "pk": name, - "action": AwaitableAction.ASSIGN, - "target": "asyncio.tasks.Task", - "outputs": False, - } - ) - awaitable_target.key = name - awaitable_target.pk = name - awaitable_target.action = AwaitableAction.ASSIGN - awaitable_target.add_done_callback(self._on_awaitable_finished) - return awaitable - - def get_inputs( - self, name: str - ) -> t.Tuple[ - t.List[t.Any], - t.Dict[str, t.Any], - t.Optional[t.List[t.Any]], - t.Optional[t.Dict[str, t.Any]], - t.Dict[str, t.Any], - ]: - """Get input based on the links.""" - - args = [] - args_dict = {} - kwargs = {} - var_args = None - var_kwargs = None - task = self.ctx._tasks[name] - properties = task.get("properties", {}) - inputs = {} - for name, input in task["inputs"].items(): - # print(f"input: {input['name']}") - if len(input["links"]) == 0: - inputs[name] = self.update_context_variable(input["property"]["value"]) - elif len(input["links"]) == 1: - link = input["links"][0] - if self.ctx._tasks[link["from_node"]]["results"] is None: - inputs[name] = None - else: - # handle the special socket _wait, _outputs - if link["from_socket"] == "_wait": - continue - elif link["from_socket"] == "_outputs": - inputs[name] = self.ctx._tasks[link["from_node"]]["results"] - else: - inputs[name] = get_nested_dict( - self.ctx._tasks[link["from_node"]]["results"], - link["from_socket"], - ) - # handle the case of multiple outputs - elif len(input["links"]) > 1: - value = {} - for link in input["links"]: - item_name = f'{link["from_node"]}_{link["from_socket"]}' - # handle the special socket _wait, _outputs - if link["from_socket"] == "_wait": - continue - if self.ctx._tasks[link["from_node"]]["results"] is None: - value[item_name] = None - else: - value[item_name] = self.ctx._tasks[link["from_node"]][ - "results" - ][link["from_socket"]] - inputs[name] = value - for name in task.get("args", []): - if name in inputs: - args.append(inputs[name]) - args_dict[name] = inputs[name] - else: - value = self.update_context_variable(properties[name]["value"]) - args.append(value) - args_dict[name] = value - for name in task.get("kwargs", []): - if name in inputs: - kwargs[name] = inputs[name] - else: - value = self.update_context_variable(properties[name]["value"]) - kwargs[name] = value - if task["var_args"] is not None: - name = task["var_args"] - if name in inputs: - var_args = inputs[name] - else: - value = self.update_context_variable(properties[name]["value"]) - var_args = value - if task["var_kwargs"] is not None: - name = task["var_kwargs"] - if name in inputs: - var_kwargs = inputs[name] - else: - value = self.update_context_variable(properties[name]["value"]) - var_kwargs = value - return args, kwargs, var_args, var_kwargs, args_dict - - def update_context_variable(self, value: t.Any) -> t.Any: - # replace context variables - - """Get value from context.""" - if isinstance(value, dict): - for key, sub_value in value.items(): - value[key] = self.update_context_variable(sub_value) - elif ( - isinstance(value, str) - and value.strip().startswith("{{") - and value.strip().endswith("}}") - ): - name = value[2:-2].strip() - return get_nested_dict(self.ctx, name) - return value - - def task_set_context(self, name: str) -> None: - """Export task result to context.""" - from aiida_workgraph.utils import update_nested_dict - - items = self.ctx._tasks[name]["context_mapping"] - for key, value in items.items(): - result = self.ctx._tasks[name]["results"][key] - update_nested_dict(self.ctx, value, result) - - def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: - """Check if the task ready to run. - For normal task and a zone task, we need to check its input tasks in the connectivity["zone"]. - For task inside a zone, we need to check if the zone (parent task) is ready. - """ - parent_task = self.ctx._tasks[name]["parent_task"] - # input_tasks, parent_task, conditions - parent_states = [True, True] - # if the task belongs to a parent zone - if parent_task[0]: - state = self.get_task_state_info(parent_task[0], "state") - if state not in ["RUNNING"]: - parent_states[1] = False - # check the input tasks of the zone - # check if the zone input tasks are ready - for child_task_name in self.ctx._connectivity["zone"][name]["input_tasks"]: - if self.get_task_state_info(child_task_name, "state") not in [ - "FINISHED", - "SKIPPED", - "FAILED", - ]: - parent_states[0] = False - break - - return all(parent_states), parent_states - - def reset(self) -> None: - self.ctx._execution_count += 1 - self.set_tasks_state(self.ctx._tasks.keys(), "PLANNED") - self.ctx._executed_tasks = [] - - def set_tasks_state( - self, tasks: t.Union[t.List[str], t.Sequence[str]], value: str - ) -> None: - """Set tasks state""" - for name in tasks: - self.set_task_state_info(name, "state", value) - if "children" in self.ctx._tasks[name]: - self.set_tasks_state(self.ctx._tasks[name]["children"], value) - - def run_executor( - self, - executor: t.Callable, - args: t.List[t.Any], - kwargs: t.Dict[str, t.Any], - var_args: t.Optional[t.List[t.Any]], - var_kwargs: t.Optional[t.Dict[str, t.Any]], - ) -> t.Any: - if var_kwargs is None: - return executor(*args, **kwargs) - else: - return executor(*args, **kwargs, **var_kwargs) - - def save_results_to_extras(self, name: str) -> None: - """Save the results to the base.extras. - For the outputs of a Normal task, they are not saved to the database like the calcjob or workchain. - One temporary solution is to save the results to the base.extras. In order to do this, we need to - serialize the results - """ - from aiida_workgraph.utils import get_executor - - results = self.ctx._tasks[name]["results"] - if results is None: - return - datas = {} - for key, value in results.items(): - # find outptus sockets with the name as key - output = [ - output - for output in self.ctx._tasks[name]["outputs"] - if output["name"] == key - ] - if len(output) == 0: - continue - output = output[0] - Executor, _ = get_executor(output["serialize"]) - datas[key] = Executor(value) - self.node.set_extra(f"nodes__results__{name}", datas) - def message_receive( self, _comm: kiwipy.Communicator, msg: t.Dict[str, t.Any] ) -> t.Any: @@ -1554,5 +428,5 @@ def finalize(self) -> t.Optional[ExitCode]: self.out("execution_count", orm.Int(self.ctx._execution_count).store()) self.report("Finalize workgraph.") for _, task in self.ctx._tasks.items(): - if self.get_task_state_info(task["name"], "state") == "FAILED": + if self.task_manager.get_task_state_info(task["name"], "state") == "FAILED": return self.exit_codes.TASK_FAILED diff --git a/aiida_workgraph/executors/qe.py b/aiida_workgraph/executors/qe.py deleted file mode 100644 index e803ba17..00000000 --- a/aiida_workgraph/executors/qe.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Dict -from aiida_workgraph import task -from aiida.orm import StructureData, UpfData - - -@task( - inputs=[ - {"identifier": "workgraph.string", "name": "pseudo_family"}, - {"identifier": StructureData, "name": "structure"}, - ], - outputs=[{"identifier": UpfData, "name": "Pseudo"}], -) -def get_pseudo_from_structure( - pseudo_family: str, structure: StructureData -) -> Dict[str, UpfData]: - """for input_namespace""" - from aiida.orm import Group, QueryBuilder - - pseudo_group = ( - QueryBuilder().append(Group, filters={"label": pseudo_family}).one()[0] - ) - elements = [kind.name for kind in structure.kinds] - pseudos = {} - for ele in elements: - for n in pseudo_group.nodes: - if ele == n.element: - pseudos[ele] = n - return {"Pseudo": pseudos} diff --git a/aiida_workgraph/orm/__init__.py b/aiida_workgraph/orm/__init__.py index e478c56f..b10b96fe 100644 --- a/aiida_workgraph/orm/__init__.py +++ b/aiida_workgraph/orm/__init__.py @@ -1,8 +1,3 @@ -from .general_data import GeneralData -from .serializer import serialize_to_aiida_nodes, general_serializer +from .pickled_data import PickledData -__all__ = ( - "GeneralData", - "serialize_to_aiida_nodes", - "general_serializer", -) +__all__ = ("PickledData",) diff --git a/aiida_workgraph/orm/atoms.py b/aiida_workgraph/orm/atoms.py deleted file mode 100644 index 15cc4b87..00000000 --- a/aiida_workgraph/orm/atoms.py +++ /dev/null @@ -1,53 +0,0 @@ -from aiida.orm import Data -from ase import Atoms -from ase.db.row import atoms2dict -import numpy as np - -__all__ = ("AtomsData",) - - -class AtomsData(Data): - """Data to represent a ASE Atoms.""" - - _cached_atoms = None - - def __init__(self, value=None, **kwargs): - """Initialise a `AtomsData` node instance. - - :param value: ASE Atoms instance to initialise the `AtomsData` node from - """ - atoms = value or Atoms() - super().__init__(**kwargs) - data, keys = self.atoms2dict(atoms) - self.base.attributes.set_many(data) - self.base.attributes.set("keys", keys) - - @classmethod - def atoms2dict(cls, atoms): - data = atoms2dict(atoms) - data.pop("unique_id") - keys = list(data.keys()) - formula = atoms.get_chemical_formula() - data = cls._convert_numpy_to_native(data) - data["formula"] = formula - data["symbols"] = atoms.get_chemical_symbols() - return data, keys - - @classmethod - def _convert_numpy_to_native(self, data): - """Convert numpy types to Python native types for JSON compatibility.""" - for key, value in data.items(): - if isinstance(value, np.bool_): - data[key] = bool(value) - elif isinstance(value, np.ndarray): - data[key] = value.tolist() - elif isinstance(value, np.generic): - data[key] = value.item() - return data - - @property - def value(self): - keys = self.base.attributes.get("keys") - data = self.base.attributes.get_many(keys) - data = dict(zip(keys, data)) - return Atoms(**data) diff --git a/aiida_workgraph/orm/function_data.py b/aiida_workgraph/orm/function_data.py index 397e52f7..573f6663 100644 --- a/aiida_workgraph/orm/function_data.py +++ b/aiida_workgraph/orm/function_data.py @@ -1,10 +1,10 @@ import inspect import textwrap from typing import Callable, Dict, Any, get_type_hints, _SpecialForm -from .general_data import GeneralData +from .pickled_data import PickledData -class PickledFunction(GeneralData): +class PickledFunction(PickledData): """Data class to represent a pickled Python function.""" def __init__(self, value=None, **kwargs): diff --git a/aiida_workgraph/orm/general_data.py b/aiida_workgraph/orm/pickled_data.py similarity index 88% rename from aiida_workgraph/orm/general_data.py rename to aiida_workgraph/orm/pickled_data.py index bcf86b1e..a4922576 100644 --- a/aiida_workgraph/orm/general_data.py +++ b/aiida_workgraph/orm/pickled_data.py @@ -6,27 +6,15 @@ from pickle import UnpicklingError -class Dict(orm.Dict): - @property - def value(self): - return self.get_dict() - - -class List(orm.List): - @property - def value(self): - return self.get_list() - - -class GeneralData(orm.Data): +class PickledData(orm.Data): """Data to represent a pickled value using cloudpickle.""" FILENAME = "value.pkl" # Class attribute to store the filename def __init__(self, value=None, **kwargs): - """Initialize a `GeneralData` node instance. + """Initialize a `PickledData` node instance. - :param value: raw Python value to initialize the `GeneralData` node from. + :param value: raw Python value to initialize the `PickledData` node from. """ super().__init__(**kwargs) self.set_value(value) diff --git a/aiida_workgraph/orm/serializer.py b/aiida_workgraph/orm/serializer.py deleted file mode 100644 index 32726af6..00000000 --- a/aiida_workgraph/orm/serializer.py +++ /dev/null @@ -1,121 +0,0 @@ -from .general_data import GeneralData -from aiida import orm, common -from importlib.metadata import entry_points -from typing import Any -from aiida_workgraph.config import load_config -import sys - - -def get_serializer_from_entry_points() -> dict: - """Retrieve the serializer from the entry points.""" - # import time - - # ts = time.time() - configs = load_config() - serializers = configs.get("serializers", {}) - excludes = serializers.get("excludes", []) - # Retrieve the entry points for 'aiida.data' and store them in a dictionary - eps = entry_points() - if sys.version_info >= (3, 10): - group = eps.select(group="aiida.data") - else: - group = eps.get("aiida.data", []) - eps = {} - for ep in group: - # split the entry point name by first ".", and check the last part - key = ep.name.split(".", 1)[-1] - # skip key without "." because it is not a module name for a data type - if "." not in key or key in excludes: - continue - eps.setdefault(key, []) - eps[key].append(ep) - - # print("Time to load entry points: ", time.time() - ts) - # check if there are duplicates - for key, value in eps.items(): - if len(value) > 1: - if key in serializers: - [ep for ep in value if ep.name == serializers[key]] - eps[key] = [ep for ep in value if ep.name == serializers[key]] - if not eps[key]: - raise ValueError( - f"Entry point {serializers[key]} not found for {key}" - ) - else: - msg = f"Duplicate entry points for {key}: {[ep.name for ep in value]}" - raise ValueError(msg) - return eps - - -eps = get_serializer_from_entry_points() - - -def serialize_to_aiida_nodes(inputs: dict = None) -> dict: - """Serialize the inputs to a dictionary of AiiDA data nodes. - - Args: - inputs (dict): The inputs to be serialized. - - Returns: - dict: The serialized inputs. - """ - new_inputs = {} - # save all kwargs to inputs port - for key, data in inputs.items(): - new_inputs[key] = general_serializer(data) - return new_inputs - - -def clean_dict_key(data): - """Replace "." with "__dot__" in the keys of a dictionary.""" - if isinstance(data, dict): - return {k.replace(".", "__dot__"): clean_dict_key(v) for k, v in data.items()} - return data - - -def general_serializer(data: Any, check_value=True) -> orm.Node: - """Serialize the data to an AiiDA data node.""" - if isinstance(data, orm.Data): - if check_value and not hasattr(data, "value"): - raise ValueError("Only AiiDA data Node with a value attribute is allowed.") - return data - elif isinstance(data, common.extendeddicts.AttributeDict): - # if the data is an AttributeDict, use it directly - return data - # if is string with syntax {{}}, this is a port will read data from ctx - elif isinstance(data, str) and data.startswith("{{") and data.endswith("}}"): - return data - # if data is a class instance, get its __module__ and class name as a string - # for example, an Atoms will have ase.atoms.Atoms - else: - data = clean_dict_key(data) - # try to get the serializer from the entry points - data_type = type(data) - ep_key = f"{data_type.__module__}.{data_type.__name__}" - # search for the key in the entry points - if ep_key in eps: - try: - new_node = eps[ep_key][0].load()(data) - except Exception as e: - raise ValueError(f"Error in serializing {ep_key}: {e}") - finally: - # try to save the node to da - try: - new_node.store() - return new_node - except Exception: - # try to serialize the value as a GeneralData - try: - new_node = GeneralData(data) - new_node.store() - return new_node - except Exception as e: - raise ValueError(f"Error in serializing {ep_key}: {e}") - else: - # try to serialize the data as a GeneralData - try: - new_node = GeneralData(data) - new_node.store() - return new_node - except Exception as e: - raise ValueError(f"Error in serializing {ep_key}: {e}") diff --git a/aiida_workgraph/sockets/builtins.py b/aiida_workgraph/sockets/builtins.py index 692fb75e..9e946982 100644 --- a/aiida_workgraph/sockets/builtins.py +++ b/aiida_workgraph/sockets/builtins.py @@ -1,4 +1,3 @@ -from typing import Optional, Any from aiida_workgraph.socket import TaskSocket @@ -6,197 +5,88 @@ class SocketAny(TaskSocket): """Any socket.""" identifier: str = "workgraph.any" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.any", name, **kwargs) + property_identifier: str = "workgraph.any" class SocketNamespace(TaskSocket): """Namespace socket.""" identifier: str = "workgraph.namespace" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - # Set the default value to an empty dictionary - kwargs.setdefault("default", {}) - self.add_property("workgraph.any", name, **kwargs) + property_identifier: str = "workgraph.any" class SocketFloat(TaskSocket): """Float socket.""" identifier: str = "workgraph.float" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.float", name, **kwargs) + property_identifier: str = "workgraph.float" class SocketInt(TaskSocket): """Int socket.""" identifier: str = "workgraph.int" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.int", name, **kwargs) + property_identifier: str = "workgraph.int" class SocketString(TaskSocket): """String socket.""" identifier: str = "workgraph.string" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.string", name, **kwargs) + property_identifier: str = "workgraph.string" class SocketBool(TaskSocket): """Bool socket.""" identifier: str = "workgraph.bool" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.bool", name, **kwargs) + property_identifier: str = "workgraph.bool" class SocketAiiDAFloat(TaskSocket): """AiiDAFloat socket.""" identifier: str = "workgraph.aiida_float" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_float", name, **kwargs) + property_identifier: str = "workgraph.aiida_float" class SocketAiiDAInt(TaskSocket): """AiiDAInt socket.""" identifier: str = "workgraph.aiida_int" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_int", name, **kwargs) + property_identifier: str = "workgraph.aiida_int" class SocketAiiDAString(TaskSocket): """AiiDAString socket.""" identifier: str = "workgraph.aiida_string" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_string", name, **kwargs) + property_identifier: str = "workgraph.aiida_string" class SocketAiiDABool(TaskSocket): """AiiDABool socket.""" identifier: str = "workgraph.aiida_bool" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_bool", name, **kwargs) + property_identifier: str = "workgraph.aiida_bool" class SocketAiiDAIntVector(TaskSocket): """Socket with a AiiDAIntVector property.""" identifier: str = "workgraph.aiida_int_vector" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_int_vector", name, **kwargs) + property_identifier: str = "workgraph.aiida_int_vector" class SocketAiiDAFloatVector(TaskSocket): """Socket with a FloatVector property.""" identifier: str = "workgraph.aiida_float_vector" - - def __init__( - self, - name: str, - node: Optional[Any] = None, - type: str = "INPUT", - index: int = 0, - uuid: Optional[str] = None, - **kwargs: Any - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_float_vector", name, **kwargs) + property_identifier: str = "workgraph.aiida_float_vector" class SocketStructureData(TaskSocket): """Any socket.""" identifier: str = "workgraph.aiida_structuredata" - - def __init__( - self, name, node=None, type="INPUT", index=0, uuid=None, **kwargs - ) -> None: - super().__init__(name, node, type, index, uuid=uuid) - self.add_property("workgraph.aiida_structuredata", name, **kwargs) + property_identifier: str = "workgraph.aiida_structuredata" diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index f0943e7e..326f2b50 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -60,10 +60,10 @@ def __init__( self.action = "" self.show_socket_depth = 0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self, short: bool = False) -> Dict[str, Any]: from aiida.orm.utils.serialize import serialize - tdata = super().to_dict() + tdata = super().to_dict(short=short) tdata["context_mapping"] = self.context_mapping tdata["wait"] = [task.name for task in self.waiting_on] tdata["children"] = [] @@ -77,9 +77,11 @@ def to_dict(self) -> Dict[str, Any]: return tdata def set_context(self, context: Dict[str, Any]) -> None: - """Update the context mappings for this task.""" - # all keys should belong to the outputs.keys() - remain_keys = set(context.keys()).difference(self.outputs.keys()) + """Set the output of the task as a value in the context. + key is the context key, value is the output key. + """ + # all values should belong to the outputs.keys() + remain_keys = set(context.values()).difference(self.outputs.keys()) if remain_keys: msg = f"Keys {remain_keys} are not in the outputs of this task." raise ValueError(msg) diff --git a/aiida_workgraph/tasks/builtins.py b/aiida_workgraph/tasks/builtins.py index 3e771a80..4c563a82 100644 --- a/aiida_workgraph/tasks/builtins.py +++ b/aiida_workgraph/tasks/builtins.py @@ -23,8 +23,8 @@ def create_sockets(self) -> None: inp.link_limit = 100000 self.outputs.new("workgraph.any", "_wait") - def to_dict(self) -> Dict[str, Any]: - tdata = super().to_dict() + def to_dict(self, short: bool = False) -> Dict[str, Any]: + tdata = super().to_dict(short=short) tdata["children"] = [task.name for task in self.children] return tdata @@ -95,12 +95,12 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.any", "result") -class ToContext(Task): - """ToContext""" +class SetContext(Task): + """SetContext""" - identifier = "workgraph.to_context" - name = "ToContext" - node_type = "TO_CONTEXT" + identifier = "workgraph.set_context" + name = "SetContext" + node_type = "SET_CONTEXT" catalog = "Control" args = ["key", "value"] @@ -114,12 +114,12 @@ def create_sockets(self) -> None: self.outputs.new("workgraph.any", "_wait") -class FromContext(Task): - """FromContext""" +class GetContext(Task): + """GetContext""" - identifier = "workgraph.from_context" - name = "FromContext" - node_type = "FROM_CONTEXT" + identifier = "workgraph.get_context" + name = "GetContext" + node_type = "GET_CONTEXT" catalog = "Control" args = ["key"] @@ -164,7 +164,9 @@ class AiiDAFloat(Task): args = ["value"] def create_sockets(self) -> None: - self.inputs.new("workgraph.aiida_float", "value", default=0.0) + self.inputs.new( + "workgraph.aiida_float", "value", property_data={"default": 0.0} + ) self.outputs.new("workgraph.aiida_float", "result") diff --git a/aiida_workgraph/tasks/pythonjob.py b/aiida_workgraph/tasks/pythonjob.py index e204f68a..0b74ff31 100644 --- a/aiida_workgraph/tasks/pythonjob.py +++ b/aiida_workgraph/tasks/pythonjob.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List from aiida import orm -from aiida_workgraph.orm.serializer import general_serializer +from aiida_pythonjob.data.serializer import general_serializer from aiida_workgraph.task import Task @@ -9,24 +9,24 @@ class PythonJob(Task): identifier = "workgraph.pythonjob" - function_kwargs: List = None + function_inputs: List = None def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob": """Overwrite the update_from_dict method to handle the PythonJob data.""" - self.function_kwargs = data.get("function_kwargs", []) + self.function_inputs = data.get("function_inputs", []) self.deserialize_pythonjob_data(data) super().update_from_dict(data) - def to_dict(self) -> Dict[str, Any]: - data = super().to_dict() - data["function_kwargs"] = self.function_kwargs + def to_dict(self, short: bool = False) -> Dict[str, Any]: + data = super().to_dict(short=short) + data["function_inputs"] = self.function_inputs return data @classmethod def serialize_pythonjob_data(cls, tdata: Dict[str, Any]): """Serialize the properties for PythonJob.""" - input_kwargs = tdata.get("function_kwargs", []) + input_kwargs = tdata.get("function_inputs", []) for name in input_kwargs: tdata["inputs"][name]["property"]["value"] = cls.serialize_socket_data( tdata["inputs"][name] @@ -45,7 +45,7 @@ def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None: Returns: Dict[str, Any]: The processed data dictionary. """ - input_kwargs = tdata.get("function_kwargs", []) + input_kwargs = tdata.get("function_inputs", []) for name in input_kwargs: if name in tdata["inputs"]: diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 67eb11f1..bee8cf6a 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -497,145 +497,6 @@ def serialize_properties(wgdata): prop["value"] = PickledLocalFunction(prop["value"]).store() -def generate_bash_to_create_python_env( - name: str, - pip: list = None, - conda: dict = None, - modules: list = None, - python_version: str = None, - variables: dict = None, - shell: str = "posix", -): - """ - Generates a bash script for creating or updating a Python environment on a remote computer. - If python_version is None, it uses the Python version from the local environment. - Conda is a dictionary that can include 'channels' and 'dependencies'. - """ - import sys - - pip = pip or [] - conda_channels = conda.get("channels", []) if conda else [] - conda_dependencies = conda.get("dependencies", []) if conda else [] - # Determine the Python version from the local environment if not provided - local_python_version = f"{sys.version_info.major}.{sys.version_info.minor}" - desired_python_version = ( - python_version if python_version is not None else local_python_version - ) - - # Start of the script - script = "#!/bin/bash\n\n" - - # Load modules if provided - if modules: - script += "# Load specified system modules\n" - for module in modules: - script += f"module load {module}\n" - - # Conda shell hook initialization for proper conda activation - script += "# Initialize Conda for this shell\n" - script += f'eval "$(conda shell.{shell} hook)"\n' - - script += "# Setup the Python environment\n" - script += "if ! conda info --envs | grep -q ^{name}$; then\n" - script += " # Environment does not exist, create it\n" - if conda_dependencies: - dependencies_string = " ".join(conda_dependencies) - script += f" conda create -y -n {name} python={desired_python_version} {dependencies_string}\n" - else: - script += f" conda create -y -n {name} python={desired_python_version}\n" - script += "fi\n" - if conda_channels: - script += "EXISTING_CHANNELS=$(conda config --show channels)\n" - script += "for CHANNEL in " + " ".join(conda_channels) + ";\n" - script += "do\n" - script += ' if ! echo "$EXISTING_CHANNELS" | grep -q $CHANNEL; then\n' - script += " conda config --prepend channels $CHANNEL\n" - script += " fi\n" - script += "done\n" - script += f"conda activate {name}\n" - - # Install pip packages - if pip: - script += f"pip install {' '.join(pip)}\n" - - # Set environment variables - if variables: - for var, value in variables.items(): - script += f"export {var}='{value}'\n" - - # End of the script - script += "echo 'Environment setup is complete.'\n" - - return script - - -def create_conda_env( - computer: Union[str, orm.Computer], - name: str, - pip: list = None, - conda: list = None, - modules: list = None, - python_version: str = None, - variables: dict = None, - shell: str = "posix", -) -> tuple: - """Test that there is no unexpected output from the connection.""" - # Execute a command that should not return any error, except ``NotImplementedError`` - # since not all transport plugins implement remote command execution. - from aiida.common.exceptions import NotExistent - from aiida import orm - - user = orm.User.collection.get_default() - if isinstance(computer, str): - computer = orm.load_computer(computer) - try: - authinfo = computer.get_authinfo(user) - except NotExistent: - raise f"Computer<{computer.label}> is not yet configured for user<{user.email}>" - - scheduler = authinfo.computer.get_scheduler() - transport = authinfo.get_transport() - - script = generate_bash_to_create_python_env( - name, pip, conda, modules, python_version, variables, shell - ) - with transport: - scheduler.set_transport(transport) - try: - retval, stdout, stderr = transport.exec_command_wait(script) - except NotImplementedError: - return ( - True, - f"Skipped, remote command execution is not implemented for the " - f"`{computer.transport_type}` transport plugin", - ) - - if retval != 0: - return ( - False, - f"The command `echo -n` returned a non-zero return code ({retval})", - ) - - template = """ -We detected an error while creating the environemnt on the remote computer, as shown between the bars -============================================================================================= -{} -============================================================================================= -Please check! - """ - if stderr: - return False, template.format(stderr) - - if stdout: - # the last line is the echo 'Environment setup is complete.' - if not stdout.strip().endswith("Environment setup is complete."): - return False, template.format(stdout) - else: - return True, "Environment setup is complete." - - return True, None - - def create_and_pause_process( runner: Runner = None, process_class: Callable = None, @@ -771,21 +632,25 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di if the former convert them to a list of `dict`s with `name` as the key. :param inout_list: The input/output list to be validated. - :param list_type: "input" or "output" to indicate what is to be validated. - :raises TypeError: If a list of mixed or wrong types is provided to the task + :param list_type: "inputs" or "outputs" to indicate what is to be validated for better error message. + :raises TypeError: If wrong types are provided to the task :return: Processed `inputs`/`outputs` list. """ - if all(isinstance(item, str) for item in inout_list): - return [{"name": item} for item in inout_list] - elif all(isinstance(item, dict) for item in inout_list): - return inout_list - elif not all(isinstance(item, dict) for item in inout_list): + if not all(isinstance(item, (dict, str)) for item in inout_list): raise TypeError( - f"Provide either a list of `str` or `dict` as `{list_type}`, not mixed types." + f"Wrong type provided in the `{list_type}` list to the task, must be either `str` or `dict`." ) - else: - raise TypeError(f"Wrong type provided in the `{list_type}` list to the task.") + + processed_inout_list = [] + + for item in inout_list: + if isinstance(item, str): + processed_inout_list.append({"name": item}) + elif isinstance(item, dict): + processed_inout_list.append(item) + + return processed_inout_list def filter_keys_namespace_depth( diff --git a/aiida_workgraph/web/frontend/src/rete/customization.ts b/aiida_workgraph/web/frontend/src/rete/customization.ts deleted file mode 100644 index aec6f6ed..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization.ts +++ /dev/null @@ -1,97 +0,0 @@ -import { NodeEditor, GetSchemes, ClassicPreset } from 'rete'; - -import { AreaExtensions, AreaPlugin } from 'rete-area-plugin'; -import { - ConnectionPlugin, - Presets as ConnectionPresets, -} from 'rete-connection-plugin'; -import { - ReactPlugin, - ReactArea2D, - Presets as ReactPresets, -} from 'rete-react-plugin'; -import { createRoot } from 'react-dom/client'; - -import { CustomNode } from './customization/CustomNode'; -import { StyledNode } from './customization/StyledNode'; -import { CustomSocket } from './customization/CustomSocket'; -import { CustomConnection } from './customization/CustomConnection'; - -import { addCustomBackground } from './customization/custom-background'; - -type Schemes = GetSchemes< - ClassicPreset.Node, - ClassicPreset.Connection ->; -type AreaExtra = ReactArea2D; - -const socket = new ClassicPreset.Socket('socket'); - -export async function createEditor(container: HTMLElement) { - const editor = new NodeEditor(); - const area = new AreaPlugin(container); - const connection = new ConnectionPlugin(); - const reactRender = new ReactPlugin({ createRoot }); - - AreaExtensions.selectableNodes(area, AreaExtensions.selector(), { - accumulating: AreaExtensions.accumulateOnCtrl(), - }); - - reactRender.addPreset( - ReactPresets.classic.setup({ - customize: { - node(context) { - if (context.payload.label === 'Fully customized') { - return CustomNode; - } - if (context.payload.label === 'Override styles') { - return StyledNode; - } - return ReactPresets.classic.Node; - }, - socket() { - return CustomSocket; - }, - connection() { - return CustomConnection; - }, - }, - }) - ); - - connection.addPreset(ConnectionPresets.classic.setup()); - - addCustomBackground(area); - - editor.use(area); - area.use(connection); - area.use(reactRender); - - // AreaExtensions.simpleNodesOrder(area); - - const aLabel = 'Override styles'; - const bLabel = 'Fully customized'; - - const a = new ClassicPreset.Node(aLabel); - a.addOutput('a', new ClassicPreset.Output(socket)); - a.addInput('a', new ClassicPreset.Input(socket)); - await editor.addNode(a); - - const b = new ClassicPreset.Node(bLabel); - b.addOutput('a', new ClassicPreset.Output(socket)); - b.addInput('a', new ClassicPreset.Input(socket)); - await editor.addNode(b); - - await area.translate(a.id, { x: 0, y: 0 }); - await area.translate(b.id, { x: 300, y: 0 }); - - await editor.addConnection(new ClassicPreset.Connection(a, 'a', b, 'a')); - - setTimeout(() => { - AreaExtensions.zoomAt(area, editor.getNodes()); - }, 300); - - return { - destroy: () => area.destroy(), - }; -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/CustomConnection.tsx b/aiida_workgraph/web/frontend/src/rete/customization/CustomConnection.tsx deleted file mode 100644 index a5696af2..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/CustomConnection.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import * as React from "react"; -import styled from "styled-components"; -import { ClassicScheme, Presets } from "rete-react-plugin"; - -const { useConnection } = Presets.classic; - -const Svg = styled.svg` - overflow: visible !important; - position: absolute; - pointer-events: none; - width: 9999px; - height: 9999px; -`; - -const Path = styled.path<{ styles?: (props: any) => any }>` - fill: none; - stroke-width: 5px; - stroke: black; - pointer-events: auto; - ${(props) => props.styles && props.styles(props)} -`; - -export function CustomConnection(props: { - data: ClassicScheme["Connection"] & { isLoop?: boolean }; - styles?: () => any; -}) { - const { path } = useConnection(); - - if (!path) return null; - - return ( - - - - ); -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/CustomNode.tsx b/aiida_workgraph/web/frontend/src/rete/customization/CustomNode.tsx deleted file mode 100644 index 49c4901f..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/CustomNode.tsx +++ /dev/null @@ -1,187 +0,0 @@ -import * as React from "react"; -import { ClassicScheme, RenderEmit, Presets } from "rete-react-plugin"; -import styled, { css } from "styled-components"; -import { $nodewidth, $socketmargin, $socketsize } from "./vars"; - -const { RefSocket, RefControl } = Presets.classic; - -type NodeExtraData = { width?: number; height?: number }; - -export const NodeStyles = styled.div< - NodeExtraData & { selected: boolean; styles?: (props: any) => any } ->` - background: black; - border: 2px solid grey; - border-radius: 10px; - cursor: pointer; - box-sizing: border-box; - width: ${(props) => - Number.isFinite(props.width) ? `${props.width}px` : `${$nodewidth}px`}; - height: ${(props) => - Number.isFinite(props.height) ? `${props.height}px` : "auto"}; - padding-bottom: 6px; - position: relative; - user-select: none; - &:hover { - background: #333; - } - ${(props) => - props.selected && - css` - border-color: red; - `} - .title { - color: white; - font-family: sans-serif; - font-size: 18px; - padding: 8px; - } - .output { - text-align: right; - } - .input { - text-align: left; - } - .output-socket { - text-align: right; - margin-right: -1px; - display: inline-block; - } - .input-socket { - text-align: left; - margin-left: -1px; - display: inline-block; - } - .input-title, - .output-title { - vertical-align: middle; - color: white; - display: inline-block; - font-family: sans-serif; - font-size: 14px; - margin: ${$socketmargin}px; - line-height: ${$socketsize}px; - } - .input-control { - z-index: 1; - width: calc(100% - ${$socketsize + 2 * $socketmargin}px); - vertical-align: middle; - display: inline-block; - } - .control { - display: block; - padding: ${$socketmargin}px ${$socketsize / 2 + $socketmargin}px; - } - ${(props) => props.styles && props.styles(props)} -`; - -function sortByIndex( - entries: T -) { - entries.sort((a, b) => { - const ai = a[1]?.index || 0; - const bi = b[1]?.index || 0; - - return ai - bi; - }); -} - -type Props = { - data: S["Node"] & NodeExtraData; - styles?: () => any; - emit: RenderEmit; -}; -export type NodeComponent = ( - props: Props -) => JSX.Element; - -export function CustomNode(props: Props) { - const inputs = Object.entries(props.data.inputs); - const outputs = Object.entries(props.data.outputs); - const controls = Object.entries(props.data.controls); - const selected = props.data.selected || false; - const { id, label, width, height } = props.data; - - sortByIndex(inputs); - sortByIndex(outputs); - sortByIndex(controls); - - return ( - -
- {label} -
- {/* Outputs */} - {outputs.map( - ([key, output]) => - output && ( -
-
- {output?.label} -
- -
- ) - )} - {/* Controls */} - {controls.map(([key, control]) => { - return control ? ( - - ) : null; - })} - {/* Inputs */} - {inputs.map( - ([key, input]) => - input && ( -
- - {input && (!input.control || !input.showControl) && ( -
- {input?.label} -
- )} - {input?.control && input?.showControl && ( - - - - )} -
- ) - )} -
- ); -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/CustomSocket.tsx b/aiida_workgraph/web/frontend/src/rete/customization/CustomSocket.tsx deleted file mode 100644 index d4f7c83d..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/CustomSocket.tsx +++ /dev/null @@ -1,25 +0,0 @@ -import * as React from "react"; -import { ClassicPreset } from "rete"; -import styled from "styled-components"; -import { $socketsize } from "./vars"; - -const Styles = styled.div` - display: inline-block; - cursor: pointer; - border: 1px solid grey; - width: ${$socketsize}px; - height: ${$socketsize * 2}px; - vertical-align: middle; - background: #fff; - z-index: 2; - box-sizing: border-box; - &:hover { - background: #ddd; - } -`; - -export function CustomSocket(props: { - data: T; -}) { - return ; -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/StyledNode.tsx b/aiida_workgraph/web/frontend/src/rete/customization/StyledNode.tsx deleted file mode 100644 index cf4e836a..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/StyledNode.tsx +++ /dev/null @@ -1,28 +0,0 @@ -import { Presets } from "rete-react-plugin"; -import { css } from "styled-components"; - -const styles = css<{ selected?: boolean }>` - background: #ebebeb; - border-color: #646464; - .title { - color: #646464; - } - &:hover { - background: #f2f2f2; - } - .output-socket { - margin-right: -1px; - } - .input-socket { - margin-left: -1px; - } - ${(props) => - props.selected && - css` - border-color: red; - `} -`; - -export function StyledNode(props: any) { - return styles} {...props} />; -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/background.css b/aiida_workgraph/web/frontend/src/rete/customization/background.css deleted file mode 100644 index ff1181f4..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/background.css +++ /dev/null @@ -1,21 +0,0 @@ -.fill-area { - display: table; - z-index: -1; - position: absolute; - top: -320000px; - left: -320000px; - width: 640000px; - height: 640000px; -} - -.background { - background-color: #ffffff; - opacity: 1; - background-image: linear-gradient(#f1f1f1 3.2px, transparent 3.2px), - linear-gradient(90deg, #f1f1f1 3.2px, transparent 3.2px), - linear-gradient(#f1f1f1 1.6px, transparent 1.6px), - linear-gradient(90deg, #f1f1f1 1.6px, #ffffff 1.6px); - background-size: 80px 80px, 80px 80px, 16px 16px, 16px 16px; - background-position: -3.2px -3.2px, -3.2px -3.2px, -1.6px -1.6px, - -1.6px -1.6px; -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/custom-background.ts b/aiida_workgraph/web/frontend/src/rete/customization/custom-background.ts deleted file mode 100644 index 13861856..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/custom-background.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { BaseSchemes } from "rete"; -import { AreaPlugin } from "rete-area-plugin"; -import './background.css' - -export function addCustomBackground( - area: AreaPlugin -) { - const background = document.createElement("div"); - - background.classList.add("background"); - background.classList.add("fill-area"); - - area.area.content.add(background); -} diff --git a/aiida_workgraph/web/frontend/src/rete/customization/vars.ts b/aiida_workgraph/web/frontend/src/rete/customization/vars.ts deleted file mode 100644 index c79b882e..00000000 --- a/aiida_workgraph/web/frontend/src/rete/customization/vars.ts +++ /dev/null @@ -1,3 +0,0 @@ -export const $nodewidth = 200; -export const $socketmargin = 6; -export const $socketsize = 16; diff --git a/aiida_workgraph/web/frontend/src/rete/index.ts b/aiida_workgraph/web/frontend/src/rete/index.ts deleted file mode 100644 index ba4ed726..00000000 --- a/aiida_workgraph/web/frontend/src/rete/index.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { createEditor as createDefaultEditor } from './default' -import { createEditor as createCustomEditor } from './customization' - -const factory = { - 'default': createDefaultEditor, - 'customization': createCustomEditor, -} -// eslint-disable-next-line no-restricted-globals, no-undef -const query = typeof location !== 'undefined' && new URLSearchParams(location.search) -const name = ((query && query.get('template')) || 'default') as keyof typeof factory - -const createEditor = factory[name] - -if (!createEditor) { - throw new Error(`template with name ${name} not found`) -} - -export { - createEditor -} diff --git a/docs/gallery/autogen/quick_start.py b/docs/gallery/autogen/quick_start.py index 2417c8c5..91eab711 100644 --- a/docs/gallery/autogen/quick_start.py +++ b/docs/gallery/autogen/quick_start.py @@ -168,14 +168,14 @@ def multiply(x, y): wg = WorkGraph("second_workflow") -# You might need to adapt the code_label to python3 if you use this as your default python -wg.add_task("PythonJob", function=add, name="add", code_label="python") +# You might need to adapt the label to python3 if you use this as your default python +wg.add_task("PythonJob", function=add, name="add", command_info={"label": "python"}) wg.add_task( "PythonJob", function=multiply, name="multiply", x=wg.tasks["add"].outputs[0], - code_label="python", + command_info={"label": "python"}, ) # export the workgraph to html file so that it can be visualized in a browser @@ -195,7 +195,7 @@ def multiply(x, y): # # **Data**: Users can (and is recoomaneded) use normal Python data as # input. The workgraph will transfer the data to AiiDA data -# (``GeneralData``) using pickle. +# (``PickledData``) using pickle. # # **Python Version**: since pickle is used to store and load data, the # Python version on the remote computer should match the one used in the @@ -275,14 +275,14 @@ def multiply(x, y): wg = WorkGraph("third_workflow") -# You might need to adapt the code_label to python3 if you use this as your default python -wg.add_task("PythonJob", function=add, name="add", code_label="python") +# You might need to adapt the label to python3 if you use this as your default python +wg.add_task("PythonJob", function=add, name="add", command_info={"label": "python"}) wg.add_task( "PythonJob", function=multiply, name="multiply", parent_folder=wg.tasks["add"].outputs["remote_folder"], - code_label="python", + command_info={"label": "python"}, ) wg.to_html() diff --git a/docs/gallery/concept/autogen/task.py b/docs/gallery/concept/autogen/task.py index 0890fd4f..35eec338 100644 --- a/docs/gallery/concept/autogen/task.py +++ b/docs/gallery/concept/autogen/task.py @@ -54,15 +54,27 @@ def multiply(x, y): ###################################################################### # If you want to change the name of the output ports, or if there are more # than one output. You can define the outputs explicitly. For example: -# ``{"name": "sum", "identifier": "workgraph.Any"}``, where the ``identifier`` -# indicates the data type. The data type tell the code how to display the -# port in the GUI, validate the data, and serialize data into database. We -# use ``workgraph.Any`` for any data type. For the moment, the data validation is + + +# define the outputs explicitly +@task(outputs=["sum", "diff"]) +def add_minus(x, y): + return {"sum": x + y, "difference": x - y} + + +print("Inputs:", add_minus.task().inputs.keys()) +print("Outputs:", add_minus.task().outputs.keys()) + +###################################################################### +# One can also add an ``identifier`` to indicates the data type. The data +# type tell the code how to display the port in the GUI, validate the data, +# and serialize data into database. +# We use ``workgraph.Any`` for any data type. For the moment, the data validation is # experimentally supported, and the GUI display is not implemented. Thus, # I suggest you to always ``workgraph.Any`` for the port. # -# define add calcfunction task +# define the outputs with identifier @task( outputs=[ {"name": "sum", "identifier": "workgraph.Any"}, @@ -73,10 +85,6 @@ def add_minus(x, y): return {"sum": x + y, "difference": x - y} -print("Inputs:", add_minus.task().inputs.keys()) -print("Outputs:", add_minus.task().outputs.keys()) - - ###################################################################### # Then, one can use the task inside the WorkGraph: # @@ -125,9 +133,32 @@ def add_minus(x, y): if "." not in output.name: print(f" - {output.name}") +###################################################################### +# For specifying the outputs, the most explicit way is to provide a list of dictionaries, as shown above. In addition, +# as a shortcut, it is also possible to pass a list of strings. In that case, WorkGraph will internally convert the list +# of strings into a list of dictionaries in which case, each ``name`` key will be assigned each passed string value. +# Furthermore, also a mixed list of string and dict elements can be passed, which can be useful in cases where multiple +# outputs should be specified, but more detailed properties are only required for some of the outputs. The above also +# applies for the ``outputs`` argument of the ``@task`` decorator introduced earlier, as well as the ``inputs``, given +# that they are explicitly specified rather than derived from the signature of the ``Callable``. Finally, all lines +# below are valid specifiers for the ``outputs`` of the ``build_task`: +# + +NormTask = build_task(norm, outputs=["norm"]) +NormTask = build_task(norm, outputs=["norm", "norm2"]) +NormTask = build_task( + norm, outputs=["norm", {"name": "norm2", "identifier": "workgraph.Any"}] +) +NormTask = build_task( + norm, + outputs=[ + {"name": "norm", "identifier": "workgraph.Any"}, + {"name": "norm2", "identifier": "workgraph.Any"}, + ], +) ###################################################################### -# One can use these AiiDA component direclty in the WorkGraph. The inputs +# One can use these AiiDA component directly in the WorkGraph. The inputs # and outputs of the task is automatically generated based on the input # and output port of the AiiDA component. In case of ``calcfunction``, the # default output is ``result``. If there are more than one output task, diff --git a/docs/gallery/howto/autogen/aggregate.py b/docs/gallery/howto/autogen/aggregate.py index 96ef206f..fc3138fc 100644 --- a/docs/gallery/howto/autogen/aggregate.py +++ b/docs/gallery/howto/autogen/aggregate.py @@ -255,7 +255,7 @@ def generator_loop(nb_iterations: Int): wg = WorkGraph() for i in range(nb_iterations.value): # this can be chosen as wanted generator_task = wg.add_task(generator, name=f"generator{i}", seed=Int(i)) - generator_task.set_context({"result": f"generated.seed{i}"}) + generator_task.set_context({f"generated.seed{i}": "result"}) return wg diff --git a/docs/gallery/howto/autogen/graph_builder.py b/docs/gallery/howto/autogen/graph_builder.py index 89ec4cb3..f67d1a04 100644 --- a/docs/gallery/howto/autogen/graph_builder.py +++ b/docs/gallery/howto/autogen/graph_builder.py @@ -197,7 +197,7 @@ def for_loop(nb_iterations: Int): # of the graph builder decorator. # Put result of the task to the context under the name task_out - task.set_context({"result": "task_out"}) + task.set_context({"task_out": "result"}) # If want to know more about the usage of the context please refer to the # context howto in the documentation return wg @@ -244,7 +244,7 @@ def if_then_else(i: Int): task = wg.add_task(modulo_two, x=i) # same concept as before, please read the for loop example for explanation - task.set_context({"result": "task_out"}) + task.set_context({"task_out": "result"}) return wg diff --git a/docs/gallery/howto/autogen/parallel.py b/docs/gallery/howto/autogen/parallel.py index 97ce88a3..240c22b4 100644 --- a/docs/gallery/howto/autogen/parallel.py +++ b/docs/gallery/howto/autogen/parallel.py @@ -94,7 +94,7 @@ def multiply_parallel_gather(X, y): multiply1 = wg.add_task(multiply, x=value, y=y) # add result of multiply1 to `self.context.mul` # self.context.mul is a dict {"a": value1, "b": value2, "c": value3} - multiply1.set_context({"result": f"mul.{key}"}) + multiply1.set_context({f"mul.{key}": "result"}) return wg diff --git a/docs/gallery/tutorial/autogen/eos.py b/docs/gallery/tutorial/autogen/eos.py index 8d88a42a..ae457d6b 100644 --- a/docs/gallery/tutorial/autogen/eos.py +++ b/docs/gallery/tutorial/autogen/eos.py @@ -54,7 +54,7 @@ def all_scf(structures, scf_inputs): pw1 = wg.add_task(PwCalculation, name=f"pw1_{key}", structure=structure) pw1.set(scf_inputs) # save the output parameters to the context - pw1.set_context({"output_parameters": f"result.{key}"}) + pw1.set_context({f"result.{key}": "output_parameters"}) return wg diff --git a/docs/gallery/tutorial/autogen/zero_to_hero.py b/docs/gallery/tutorial/autogen/zero_to_hero.py index 664ca9a2..151e1981 100644 --- a/docs/gallery/tutorial/autogen/zero_to_hero.py +++ b/docs/gallery/tutorial/autogen/zero_to_hero.py @@ -426,7 +426,7 @@ def all_scf(structures, scf_inputs): pw1 = wg.add_task(PwCalculation, name=f"pw1_{key}", structure=structure) pw1.set(scf_inputs) # save the output parameters to the context - pw1.set_context({"output_parameters": f"result.{key}"}) + pw1.set_context({f"result.{key}": "output_parameters"}) return wg diff --git a/docs/source/blog/aiida_python.ipynb b/docs/source/blog/aiida_python.ipynb index 63ae3470..2ff5f64f 100644 --- a/docs/source/blog/aiida_python.ipynb +++ b/docs/source/blog/aiida_python.ipynb @@ -122,7 +122,7 @@ "\n", "**Computer**: Users can designate the remote computer where the job will be executed. This action will create an AiiDA code, `python3@computer`, if it does not already exist.\n", "\n", - "**Data**: It is recommended that users employ standard Python data types as inputs. The WorkGraph is responsible for transferring and serializing this data to AiiDA-compatible formats. During serialization, the WorkGraph searches for a corresponding AiiDA data entry point based on the module and class names (e.g., `ase.atoms.Atoms`). If an appropriate entry point is found, it is utilized for serialization. If no entry point is found, the data is serialized into binary format using GeneralData (pickle).\n", + "**Data**: It is recommended that users employ standard Python data types as inputs. The WorkGraph is responsible for transferring and serializing this data to AiiDA-compatible formats. During serialization, the WorkGraph searches for a corresponding AiiDA data entry point based on the module and class names (e.g., `ase.atoms.Atoms`). If an appropriate entry point is found, it is utilized for serialization. If no entry point is found, the data is serialized into binary format using PickledData (pickle).\n", "\n", "**Python Version**: To ensure compatibility, the Python version on the remote computer should match the version used on the localhost. Users can create a matching virtual environment using Conda. It's essential to activate this environment prior to executing the script.\n", "\n", diff --git a/docs/source/built-in/html/PythonJob_parent_folder.html b/docs/source/built-in/html/PythonJob_parent_folder.html index 033b89ea..38cde52e 100644 --- a/docs/source/built-in/html/PythonJob_parent_folder.html +++ b/docs/source/built-in/html/PythonJob_parent_folder.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "PythonJob_parent_folder", "uuid": "429664fa-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"add": {"label": "add", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "remote_folder"}], "position": [30, 30], "children": []}, "multiply": {"label": "multiply", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}, {"name": "parent_folder"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [60, 60], "children": []}}, "links": [{"from_socket": "remote_folder", "from_node": "add", "to_socket": "parent_folder", "to_node": "multiply", "state": false}]} + const workgraphData = {"name": "PythonJob_parent_folder", "uuid": "429664fa-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"add": {"label": "add", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "remote_folder"}], "position": [30, 30], "children": []}, "multiply": {"label": "multiply", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}, {"name": "parent_folder"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [60, 60], "children": []}}, "links": [{"from_socket": "remote_folder", "from_node": "add", "to_socket": "parent_folder", "to_node": "multiply", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/built-in/html/PythonJob_shell_command.html b/docs/source/built-in/html/PythonJob_shell_command.html index 20579f7f..6bbfb95a 100644 --- a/docs/source/built-in/html/PythonJob_shell_command.html +++ b/docs/source/built-in/html/PythonJob_shell_command.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "PythonJob_shell_command", "uuid": "57729aec-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"add": {"label": "add", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "multiply": {"label": "multiply", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}, {"name": "x"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [60, 60], "children": []}}, "links": [{"from_socket": "result", "from_node": "add", "to_socket": "x", "to_node": "multiply", "state": false}]} + const workgraphData = {"name": "PythonJob_shell_command", "uuid": "57729aec-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"add": {"label": "add", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "multiply": {"label": "multiply", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}, {"name": "x"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [60, 60], "children": []}}, "links": [{"from_socket": "result", "from_node": "add", "to_socket": "x", "to_node": "multiply", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/built-in/html/atomization_energy.html b/docs/source/built-in/html/atomization_energy.html index fb8db85f..b837fcff 100644 --- a/docs/source/built-in/html/atomization_energy.html +++ b/docs/source/built-in/html/atomization_energy.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "atomization_energy", "uuid": "4e03aad2-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"emt_atom": {"label": "emt_atom", "node_type": "PYTHONJOB", "inputs": [{"name": "atoms"}], "properties": {"atoms": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "emt_mol": {"label": "emt_mol", "node_type": "PYTHONJOB", "inputs": [{"name": "atoms"}], "properties": {"atoms": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "atomization_energy": {"label": "atomization_energy", "node_type": "PYTHONJOB", "inputs": [{"name": "mol"}, {"name": "energy_molecule"}, {"name": "energy_atom"}, {"name": "energy_atom"}, {"name": "energy_molecule"}], "properties": {"mol": null, "energy_molecule": null, "energy_atom": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "emt_atom", "to_socket": "energy_atom", "to_node": "atomization_energy", "state": false}, {"from_socket": "result", "from_node": "emt_mol", "to_socket": "energy_molecule", "to_node": "atomization_energy", "state": false}]} + const workgraphData = {"name": "atomization_energy", "uuid": "4e03aad2-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"emt_atom": {"label": "emt_atom", "node_type": "PYTHONJOB", "inputs": [{"name": "atoms"}], "properties": {"atoms": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "emt_mol": {"label": "emt_mol", "node_type": "PYTHONJOB", "inputs": [{"name": "atoms"}], "properties": {"atoms": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "atomization_energy": {"label": "atomization_energy", "node_type": "PYTHONJOB", "inputs": [{"name": "mol"}, {"name": "energy_molecule"}, {"name": "energy_atom"}, {"name": "energy_atom"}, {"name": "energy_molecule"}], "properties": {"mol": null, "energy_molecule": null, "energy_atom": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "emt_atom", "to_socket": "energy_atom", "to_node": "atomization_energy", "state": false}, {"from_socket": "result", "from_node": "emt_mol", "to_socket": "energy_molecule", "to_node": "atomization_energy", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/built-in/html/first_workflow.html b/docs/source/built-in/html/first_workflow.html index 23fdff0e..3d216323 100644 --- a/docs/source/built-in/html/first_workflow.html +++ b/docs/source/built-in/html/first_workflow.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "first_workflow", "uuid": "3ae129ca-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"add": {"label": "add", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "multiply": {"label": "multiply", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}, {"name": "x"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [60, 60], "children": []}}, "links": [{"from_socket": "result", "from_node": "add", "to_socket": "x", "to_node": "multiply", "state": false}]} + const workgraphData = {"name": "first_workflow", "uuid": "3ae129ca-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"add": {"label": "add", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "multiply": {"label": "multiply", "node_type": "PYTHONJOB", "inputs": [{"name": "x"}, {"name": "y"}, {"name": "x"}], "properties": {"x": null, "y": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [60, 60], "children": []}}, "links": [{"from_socket": "result", "from_node": "add", "to_socket": "x", "to_node": "multiply", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/built-in/html/pythonjob_eos_emt.html b/docs/source/built-in/html/pythonjob_eos_emt.html index 53ea1df6..3117b3e7 100644 --- a/docs/source/built-in/html/pythonjob_eos_emt.html +++ b/docs/source/built-in/html/pythonjob_eos_emt.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "pythonjob_eos_emt", "uuid": "93ea2a80-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"scale_atoms": {"label": "scale_atoms", "node_type": "PYTHONJOB", "inputs": [{"name": "atoms"}, {"name": "scales"}], "properties": {"atoms": null, "scales": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "scaled_atoms"}, {"name": "volumes"}], "position": [30, 30], "children": []}, "calculate_enegies": {"label": "calculate_enegies", "node_type": "GRAPH_BUILDER", "inputs": [{"name": "scaled_atoms"}, {"name": "scaled_atoms"}], "properties": {"scaled_atoms": null, "_wait": null}, "outputs": [{"name": "results"}], "position": [60, 60], "children": []}, "fit_eos": {"label": "fit_eos", "node_type": "PYTHONJOB", "inputs": [{"name": "volumes"}, {"name": "emt_results"}, {"name": "volumes"}, {"name": "emt_results"}], "properties": {"volumes": null, "emt_results": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_kwargs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "scaled_atoms", "from_node": "scale_atoms", "to_socket": "scaled_atoms", "to_node": "calculate_enegies", "state": false}, {"from_socket": "volumes", "from_node": "scale_atoms", "to_socket": "volumes", "to_node": "fit_eos", "state": false}, {"from_socket": "results", "from_node": "calculate_enegies", "to_socket": "emt_results", "to_node": "fit_eos", "state": false}]} + const workgraphData = {"name": "pythonjob_eos_emt", "uuid": "93ea2a80-71ab-11ef-9079-906584de3e5b", "state": "CREATED", "nodes": {"scale_atoms": {"label": "scale_atoms", "node_type": "PYTHONJOB", "inputs": [{"name": "atoms"}, {"name": "scales"}], "properties": {"atoms": null, "scales": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [{"name": "scaled_atoms"}, {"name": "volumes"}], "position": [30, 30], "children": []}, "calculate_enegies": {"label": "calculate_enegies", "node_type": "GRAPH_BUILDER", "inputs": [{"name": "scaled_atoms"}, {"name": "scaled_atoms"}], "properties": {"scaled_atoms": null, "_wait": null}, "outputs": [{"name": "results"}], "position": [60, 60], "children": []}, "fit_eos": {"label": "fit_eos", "node_type": "PYTHONJOB", "inputs": [{"name": "volumes"}, {"name": "emt_results"}, {"name": "volumes"}, {"name": "emt_results"}], "properties": {"volumes": null, "emt_results": null, "_wait": null, "computer": null, "code_label": null, "code_path": null, "prepend_text": null, "metadata": null, "metadata.store_provenance": null, "metadata.description": null, "metadata.label": null, "metadata.call_link_label": null, "metadata.disable_cache": null, "metadata.dry_run": null, "metadata.computer": null, "metadata.options": null, "metadata.options.input_filename": null, "metadata.options.output_filename": null, "metadata.options.submit_script_filename": null, "metadata.options.scheduler_stdout": null, "metadata.options.scheduler_stderr": null, "metadata.options.resources": null, "metadata.options.max_wallclock_seconds": null, "metadata.options.custom_scheduler_commands": null, "metadata.options.queue_name": null, "metadata.options.rerunnable": null, "metadata.options.account": null, "metadata.options.qos": null, "metadata.options.withmpi": null, "metadata.options.mpirun_extra_params": null, "metadata.options.import_sys_environment": null, "metadata.options.environment_variables": null, "metadata.options.environment_variables_double_quotes": null, "metadata.options.priority": null, "metadata.options.max_memory_kb": null, "metadata.options.prepend_text": null, "metadata.options.append_text": null, "metadata.options.parser_name": null, "metadata.options.additional_retrieve_list": null, "metadata.options.stash": null, "metadata.options.stash.target_base": null, "metadata.options.stash.source_list": null, "metadata.options.stash.stash_mode": null, "code": null, "monitors": null, "remote_folder": null, "function": null, "function_source_code": null, "function_name": null, "process_label": null, "function_inputs": null, "function_outputs": null, "parent_folder": null, "parent_folder_name": null, "parent_output_folder": null, "upload_files": null, "copy_files": null, "additional_retrieve_list": null}, "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "scaled_atoms", "from_node": "scale_atoms", "to_socket": "scaled_atoms", "to_node": "calculate_enegies", "state": false}, {"from_socket": "volumes", "from_node": "scale_atoms", "to_socket": "volumes", "to_node": "fit_eos", "state": false}, {"from_socket": "results", "from_node": "calculate_enegies", "to_socket": "emt_results", "to_node": "fit_eos", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/built-in/pythonjob.ipynb b/docs/source/built-in/pythonjob.ipynb index 64072029..a7e286ad 100644 --- a/docs/source/built-in/pythonjob.ipynb +++ b/docs/source/built-in/pythonjob.ipynb @@ -8,7 +8,7 @@ "# PythonJob\n", "## Introduction\n", "\n", - "The `PythonJob` is a built-in task that allows users to run Python functions on a remote computer. It is designed to enable users from non-AiiDA communities to run their Python functions remotely and construct workflows with checkpoints, maintaining all data provenance. For instance, users can use ASE's calculator to run a DFT calculation on a remote computer directly. Users only need to write normal Python code, and the WorkGraph will handle the data transformation to AiiDA data.\n", + "The `PythonJob` is a built-in task, which uses the [aiida-pythonjob](https://aiida-pythonjob.readthedocs.io/en/latest/) package to run Python functions on a remote computer. It is designed to enable users from non-AiiDA communities to run their Python functions remotely and construct workflows with checkpoints, maintaining all data provenance. For instance, users can use ASE's calculator to run a DFT calculation on a remote computer directly. Users only need to write normal Python code, and the WorkGraph will handle the data transformation to AiiDA data.\n", "\n", "### Key Features\n", "\n", @@ -142,7 +142,7 @@ "One can use the `create_conda_env` function to create a conda environment on the remote computer. The function will create a conda environment with the specified packages and modules. The function will update the packages if the environment already exists.\n", "\n", "```python\n", - "from aiida_workgraph.utils import create_conda_env\n", + "from aiida_pythonjob.utils import create_conda_env\n", "# create a conda environment on remote computer\n", "create_conda_env(\"merlin6\", \"test_pythonjob\", modules=[\"anaconda\"],\n", " pip=[\"numpy\", \"matplotlib\"],\n", @@ -285,7 +285,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__x\n", + "function_inputs__x\n", "\n", "\n", "\n", @@ -819,7 +819,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__energy_molecule\n", + "function_inputs__energy_molecule\n", "\n", "\n", "\n", @@ -827,7 +827,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__energy_atom\n", + "function_inputs__energy_atom\n", "\n", "\n", "\n", @@ -1128,7 +1128,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__x\n", + "function_inputs__x\n", "\n", "\n", "\n", @@ -1511,7 +1511,7 @@ " emt1 = wg.add_task(\"PythonJob\", function=emt, name=f\"emt1_{key}\", atoms=atoms)\n", " emt1.set({\"computer\": \"localhost\"})\n", " # save the output parameters to the context\n", - " emt1.set_context({\"result\": f\"results.{key}\"})\n", + " emt1.set_context({f\"results.{key}\": \"result\"})\n", " return wg\n", "\n", "\n", @@ -1805,7 +1805,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__atoms\n", + "function_inputs__atoms\n", "\n", "\n", "\n", @@ -1829,7 +1829,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__atoms\n", + "function_inputs__atoms\n", "\n", "\n", "\n", @@ -1853,7 +1853,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__atoms\n", + "function_inputs__atoms\n", "\n", "\n", "\n", @@ -1861,7 +1861,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__volumes\n", + "function_inputs__volumes\n", "\n", "\n", "\n", @@ -2060,7 +2060,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__emt_results__s_0\n", + "function_inputs__emt_results__s_0\n", "\n", "\n", "\n", @@ -2068,7 +2068,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__emt_results__s_1\n", + "function_inputs__emt_results__s_1\n", "\n", "\n", "\n", @@ -2076,7 +2076,7 @@ "\n", "\n", "INPUT_CALC\n", - "function_kwargs__emt_results__s_2\n", + "function_inputs__emt_results__s_2\n", "\n", "\n", "\n", @@ -2360,9 +2360,9 @@ "\n", "\n", "## Define your data serializer\n", - "Workgraph search data serializer from the `aiida.data` entry point by the module name and class name (e.g., `ase.atoms.Atoms`). \n", + "PythonJob search data serializer from the `aiida.data` entry point by the module name and class name (e.g., `ase.atoms.Atoms`). \n", "\n", - "In order to let the workgraph find the serializer, you must register the AiiDA data with the following format:\n", + "In order to let the PythonJob find the serializer, you must register the AiiDA data with the following format:\n", "```\n", "[project.entry-points.\"aiida.data\"]\n", "abc.ase.atoms.Atoms = \"abc.xyz:MyAtomsData\"\n", @@ -2371,7 +2371,7 @@ "\n", "\n", "### Avoid duplicate data serializer\n", - "If you have multiple plugins that register the same data serializer, the workgraph will raise an error. You can avoid this by selecting the plugin that you want to use in the configuration file.\n", + "If you have multiple plugins that register the same data serializer, the PythonJob will raise an error. You can avoid this by selecting the plugin that you want to use in the configuration file.\n", "\n", "```json\n", "{\n", @@ -2381,96 +2381,8 @@ "}\n", "```\n", "\n", - "Save the configuration file as `workgraph.json` in the aiida configuration directory (by default, `~/.aiida` directory).\n", - "\n", - "\n", - "## Use PythonJob outside WorkGraph\n", - "One can use the `PythonJob` task outside the WorkGraph to run a Python function on a remote computer. For example, in a `WorkChain` or run a single `CalcJob` calculation.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "9a1fa5e6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result: 3\n" - ] - } - ], - "source": [ - "from aiida import orm, load_profile\n", - "from aiida.engine import run_get_node\n", - "from aiida_workgraph.calculations.python import PythonJob\n", - "\n", - "load_profile()\n", - "\n", - "python_code = orm.load_code(\"python3@localhost\")\n", - "\n", - "def add(x, y):\n", - " return x + y\n", - "\n", - "result, node = run_get_node(PythonJob, code=python_code,\n", - " function=add,\n", - " function_kwargs = {\"x\": orm.Int(1), \"y\": orm.Int(2)},\n", - " function_outputs=[{\"name\": \"add\"}])\n", - "\n", - "print(\"Result: \", result[\"add\"].value)\n" - ] - }, - { - "cell_type": "markdown", - "id": "4fb22545", - "metadata": {}, - "source": [ - "You can see more details on any process, including its inputs and outputs, using the verdi command:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "86e74979", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[22mProperty Value\n", - "----------- ------------------------------------\n", - "type PythonJob\n", - "state Finished [0]\n", - "pk 151415\n", - "uuid ff25998c-98d9-4d56-995a-fe9ecd66468a\n", - "label PythonJob\n", - "description\n", - "ctime 2024-09-13 10:46:05.231456+02:00\n", - "mtime 2024-09-13 10:46:08.263554+02:00\n", - "computer [1] localhost\n", - "\n", - "Inputs PK Type\n", - "---------------- ------ ---------------\n", - "function_kwargs\n", - " x 151412 Int\n", - " y 151413 Int\n", - "code 42316 InstalledCode\n", - "function 151411 PickledFunction\n", - "function_outputs 151414 List\n", - "\n", - "Outputs PK Type\n", - "------------- ------ ----------\n", - "add 151419 Int\n", - "remote_folder 151417 RemoteData\n", - "retrieved 151418 FolderData\u001b[0m\n" - ] - } - ], - "source": [ - "%verdi process show {node.pk}" + "Save the configuration file as `pythonjob.json` in the aiida configuration directory (by default, `~/.aiida` directory).\n", + "\n" ] } ], diff --git a/docs/source/development/python_task.ipynb b/docs/source/development/python_task.ipynb index cd491c9d..e5a0777a 100644 --- a/docs/source/development/python_task.ipynb +++ b/docs/source/development/python_task.ipynb @@ -36,7 +36,7 @@ "### About the data\n", "For a `CalcJob`, the input data needs to be an AiiDA data node; however, we don't require the user to install AiiDA or the same Python environment on a remote computer. This means we should pass normal Python data as arguments when running the Python function on the remote computer. The `WorkGraphEngine` will handle this data transformation when preparing and launching the `CalcJob`.\n", "\n", - "All AiiDA data that will be passed to the function should have a `value` attribute, which corresponds to its raw Python data. The `GeneralData`, `Int`, `Float`, `Str`, `Bool` fulfill this requirement, while `List`, `Dict` and `StructureData` are not.\n", + "All AiiDA data that will be passed to the function should have a `value` attribute, which corresponds to its raw Python data. The `PickledData`, `Int`, `Float`, `Str`, `Bool` fulfill this requirement, while `List`, `Dict` and `StructureData` are not.\n", "\n", "### Inputs and Outputs:\n", "Inputs for each task are pickled into the `inputs.pickle` file.\n", diff --git a/docs/source/howto/context.ipynb b/docs/source/howto/context.ipynb index 016e36b0..9ccbca59 100644 --- a/docs/source/howto/context.ipynb +++ b/docs/source/howto/context.ipynb @@ -14,11 +14,11 @@ "metadata": {}, "source": [ "## Introduction\n", - "In AiiDA workflow, the context is a internal data container that can hold and pass information between steps. It is very usefull for complex workflows.\n", + "In AiiDA workflow, the context is a internal container that can hold data that shared between different tasks. It is very usefull for complex workflows.\n", "\n", "## Pass data to context\n", "\n", - "There are three ways to pass data to context.\n", + "There are three ways to set data to context.\n", "\n", "- Initialize the context data when creating the WorkGraph.\n", " ```python\n", @@ -27,7 +27,7 @@ " wg.context = {\"x\": Int(2), \"data.y\": Int(3)}\n", " ```\n", "\n", - "- Export the task result to context.\n", + "- Set the task result to context when the task is done.\n", " ```python\n", " # define add task\n", " @task.calcfunction()\n", @@ -35,13 +35,13 @@ " return x + y\n", " add1 = wg.add_task(add, \"add1\", x=2, y=3)\n", " # set result of add1 to context.sum\n", - " add1.set_context({\"result\": \"sum\"})\n", + " add1.set_context({\"sum\": \"result\"})\n", " ```\n", "\n", - "- Use the `to_context` task to save the result to context.\n", + "- Use the `set_context` task to set either the task result or a constant value to the context.\n", "\n", " ```python\n", - " wg.add_task(\"workgraph.to_context\", name=\"to_ctx1\", key=\"sum\", value=add1.outputs[\"result\"])\n", + " wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"sum\", value=add1.outputs[\"result\"])\n", " ```\n", "\n", "\n", @@ -49,8 +49,8 @@ "To organize the context data (e.g. group data), The keys may contain dots `.`, which will creating dictionary in the context. Here is an example, to group the results of all add tasks to `context.sum`:\n", "\n", "```python\n", - "add1.set_context({\"result\": \"sum.add1\"})\n", - "add2.set_context({\"result\": \"sum.add2\"})\n", + "add1.set_context({\"sum.add1\": \"result\"})\n", + "add2.set_context({\"sum.add2\": \"result\"})\n", "```\n", "here, `context.sum` will be:\n", "```python\n", @@ -75,14 +75,14 @@ " nt = WorkGraph(\"while_workgraph\")\n", " add1 = wg.add_task(add, x=2, y=3)\n", " add2 = wg.add_task(add, x=2, y=3)\n", - " add1.set_context({\"result\": \"sum.add1\"})\n", - " add2.set_context({\"result\": \"sum.add2\"})\n", + " add1.set_context({\"sum.add1\": \"result\"})\n", + " add2.set_context({\"sum.add2\": \"result\"})\n", " ```\n", "\n", - "- One can use the `from_context` task to get the data from context. **This task will be shown in the GUI**\n", + "- One can use the `get_context` task to get the data from context. **This task will be shown in the GUI**\n", "\n", " ```python\n", - " wg.add_task(\"workgraph.from_context\", name=\"from_ctx1\", key=\"sum.add1\")\n", + " wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"sum.add1\")\n", " ```" ] }, @@ -136,10 +136,10 @@ "wg = WorkGraph(name=\"test_workgraph_ctx\")\n", "# Set the context of the workgraph\n", "wg.context = {\"x\": 2, \"data.y\": 3}\n", - "from_ctx1 = wg.add_task(\"workgraph.from_context\", name=\"from_ctx1\", key=\"x\")\n", - "add1 = wg.add_task(add, \"add1\", x=from_ctx1.outputs[\"result\"],\n", + "get_ctx1 = wg.add_task(\"workgraph.get_context\", name=\"get_ctx1\", key=\"x\")\n", + "add1 = wg.add_task(add, \"add1\", x=get_ctx1.outputs[\"result\"],\n", " y=\"{{data.y}}\")\n", - "to_ctx1 = wg.add_task(\"workgraph.to_context\", name=\"to_ctx1\", key=\"x\",\n", + "set_ctx1 = wg.add_task(\"workgraph.set_context\", name=\"set_ctx1\", key=\"x\",\n", " value=add1.outputs[\"result\"])\n", "wg.to_html()\n", "# wg" @@ -150,7 +150,7 @@ "id": "f6969061", "metadata": {}, "source": [ - "As shown in the GUI, the `from_context` task and `to_context` tasks are shown in the GUI. However, the context variable using the `set_context` method or `{{}}` is not shown in the GUI." + "As shown in the GUI, the `get_context` task and `to_context` tasks are shown in the GUI. However, the context variable using the `set_context` method or `{{}}` is not shown in the GUI." ] }, { diff --git a/docs/source/howto/for.ipynb b/docs/source/howto/for.ipynb index e3d4df4d..6fda2663 100644 --- a/docs/source/howto/for.ipynb +++ b/docs/source/howto/for.ipynb @@ -103,7 +103,7 @@ " multiply1 = wg.add_task(multiply, name=\"multiply1\", x=\"{{ i }}\", y=2)\n", " add1 = wg.add_task(add, name=\"add1\", x=\"{{ total }}\")\n", " # update the context variable\n", - " add1.set_context({\"result\": \"total\"})\n", + " add1.set_context({\"total\": \"result\"})\n", " wg.add_link(multiply1.outputs[\"result\"], add1.inputs[\"y\"])\n", " # don't forget to return the workgraph\n", " return wg" diff --git a/docs/source/howto/html/test_workgraph_ctx.html b/docs/source/howto/html/test_workgraph_ctx.html index 0c3ffd7f..08f8b000 100644 --- a/docs/source/howto/html/test_workgraph_ctx.html +++ b/docs/source/howto/html/test_workgraph_ctx.html @@ -61,7 +61,7 @@ const { RenderUtils } = ReteRenderUtils; const styled = window.styled; - const workgraphData = {"name": "test_workgraph_ctx", "uuid": "50bec156-5a60-11ef-888c-906584de3e5b", "state": "CREATED", "nodes": {"from_ctx1": {"label": "from_ctx1", "node_type": "FROM_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50c8610c-5a60-11ef-888c-906584de3e5b", "node_uuid": "50c85f5e-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "50d07e78-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "from_ctx1", "from_socket": "result", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "50d07edc-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "to_ctx1": {"label": "to_ctx1", "node_type": "TO_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50d858be-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value", "identifier": "workgraph.any", "uuid": "50d8599a-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value"}], "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "from_ctx1", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b", "to_socket": "x", "to_node": "add1", "state": false}, {"from_socket": "result", "from_node": "add1", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b", "to_socket": "value", "to_node": "to_ctx1", "state": false}]} + const workgraphData = {"name": "test_workgraph_ctx", "uuid": "50bec156-5a60-11ef-888c-906584de3e5b", "state": "CREATED", "nodes": {"get_ctx1": {"label": "get_ctx1", "node_type": "GET_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50c8610c-5a60-11ef-888c-906584de3e5b", "node_uuid": "50c85f5e-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}], "outputs": [{"name": "result"}], "position": [30, 30], "children": []}, "add1": {"label": "add1", "node_type": "CALCFUNCTION", "inputs": [{"name": "x", "identifier": "workgraph.any", "uuid": "50d07e78-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "get_ctx1", "from_socket": "result", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "y", "identifier": "workgraph.any", "uuid": "50d07edc-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d0789c-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "x"}], "outputs": [{"name": "result"}], "position": [60, 60], "children": []}, "set_ctx1": {"label": "set_ctx1", "node_type": "SET_CONTEXT", "inputs": [{"name": "key", "identifier": "workgraph.any", "uuid": "50d858be-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value", "identifier": "workgraph.any", "uuid": "50d8599a-5a60-11ef-888c-906584de3e5b", "node_uuid": "50d856f2-5a60-11ef-888c-906584de3e5b", "type": "INPUT", "link_limit": 1, "links": [{"from_node": "add1", "from_socket": "result", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b"}], "serialize": {"path": "node_graph.serializer", "name": "serialize_pickle"}, "deserialize": {"path": "node_graph.serializer", "name": "deserialize_pickle"}}, {"name": "value"}], "outputs": [], "position": [90, 90], "children": []}}, "links": [{"from_socket": "result", "from_node": "get_ctx1", "from_socket_uuid": "50c861fc-5a60-11ef-888c-906584de3e5b", "to_socket": "x", "to_node": "add1", "state": false}, {"from_socket": "result", "from_node": "add1", "from_socket_uuid": "50d07fa4-5a60-11ef-888c-906584de3e5b", "to_socket": "value", "to_node": "set_ctx1", "state": false}]} // Define Schemes to use in vanilla JS const Schemes = { diff --git a/docs/source/howto/monitor.ipynb b/docs/source/howto/monitor.ipynb index 3104ea86..0a4d1361 100644 --- a/docs/source/howto/monitor.ipynb +++ b/docs/source/howto/monitor.ipynb @@ -111,12 +111,10 @@ "monitor2 = wg.add_task(\"workgraph.file_monitor\", filepath=\"/tmp/test.txt\")\n", "```\n", "\n", - "## Awaitable Task Decorator\n", + "### Awaitable Task Decorator\n", "\n", "The `awaitable` decorator allows for the integration of `asyncio` within tasks, letting users control asynchronous functions.\n", "\n", - "### General Awaitable Task\n", - "\n", "Define and use an awaitable task within the WorkGraph.\n", "\n" ] @@ -204,7 +202,7 @@ "id": "1ae83d3f", "metadata": {}, "source": [ - "## Kill the monitor task\n", + "### Kill the monitor task\n", "\n", "One can kill a running monitor task by using the following command:\n", "\n", @@ -220,7 +218,7 @@ "\n", "The awaitable task lets the WorkGraph enter a `Waiting` state, yielding control to the asyncio event loop. This enables other tasks to run concurrently, although long-running calculations may delay the execution of awaitable tasks.\n", "\n", - "## Conclusion\n", + "### Conclusion\n", "\n", "These enhancements provide powerful tools for managing dependencies and asynchronous operations within WorkGraph, offering greater flexibility and efficiency in task execution." ] diff --git a/docs/source/howto/parallel.ipynb b/docs/source/howto/parallel.ipynb index d99041ae..5cdad860 100644 --- a/docs/source/howto/parallel.ipynb +++ b/docs/source/howto/parallel.ipynb @@ -401,7 +401,7 @@ " multiply1 = wg.add_task(multiply, x=value, y=y)\n", " # add result of multiply1 to `self.context.mul`\n", " # self.context.mul is a dict {\"a\": value1, \"b\": value2, \"c\": value3}\n", - " multiply1.set_context({\"result\": f\"mul.{key}\"})\n", + " multiply1.set_context({f\"mul.{key}\": \"result\"})\n", " return wg\n", "\n", "@task.calcfunction()\n", diff --git a/docs/source/howto/waiting_on.ipynb b/docs/source/howto/waiting_on.ipynb index bd41045b..f9c2e2c7 100644 --- a/docs/source/howto/waiting_on.ipynb +++ b/docs/source/howto/waiting_on.ipynb @@ -105,9 +105,9 @@ "\n", "wg = WorkGraph(\"test_wait\")\n", "add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n", - "add1.set_context({\"result\": \"data.add1\"})\n", + "add1.set_context({\"data.add1\": \"result\"})\n", "add2 = wg.add_task(add, name=\"add2\", x=2, y=2)\n", - "add2.set_context({\"result\": \"data.add2\"})\n", + "add2.set_context({\"data.add2\": \"result\"})\n", "# let sum task wait for add1 and add2, and the `data` in the context is ready\n", "sum3 = wg.add_task(sum, name=\"sum1\", datas=\"{{data}}\")\n", "sum3.waiting_on.add([\"add1\", \"add2\"])\n", diff --git a/docs/source/howto/while.ipynb b/docs/source/howto/while.ipynb index 1330679d..af3eff2e 100644 --- a/docs/source/howto/while.ipynb +++ b/docs/source/howto/while.ipynb @@ -126,7 +126,7 @@ "# set a context variable before running.\n", "wg.context = {\"should_run\": True}\n", "add1 = wg.add_task(add, name=\"add1\", x=1, y=1)\n", - "add1.set_context({\"result\": \"n\"})\n", + "add1.set_context({\"n\": \"result\"})\n", "#---------------------------------------------------------------------\n", "# Create the while tasks\n", "compare1 = wg.add_task(compare, name=\"compare1\", x=\"{{n}}\", y=50)\n", @@ -138,7 +138,7 @@ " x=add2.outputs[\"result\"],\n", " y=2)\n", "# update the context variable\n", - "multiply1.set_context({\"result\": \"n\"})\n", + "multiply1.set_context({\"n\": \"result\"})\n", "while1.children.add([\"add2\", \"multiply1\"])\n", "#---------------------------------------------------------------------\n", "add3 = wg.add_task(add, name=\"add3\", x=1, y=1)\n", @@ -845,7 +845,7 @@ " multiply1 = wg.add_task(multiply, name=\"multiply1\", x=add1.outputs[\"result\"],\n", " y=2)\n", " # update the context variable\n", - " multiply1.set_context({\"result\": \"n\"})\n", + " multiply1.set_context({\"n\": \"result\"})\n", " return wg" ] }, diff --git a/docs/source/tutorial/eos.ipynb b/docs/source/tutorial/eos.ipynb index 7e71d11c..c4fe5a26 100644 --- a/docs/source/tutorial/eos.ipynb +++ b/docs/source/tutorial/eos.ipynb @@ -48,7 +48,7 @@ " pw1 = wg.add_task(PwCalculation, name=f\"pw1_{key}\", structure=structure)\n", " pw1.set(scf_inputs)\n", " # save the output parameters to the context\n", - " pw1.set_context({\"output_parameters\": f\"result.{key}\"})\n", + " pw1.set_context({f\"result.{key}\": \"output_parameters\"})\n", " return wg\n", "\n", "\n", diff --git a/docs/source/tutorial/zero_to_hero.ipynb b/docs/source/tutorial/zero_to_hero.ipynb index 0bebdf09..a643a822 100644 --- a/docs/source/tutorial/zero_to_hero.ipynb +++ b/docs/source/tutorial/zero_to_hero.ipynb @@ -1224,7 +1224,7 @@ " pw1 = wg.add_task(PwCalculation, name=f\"pw1_{key}\", structure=structure)\n", " pw1.set(scf_inputs)\n", " # save the output parameters to the context\n", - " pw1.set_context({\"output_parameters\": f\"result.{key}\"})\n", + " pw1.set_context({f\"result.{key}\": \"output_parameters\"})\n", " return wg\n", "\n", "\n", diff --git a/examples/example.ipynb b/examples/example.ipynb deleted file mode 100644 index 41318e31..00000000 --- a/examples/example.ipynb +++ /dev/null @@ -1,109 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Example\n", - "\n", - "## Node from WorkGraph\n", - "Create a node from a WorkGraph" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['parameters']\n", - "['parameters', 'output_parameters']\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a40ea66c06b14a798590cce3071ee554", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "NodeGraphWidget(settings={'minimap': True}, style={'width': '90%', 'height': '600px'}, value={'name': 'WorkGra…" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from workgraph_collections.qe.bands import bands_workgraph\n", - "from aiida import load_profile\n", - "load_profile()\n", - "\n", - "bands_wg = bands_workgraph(run_relax=True)\n", - "bands_wg" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2ce656b7e3da43a6952bca4c95afa3eb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "NodeGraphWidget(settings={'minmap': False}, style={'width': '40%', 'height': '600px'}, value={'nodes': {'PdosW…" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from aiida_quantumespresso.workflows.pdos import PdosWorkChain\n", - "from aiida_workgraph import WorkGraph\n", - "wg = WorkGraph()\n", - "wg.add_task(PdosWorkChain)\n", - "wg.tasks[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "aiida", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/examples_widget.ipynb b/examples/examples_widget.ipynb deleted file mode 100644 index 019d8d4c..00000000 --- a/examples/examples_widget.ipynb +++ /dev/null @@ -1,179 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## WorkGraph example" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "from aiida.engine import calcfunction\n", - "from aiida_workgraph import WorkGraph\n", - "\n", - "# define add calcfunction task\n", - "@calcfunction\n", - "def add(x, y):\n", - " return x + y\n", - "\n", - "# define multiply calcfunction task\n", - "@calcfunction\n", - "def multiply(x, y):\n", - " return x*y\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create WorkGraph in one shot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg = WorkGraph(\"test_add_multiply\")\n", - "wg.add_task(add, name=\"add1\")\n", - "wg.add_task(multiply, name=\"multiply1\")\n", - "wg.add_link(wg.tasks[\"add1\"].outputs[0], wg.tasks[\"multiply1\"].inputs[\"x\"])\n", - "wg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.tasks[\"multiply1\"].position\n", - "from aiida import load_profile, orm\n", - "load_profile()\n", - "wg.submit(wait=True, inputs = {\"add1\": {\"x\": 1, \"y\": 2}, \"multiply1\": {\"y\": 3}})\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create the workgraph step by step\n", - "First, create a empty workgraph:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "from aiida_workgraph import WorkGraph\n", - "wg = WorkGraph(\"test_add_multiply\")\n", - "wg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.add_task(add, name=\"add1\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.add_task(multiply, name=\"multiply1\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.add_link(wg.tasks[\"add1\"].outputs[0], wg.tasks[\"multiply1\"].inputs[\"x\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.tasks.delete(\"add1\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.add_task(add, name=\"add1\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.add_link(wg.tasks[\"add1\"].outputs[0], wg.tasks[\"multiply1\"].inputs[\"x\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.links.delete(\"add1.result -> multiply1.x\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wg.add_link(wg.tasks[\"add1\"].outputs[0], wg.tasks[\"multiply1\"].inputs[\"x\"])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.0" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/pyproject.toml b/pyproject.toml index ffedd88d..18e33905 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,10 +28,11 @@ dependencies = [ "numpy~=1.21", "scipy", "ase", - "node-graph==0.1.2", + "node-graph==0.1.3", "aiida-core>=2.3", "cloudpickle", "aiida-shell~=0.8", + "aiida-pythonjob==0.1.3", "fastapi", "uvicorn", "pydantic_settings", @@ -75,26 +76,13 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" [project.entry-points."aiida.cmdline"] "workgraph" = "aiida_workgraph.cli.cmd_workgraph:workgraph" -[project.entry-points."aiida.calculations"] -"workgraph.python" = "aiida_workgraph.calculations.python:PythonJob" - -[project.entry-points."aiida.parsers"] -"workgraph.python" = "aiida_workgraph.calculations.python_parser:PythonParser" - [project.entry-points.'aiida.workflows'] "workgraph.engine" = "aiida_workgraph.engine.workgraph:WorkGraphEngine" [project.entry-points."aiida.data"] -"workgraph.general" = "aiida_workgraph.orm.general_data:GeneralData" +"workgraph.pickled_data" = "aiida_workgraph.orm.pickled_data:PickledData" "workgraph.pickled_function" = "aiida_workgraph.orm.function_data:PickledFunction" "workgraph.pickled_local_function" = "aiida_workgraph.orm.function_data:PickledLocalFunction" -"workgraph.ase.atoms.Atoms" = "aiida_workgraph.orm.atoms:AtomsData" -"workgraph.builtins.int" = "aiida.orm.nodes.data.int:Int" -"workgraph.builtins.float" = "aiida.orm.nodes.data.float:Float" -"workgraph.builtins.str" = "aiida.orm.nodes.data.str:Str" -"workgraph.builtins.bool" = "aiida.orm.nodes.data.bool:Bool" -"workgraph.builtins.list"="aiida_workgraph.orm.general_data:List" -"workgraph.builtins.dict"="aiida_workgraph.orm.general_data:Dict" [project.entry-points."aiida.node"] @@ -106,8 +94,8 @@ workgraph = "aiida_workgraph.cli.cmd_workgraph:workgraph" "workgraph.if" = "aiida_workgraph.tasks.builtins:If" "workgraph.select" = "aiida_workgraph.tasks.builtins:Select" "workgraph.gather" = "aiida_workgraph.tasks.builtins:Gather" -"workgraph.to_context" = "aiida_workgraph.tasks.builtins:ToContext" -"workgraph.from_context" = "aiida_workgraph.tasks.builtins:FromContext" +"workgraph.set_context" = "aiida_workgraph.tasks.builtins:SetContext" +"workgraph.get_context" = "aiida_workgraph.tasks.builtins:GetContext" "workgraph.time_monitor" = "aiida_workgraph.tasks.monitors:TimeMonitor" "workgraph.file_monitor" = "aiida_workgraph.tasks.monitors:FileMonitor" "workgraph.task_monitor" = "aiida_workgraph.tasks.monitors:TaskMonitor" diff --git a/tests/test_ctx.py b/tests/test_ctx.py index 9ec1bee5..5f8b66bf 100644 --- a/tests/test_ctx.py +++ b/tests/test_ctx.py @@ -15,26 +15,26 @@ def test_workgraph_ctx(decorated_add: Callable) -> None: wg.context = {"x": Float(2), "data.y": Float(3), "array": array} add1 = wg.add_task(decorated_add, "add1", x="{{ x }}", y="{{ data.y }}") wg.add_task( - "workgraph.to_context", name="to_ctx1", key="x", value=add1.outputs["result"] + "workgraph.set_context", name="set_ctx1", key="x", value=add1.outputs["result"] ) - from_ctx1 = wg.add_task("workgraph.from_context", name="from_ctx1", key="x") + get_ctx1 = wg.add_task("workgraph.get_context", name="get_ctx1", key="x") # test the task can wait for another task - from_ctx1.waiting_on.add(add1) - add2 = wg.add_task(decorated_add, "add2", x=from_ctx1.outputs["result"], y=1) + get_ctx1.waiting_on.add(add1) + add2 = wg.add_task(decorated_add, "add2", x=get_ctx1.outputs["result"], y=1) wg.run() assert add2.outputs["result"].value == 6 -def test_node_to_ctx(decorated_add: Callable) -> None: +def test_task_set_ctx(decorated_add: Callable) -> None: """Set/get data to/from context.""" - wg = WorkGraph(name="test_node_to_ctx") + wg = WorkGraph(name="test_node_set_ctx") add1 = wg.add_task(decorated_add, "add1", x=Float(2).store(), y=Float(3).store()) try: - add1.set_context({"resul": "sum"}) + add1.set_context({"sum": "resul"}) except ValueError as e: assert str(e) == "Keys {'resul'} are not in the outputs of this task." - add1.set_context({"result": "sum"}) + add1.set_context({"sum": "result"}) add2 = wg.add_task(decorated_add, "add2", y="{{ sum }}") wg.add_link(add1.outputs[0], add2.inputs["x"]) wg.submit(wait=True) diff --git a/tests/test_data.py b/tests/test_data.py index 046aef3c..e22b28f9 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,10 +1,11 @@ -def test_AtomsData(): - from aiida_workgraph.orm.atoms import AtomsData - from ase.build import bulk +def test_PickledData(): + from aiida_workgraph.orm.pickled_data import PickledData - atoms = bulk("Si") - data = AtomsData(atoms) - data.store() - assert data.value == atoms - assert data.base.attributes.get("formula") == "Si2" - assert data.base.attributes.get("pbc") == [True, True, True] + class CustomData: + def __init__(self, a): + self.a = a + + data = CustomData(a=1) + pickled_data = PickledData(data) + pickled_data.store() + assert pickled_data.value.a == 1 diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 2d96c708..ff2237e6 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -3,6 +3,18 @@ from typing import Callable +def test_custom_outputs(): + """Test custom outputs.""" + + @task(outputs=["sum", {"name": "product", "identifier": "workgraph.any"}]) + def add_multiply(x, y): + return {"sum": x + y, "product": x * y} + + n = add_multiply.task() + assert "sum" in n.outputs.keys() + assert "product" in n.outputs.keys() + + @pytest.fixture(params=["decorator_factory", "decorator"]) def task_calcfunction(request): if request.param == "decorator_factory": diff --git a/tests/test_for.py b/tests/test_for.py index 8f347ebc..fdd6fa47 100644 --- a/tests/test_for.py +++ b/tests/test_for.py @@ -21,7 +21,7 @@ def add_multiply_for(sequence): ) add1 = wg.add_task(decorated_add, name="add1", x="{{ total }}") # update the context variable - add1.set_context({"result": "total"}) + add1.set_context({"total": "result"}) wg.add_link(multiply1.outputs["result"], add1.inputs["y"]) # don't forget to return the workgraph return wg diff --git a/tests/test_python.py b/tests/test_pythonjob.py similarity index 92% rename from tests/test_python.py rename to tests/test_pythonjob.py index 014a52da..d11ccf3c 100644 --- a/tests/test_python.py +++ b/tests/test_pythonjob.py @@ -28,7 +28,7 @@ def multiply(x: Any, y: Any) -> Any: x=1, y=2, computer="localhost", - code_label=python_executable_path, + command_info={"label": python_executable_path}, ) wg.add_task( decorted_multiply, @@ -36,7 +36,7 @@ def multiply(x: Any, y: Any) -> Any: x=wg.tasks["add1"].outputs["sum"], y=3, computer="localhost", - code_label=python_executable_path, + command_info={"label": python_executable_path}, ) # wg.submit(wait=True) wg.run() @@ -59,7 +59,7 @@ def test_importable_function(fixture_localhost, python_executable_path): x=1, y=2, computer="localhost", - code_label=python_executable_path, + command_info={"label": python_executable_path}, ) wg.run() assert wg.tasks["add"].outputs["result"].value.value == 3 @@ -83,7 +83,7 @@ def add(x, y=1, **kwargs): "y": 2, "kwargs": {"m": 2, "n": 3}, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, }, ) @@ -151,7 +151,7 @@ def add(x, y): y=2, # code=code, computer="localhost", - code_label=python_executable_path, + command_info={"label": python_executable_path}, ) wg.run() assert wg.tasks["add"].outputs["sum"].value.value == 3 @@ -193,7 +193,7 @@ def myfunc(x, y): "x": 1.0, "y": 2.0, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, } }, ) @@ -251,17 +251,17 @@ def myfunc3(x, y): "x": 1.0, "y": 2.0, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, "myfunc2": { "y": 3.0, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, "myfunc3": { "y": 4.0, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, } wg.run(inputs=inputs) @@ -301,13 +301,13 @@ def multiply(x, y): "x": 2, "y": 3, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, "multiply": { "x": 3, "y": 4, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, }, wait=True, @@ -349,7 +349,7 @@ def add(): inputs={ "add": { "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, "upload_files": { "input.txt": input_file, "inputs_folder": input_folder, @@ -404,19 +404,19 @@ def multiply(x_folder_name, y_folder_name): "x": 2, "y": 3, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, "add2": { "x": 2, "y": 3, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, "multiply": { "x_folder_name": "add1_remote_folder", "y_folder_name": "add2_remote_folder", "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, }, wait=True, @@ -442,7 +442,7 @@ def add(x, y): "x": 2, "y": 3, "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, "metadata": { "options": { "additional_retrieve_list": ["result.txt"], @@ -475,7 +475,7 @@ def make_supercell(atoms: Atoms, dim: int) -> Atoms: dim=2, name="make_supercell", computer="localhost", - code_label=python_executable_path, + command_info={"label": python_executable_path}, ) # ------------------------- Submit the calculation ------------------- wg.submit(wait=True) @@ -501,7 +501,7 @@ def add(x: str, y: str) -> str: "x": "Hello, ", "y": "World!", "computer": "localhost", - "code_label": python_executable_path, + "command_info": {"label": python_executable_path}, }, }, # wait=True, @@ -545,7 +545,7 @@ def add(x: array, y: array) -> array: x=array([1, 1]), y=array([1, -2]), computer="localhost", - code_label=python_executable_path, + command_info={"label": python_executable_path}, ) wg.run() # the first task should have exit status 410 diff --git a/tests/test_serializer.py b/tests/test_serializer.py index df509307..e69de29b 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,12 +0,0 @@ -import aiida - - -def test_python_job(): - """Test a simple python node.""" - from aiida_workgraph.orm import GeneralData, serialize_to_aiida_nodes - - inputs = {"a": 1, "b": 2.0, "c": set()} - new_inputs = serialize_to_aiida_nodes(inputs) - assert isinstance(new_inputs["a"], aiida.orm.Int) - assert isinstance(new_inputs["b"], aiida.orm.Float) - assert isinstance(new_inputs["c"], GeneralData) diff --git a/tests/test_socket.py b/tests/test_socket.py index 6e050158..fea72d4d 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -109,7 +109,7 @@ def test(a, b=1, **kwargs): test1 = test.node() assert test1.inputs["kwargs"].link_limit == 1e6 assert test1.inputs["kwargs"].identifier == "workgraph.namespace" - assert test1.inputs["kwargs"].property.value == {} + assert test1.inputs["kwargs"].property.value is None @pytest.mark.parametrize( diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 74ef19e9..d2c5b916 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -34,7 +34,7 @@ def test_build_task_from_workgraph( wg = WorkGraph("build_task_from_workgraph") add1_task = wg.add_task(decorated_add, name="add1", x=1, y=3) wg_task = wg.add_task(wg_calcfunction, name="wg_calcfunction") - assert wg_task.inputs["sumdiff1"].value == {} + assert wg_task.inputs["sumdiff1"].value is None wg.add_task(decorated_add, name="add2", y=3) wg.add_link(add1_task.outputs["result"], wg_task.inputs["sumdiff1.x"]) wg.add_link(wg_task.outputs["sumdiff2.sum"], wg.tasks["add2"].inputs["x"]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8e42db34..8ecaaab6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,66 +3,49 @@ from aiida_workgraph.utils import validate_task_inout +def test_validate_task_inout_empty_list(): + """Test validation with a list of strings.""" + input_list = [] + result = validate_task_inout(input_list, "inputs") + assert result == [] + + def test_validate_task_inout_str_list(): """Test validation with a list of strings.""" input_list = ["task1", "task2"] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == [{"name": "task1"}, {"name": "task2"}] def test_validate_task_inout_dict_list(): """Test validation with a list of dictionaries.""" input_list = [{"name": "task1"}, {"name": "task2"}] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list -@pytest.mark.parametrize( - "input_list, list_type, expected_error", - [ - # Mixed types error cases - ( - ["task1", {"name": "task2"}], - "input", - "Provide either a list of `str` or `dict` as `input`, not mixed types.", - ), - ( - [{"name": "task1"}, "task2"], - "output", - "Provide either a list of `str` or `dict` as `output`, not mixed types.", - ), - # Empty list cases - ([], "input", None), - ([], "output", None), - ], -) -def test_validate_task_inout_mixed_types(input_list, list_type, expected_error): - """Test error handling for mixed type lists.""" - if expected_error: - with pytest.raises(TypeError) as excinfo: - validate_task_inout(input_list, list_type) - assert str(excinfo.value) == expected_error - else: - # For empty lists, no error should be raised - result = validate_task_inout(input_list, list_type) - assert result == [] +def test_validate_task_inout_mixed_list(): + """Test validation with a list of dictionaries.""" + input_list = ["task1", {"name": "task2"}] + result = validate_task_inout(input_list, "inputs") + assert result == [{"name": "task1"}, {"name": "task2"}] @pytest.mark.parametrize( "input_list, list_type", [ # Invalid type cases - ([1, 2, 3], "input"), - ([None, None], "output"), - ([True, False], "input"), - (["task", 123], "output"), + ([1, 2, 3], "inputs"), + ([None, None], "outputs"), + ([True, False], "inputs"), + (["task", 123], "outputs"), ], ) def test_validate_task_inout_invalid_types(input_list, list_type): """Test error handling for completely invalid type lists.""" with pytest.raises(TypeError) as excinfo: validate_task_inout(input_list, list_type) - assert "Provide either a list of" in str(excinfo.value) + assert "Wrong type provided" in str(excinfo.value) def test_validate_task_inout_dict_with_extra_keys(): @@ -71,5 +54,5 @@ def test_validate_task_inout_dict_with_extra_keys(): {"name": "task1", "description": "first task"}, {"name": "task2", "priority": "high"}, ] - result = validate_task_inout(input_list, "input") + result = validate_task_inout(input_list, "inputs") assert result == input_list diff --git a/tests/test_while.py b/tests/test_while.py index 90f5280d..6c677e85 100644 --- a/tests/test_while.py +++ b/tests/test_while.py @@ -43,7 +43,7 @@ def raw_python_code(): "l": 1, } add1 = wg.add_task(decorated_add, name="add1", x=1, y=1) - add1.set_context({"result": "n"}) + add1.set_context({"n": "result"}) # --------------------------------------------------------------------- # the `result` of compare1 taskis used as condition compare1 = wg.add_task(decorated_compare, name="compare1", x="{{m}}", y=10) @@ -57,7 +57,7 @@ def raw_python_code(): ) add21.waiting_on.add("add1") add22 = wg.add_task(decorated_add, name="add22", x=add21.outputs["result"], y=1) - add22.set_context({"result": "n"}) + add22.set_context({"n": "result"}) while2.children.add(["add21", "add22"]) # --------------------------------------------------------------------- compare3 = wg.add_task(decorated_compare, name="compare3", x="{{l}}", y=5) @@ -67,13 +67,13 @@ def raw_python_code(): add31 = wg.add_task(decorated_add, name="add31", x="{{l}}", y=1) add31.waiting_on.add("add22") add32 = wg.add_task(decorated_add, name="add32", x=add31.outputs["result"], y=1) - add32.set_context({"result": "l"}) + add32.set_context({"l": "result"}) while3.children.add(["add31", "add32"]) # --------------------------------------------------------------------- add12 = wg.add_task( decorated_add, name="add12", x="{{m}}", y=add32.outputs["result"] ) - add12.set_context({"result": "m"}) + add12.set_context({"m": "result"}) while1.children.add(["add11", "while2", "while3", "add12", "compare2", "compare3"]) # --------------------------------------------------------------------- add2 = wg.add_task( @@ -101,7 +101,7 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare): decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) ) add1 = wg.add_task(decorated_add, name="add1", y=3) - add1.set_context({"result": "n"}) + add1.set_context({"n": "result"}) wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) wg.submit(wait=True, timeout=100) assert wg.execution_count == 4 @@ -125,7 +125,7 @@ def my_while(n=0, limit=100): decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2) ) add1 = wg.add_task(decorated_add, name="add1", y=3) - add1.set_context({"result": "n"}) + add1.set_context({"n": "result"}) wg.add_link(multiply1.outputs["result"], add1.inputs["x"]) return wg