Skip to content

Commit

Permalink
Implement skip_orm option for SqlAlchemy Group.remove_nodes (#4214)
Browse files Browse the repository at this point in the history
The current implementation of `Group.remove_nodes` is very slow. For a
group of a few tens of thousands of nodes, removing a thousand can take
more than a day. The same problem exists for `add_nodes` which is why a
shortcut was added to the backend implementation for SqlAlchemy. Here,
we do the same for `remove_nodes`. The `SqlaGroup.remove_nodes` now
accepts a keyword argument `skip_orm` that, when True, will delete the
nodes by directly constructing a delete query on the join table.
  • Loading branch information
sphuber authored Jul 24, 2020
1 parent 3a4eff7 commit bced84e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
39 changes: 29 additions & 10 deletions aiida/orm/implementation/sqlalchemy/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,37 +228,56 @@ def check_node(given_node):
# Commit everything as up till now we've just flushed
session.commit()

def remove_nodes(self, nodes):
def remove_nodes(self, nodes, **kwargs):
"""Remove a node or a set of nodes from the group.
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL
DELETE statement to the group-node relationship table in order to improve speed.
"""
from sqlalchemy import and_
from aiida.backends.sqlalchemy import get_scoped_session
from aiida.backends.sqlalchemy.models.base import Base
from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode

super().remove_nodes(nodes)

# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self._dbmodel.dbnodes
skip_orm = kwargs.get('skip_orm', False)

list_nodes = []

for node in nodes:
def check_node(node):
if not isinstance(node, SqlaNode):
raise TypeError('invalid type {}, has to be {}'.format(type(node), SqlaNode))

if node.id is None:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

# If we don't check first, SqlA might issue a DELETE statement for an unexisting key, resulting in an error
if node.dbmodel in dbnodes:
list_nodes.append(node.dbmodel)
list_nodes = []

for node in list_nodes:
dbnodes.remove(node)
with utils.disable_expire_on_commit(get_scoped_session()) as session:
if not skip_orm:
for node in nodes:
check_node(node)

# Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error
if node.dbmodel in dbnodes:
list_nodes.append(node.dbmodel)

for node in list_nodes:
dbnodes.remove(node)
else:
table = Base.metadata.tables['db_dbgroup_dbnodes']
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)

sa.get_scoped_session().commit()
session.commit()


class SqlaGroupCollection(BackendGroupCollection):
Expand Down
30 changes: 30 additions & 0 deletions tests/backends/aiida_sqlalchemy/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,33 @@ def test_group_batch_size(self):
group = Group(label='test_batches_' + str(batch_size)).store()
group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

def test_remove_nodes_bulk(self):
"""Test node removal."""
backend = self.backend

node_01 = Data().store().backend_entity
node_02 = Data().store().backend_entity
node_03 = Data().store().backend_entity
node_04 = Data().store().backend_entity
nodes = [node_01, node_02, node_03]
group = backend.groups.create(label='test_remove_nodes', user=backend.users.create('[email protected]')).store()

# Add initial nodes
group.add_nodes(nodes)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

# Remove a node that is not in the group: nothing should happen
group.remove_nodes([node_04], skip_orm=True)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

# Remove one Node
nodes.remove(node_03)
group.remove_nodes([node_03], skip_orm=True)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

# Remove a list of Nodes and check
nodes.remove(node_01)
nodes.remove(node_02)
group.remove_nodes([node_01, node_02], skip_orm=True)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

0 comments on commit bced84e

Please sign in to comment.