Skip to content

Commit

Permalink
🐛 Allow specifying nested extras
Browse files Browse the repository at this point in the history
Currently, any nested extras can not be used for the `BaseSubmissionController`.
Using e.g. `source.database` as one of the extras would fail, since when
adding the extras to the submitted work chain the controller attempts to
directly pass the `source.database` string as an extra key, which is not
accepted by AiiDA.

Here add the `get_extras_dict` utility function to convert the list of
(possibly nested) extras into a proper dictionary to pass to the
`set_extra_many` call on the work chain. This makes using nested extras for the
submission controllers possible.
  • Loading branch information
mbercx committed Dec 20, 2023
1 parent 201cdac commit fcfd545
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions aiida_submission_controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@
CMDLINE_LOGGER = logging.getLogger("verdi")


def get_extras_dict(extras_keys, workchain_extras):
"""Return a dictionary of extras from a list of keys and a list of values."""

def add_to_nested_dict(nested_dict, key, value):
if "." in key:
first_key, remaining_keys = key.split(".", 1)
nested_dict.setdefault(first_key, {})
add_to_nested_dict(nested_dict[first_key], remaining_keys, value)
else:
nested_dict.setdefault(key, value)

extras_dict = {}

for key, value in zip(extras_keys, workchain_extras):
add_to_nested_dict(extras_dict, key, value)

return extras_dict


def validate_group_exists(value: str) -> str:
"""Validator that makes sure the ``Group`` with the provided label exists."""
try:
Expand Down Expand Up @@ -216,8 +235,8 @@ def submit_new_batch(self, dry_run=False, sort=False, verbose=False):
CMDLINE_LOGGER.error(f"Failed to submit work chain for extras <{workchain_extras}>: {exc}")
else:
CMDLINE_LOGGER.report(f"Submitted work chain <{wc_node}> for extras <{workchain_extras}>.")
# Add extras, and put in group
wc_node.set_extra_many(dict(zip(self.get_extra_unique_keys(), workchain_extras)))

wc_node.set_extra_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras))
self.group.add_nodes([wc_node])
submitted[workchain_extras] = wc_node

Expand Down

0 comments on commit fcfd545

Please sign in to comment.