Skip to content

Commit

Permalink
Changes for pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Jul 19, 2024
1 parent f3f6dac commit cc642d0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 22 deletions.
14 changes: 9 additions & 5 deletions aiida_submission_controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aiida import engine, orm
from aiida.common import NotExistent
from pydantic import BaseModel, validator
from pydantic import BaseModel, validator, field_validator
from rich import print
from rich.console import Console
from rich.table import Table
Expand All @@ -33,7 +33,7 @@ def add_to_nested_dict(nested_dict, key, value):
return extras_dict


def validate_group_exists(value: str) -> str:
def _validate_group_exists(value: str) -> str:
"""Validator that makes sure the ``Group`` with the provided label exists."""
try:
orm.Group.collection.get(label=value)
Expand All @@ -56,12 +56,15 @@ class BaseSubmissionController(BaseModel):
unique_extra_keys: Optional[tuple] = None
"""Tuple of keys defined in the extras that uniquely define each process to be run."""

_validate_group_exists = validator("group_label", allow_reuse=True)(validate_group_exists)
@field_validator('group_label')
@classmethod
def validate_group_exists(cls, v: str) -> str:
return _validate_group_exists(v)

@property
def group(self):
"""Return the AiiDA ORM Group instance that is managed by this class."""
return orm.Group.objects.get(label=self.group_label)
return orm.Group.collection.get(label=self.group_label)

def get_query(self, process_projections, only_active=False):
"""Return a QueryBuilder object to get all processes in the group associated to this.
Expand Down Expand Up @@ -233,10 +236,11 @@ def submit_new_batch(self, dry_run=False, sort=False, verbose=False):

except Exception as exc:
CMDLINE_LOGGER.error(f"Failed to submit work chain for extras <{workchain_extras}>: {exc}")
raise
else:
CMDLINE_LOGGER.report(f"Submitted work chain <{wc_node}> for extras <{workchain_extras}>.")

wc_node.set_extra_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras))
wc_node.base.extras.set_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras))
self.group.add_nodes([wc_node])
submitted[workchain_extras] = wc_node

Expand Down
33 changes: 18 additions & 15 deletions aiida_submission_controller/from_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Optional

from aiida import orm
from pydantic import validator
from pydantic import field_validator

from .base import BaseSubmissionController, validate_group_exists
from .base import BaseSubmissionController, _validate_group_exists


class FromGroupSubmissionController(BaseSubmissionController): # pylint: disable=abstract-method
Expand All @@ -22,27 +22,30 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl
order_by: Optional[dict] = None
"""Ordering applied to the query of the nodes in the parent group."""

_validate_group_exists = validator("parent_group_label", allow_reuse=True)(validate_group_exists)
@field_validator('group_label')
@classmethod
def validate_group_exists(cls, v: str) -> str:
return _validate_group_exists(v)

@property
def parent_group(self):
"""Return the AiiDA ORM Group instance of the parent group."""
return orm.Group.objects.get(label=self.parent_group_label)
return orm.Group.collection.get(label=self.parent_group_label)

def get_parent_node_from_extras(self, extras_values):
"""Return the Node instance (in the parent group) from the (unique) extras identifying it."""
extras_projections = self.get_process_extra_projections()
assert len(extras_values) == len(extras_projections), f"The extras must be of length {len(extras_projections)}"
filters = dict(zip(extras_projections, extras_values))

qbuild = orm.QueryBuilder()
qbuild.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group")
qbuild.append(orm.Node, project="*", filters=filters, tag="process", with_group="group")
qbuild.limit(2)
results = qbuild.all(flat=True)
qb = orm.QueryBuilder()
qb.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group")
qb.append(orm.Node, project="*", filters=filters, tag="process", with_group="group")
qb.limit(2)
results = qb.all(flat=True)
if len(results) != 1:
raise ValueError(
"I would have expected only 1 result for extras={extras}, I found {'>1' if len(qbuild) else '0'}"
"I would have expected only 1 result for extras={extras}, I found {'>1' if len(qb) else '0'}"
)
return results[0]

Expand All @@ -57,9 +60,9 @@ def get_all_extras_to_submit(self):
"""
extras_projections = self.get_process_extra_projections()

qbuild = orm.QueryBuilder()
qbuild.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group")
qbuild.append(
qb = orm.QueryBuilder()
qb.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group")
qb.append(
orm.Node,
project=extras_projections,
filters=self.filters,
Expand All @@ -68,9 +71,9 @@ def get_all_extras_to_submit(self):
)

if self.order_by is not None:
qbuild.order_by(self.order_by)
qb.order_by(self.order_by)

results = qbuild.all()
results = qb.all()

# I return a set of results as required by the API
# First, however, convert to a list of tuples otherwise
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ classifiers = [
requires-python = ">=3.6"

dependencies = [
"aiida-core>=1.0",
"pydantic~=1.10.4",
"aiida-core~=2.5",
"rich",
]

Expand Down

0 comments on commit cc642d0

Please sign in to comment.