Skip to content

Commit

Permalink
📚 Update PwBaseWorkChain example
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Jan 19, 2024
1 parent cf9d6f0 commit 7af2c12
Showing 1 changed file with 48 additions and 141 deletions.
189 changes: 48 additions & 141 deletions examples/pw_base.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,64 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""An example of a SubmissionController implementation for a small set of PwBaseWorkChains."""

import typing as ty
import warnings

from aiida import load_profile, orm, plugins
from qe_tools import CONSTANTS

from aiida_submission_controller import BaseSubmissionController


class PwBaseSubmissionController(BaseSubmissionController):
"""The implementation of a SubmissionController to run a small set of PwBaseWorkChains."""

WORKFLOW_ENTRY_POINT = "quantumespresso.pw.base"

def __init__(
self,
pw_code_id: ty.Union[str, int],
structure_group_id: ty.Union[str, int],
pseudo_family_id: ty.Union[str, int],
*args,
structure_filters: ty.Optional[ty.Dict[str, ty.Any]] = None,
**kwargs,
):
"""A SubmissionController for PwBaseWorkChains."""
super().__init__(*args, **kwargs)
self._code = orm.load_code(identifier=pw_code_id)
self._process_class = plugins.WorkflowFactory(self.WORKFLOW_ENTRY_POINT)
self._structure_group = orm.load_group(identifier=structure_group_id)
self._structure_filters = structure_filters if structure_filters is not None else {}
self._pseudo_family = orm.load_group(identifier=pseudo_family_id)

def get_extra_unique_keys(self) -> ty.Tuple[str]:
"""Return a tuple of the extra key or keys used to uniquely identify your workchains."""
return ("mpid",)

def get_all_extras_to_submit(self) -> ty.Set[ty.Tuple[str]]:
"""Return a set of all the unique extras to submit."""
pseudo_family_elements = set(self._pseudo_family.elements)

qbuild = orm.QueryBuilder()
qbuild.append(orm.Group, filters={"label": self._structure_group.label}, tag="group")
qbuild.append(
orm.StructureData,
project=["extras.mpid", "attributes.kinds"],
tag="structure",
with_group="group",
filters={"extras": {"has_key": "mpid"}, **self._structure_filters},
)
res = qbuild.all()

all_extras = []
for mpid, kinds in res:
kind_names = set(kind["name"] for kind in kinds)
if kind_names.issubset(pseudo_family_elements):
all_extras.append((mpid,))
all_extras = set(all_extras)

# all_extras = set((mpid,) for mpid in qb.all(flat=True))
return all_extras

def _get_structure_from_extras(self, extras_values: ty.Tuple[str]) -> orm.StructureData:
"""Get a structure from the values of the extras."""
qbuild = orm.QueryBuilder()
qbuild.append(orm.Group, filters={"label": self._structure_group.label}, tag="group")
qbuild.append(
orm.StructureData,
project="*",
tag="structure",
with_group="group",
filters={"extras.mpid": extras_values[0]},
import time

from aiida import load_profile, orm
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from ase.build import bulk

from aiida_submission_controller import FromGroupSubmissionController


class PwBaseSubmissionController(FromGroupSubmissionController):
"""SubmissionController to run PwBaseWorkChains from a group of `StructureData` nodes."""

pw_code: str
"""The label of the `Code` to use for the `PwBaseWorkChain`."""
overrides: dict = {}
"""A dictionary of overrides to pass to `PwBaseWorkChain.get_builder_from_protocol()`."""

def get_inputs_and_processclass_from_extras(self, extras_values):
parent_node = self.get_parent_node_from_extras(extras_values)

if not isinstance(parent_node, orm.StructureData):
raise ValueError("The parent node is not a StructureData node.")

builder = PwBaseWorkChain.get_builder_from_protocol(
code=orm.load_code(self.pw_code),
structure=parent_node,
overrides=self.overrides,
)
structure = qbuild.all(flat=True)[0]
return structure

def get_inputs_and_processclass_from_extras(self, extras_values: ty.Tuple[str]):
"""Construct the inputs and get the process class from the values of the uniquely identifying extras."""
structure = self._get_structure_from_extras(extras_values)
pseudos = self._pseudo_family.get_pseudos(structure=structure)
ecutwfc, ecutrho = self._pseudo_family.get_recommended_cutoffs(structure=structure)
metadata = {
"options": {
"resources": {"num_machines": 1, "num_mpiprocs_per_machine": 1},
"max_wallclock_seconds": 2 * 60,
"withmpi": True,
}
}

inputs = {
"clean_workdir": orm.Bool(True),
"kpoints_distance": orm.Float(0.25),
"pw": {
"structure": structure,
"metadata": metadata,
"code": self._code,
"pseudos": pseudos,
"parameters": orm.Dict(
dict={
"CONTROL": {"calculation": "scf", "verbosity": "low"},
"SYSTEM": {
"ecutwfc": ecutwfc,
"ecutrho": ecutrho,
"nosym": False,
"occupations": "smearing",
"smearing": "gaussian",
"degauss": 0.5 / CONSTANTS.ry_to_ev,
},
"ELECTRONS": {
"conv_thr": 1e-8,
"mixing_beta": 4e-1,
"electron_maxstep": 80,
},
}
),
},
}

return inputs, self._process_class
return builder


def main():
"""Main execution when called as a script."""
warnings.filterwarnings("ignore")

profile = "asc"
pseudo_family_id = "SSSP/1.1/PBE/efficiency"
load_profile()

load_profile(profile)
# To make the example easier to execute, the parent structure group and the workchain group are created here
# and the structure group is populated with some example structures in case it does not exist yet.
structure_group, created = orm.Group.collection.get_or_create("structures")
workchain_group, _ = orm.Group.collection.get_or_create("workchain/base")

if created:
for cell_size in (3.9, 4.0, 4.1, 4.2, 4.3, 4.4):
structure = orm.StructureData(ase=bulk("Al", a=cell_size, cubic=True))
structure.store()
structure_group.add_nodes(structure)

controller = PwBaseSubmissionController(
pw_code_id="pw-6.7MaX_conda",
structure_group_id="structures/mp/2018_10_18",
structure_filters={
"attributes.sites": {"longer": 0, "shorter": 3},
},
pseudo_family_id=pseudo_family_id,
group_label="tests/pw_base",
max_concurrent=2,
unique_extra_keys=("_aiida_hash",),
parent_group_label=structure_group.label,
group_label=workchain_group.label,
max_concurrent=1,
pw_code="pw@localhost", # Replace with the label of a code configured for Quantum ESPRESSO pw.x
)

print("Max concurrent :", controller.max_concurrent)
print("Active slots :", controller.num_active_slots)
print("Available slots:", controller.num_available_slots)
print("Already run :", controller.num_already_run)
print("Still to run :", controller.num_to_run)
print()

print("Submitting...")
run_processes = controller.submit_new_batch(dry_run=False)
for run_process_extras, run_process in run_processes.items():
print(f"{run_process_extras} --> <{run_process}>")
print("Done.")
while True:
controller.submit_new_batch(verbose=True)
time.sleep(30)


if __name__ == "__main__":
Expand Down

0 comments on commit 7af2c12

Please sign in to comment.