diff --git a/src/aiidalab_qe/app/result/__init__.py b/src/aiidalab_qe/app/result/__init__.py
index e81a80e50..eabe8469d 100644
--- a/src/aiidalab_qe/app/result/__init__.py
+++ b/src/aiidalab_qe/app/result/__init__.py
@@ -121,7 +121,7 @@ def _post_render(self):
self.toggle_controls.value = "Summary"
self.process_monitor = ProcessMonitor(
- timeout=0.2,
+ timeout=0.5,
callbacks=[
self._update_status,
self._update_state,
diff --git a/src/aiidalab_qe/app/result/components/viewer/structure/model.py b/src/aiidalab_qe/app/result/components/viewer/structure/model.py
index 3d56ab48a..eb9526357 100644
--- a/src/aiidalab_qe/app/result/components/viewer/structure/model.py
+++ b/src/aiidalab_qe/app/result/components/viewer/structure/model.py
@@ -9,12 +9,15 @@ class StructureResultsModel(ResultsModel):
source = None
+ @property
+ def include(self):
+ return True
+
def update(self):
- super().update()
is_relaxed = "relax" in self.properties
self.title = "Relaxed structure" if is_relaxed else "Initial structure"
self.source = self.outputs if is_relaxed else self.inputs
- self.auto_render = not is_relaxed # auto-render initial structure
+ self.auto_render = not is_relaxed or self.has_results
def get_structure(self):
try:
diff --git a/src/aiidalab_qe/app/result/components/viewer/structure/structure.py b/src/aiidalab_qe/app/result/components/viewer/structure/structure.py
index 93a95e5c4..3cae7cb85 100644
--- a/src/aiidalab_qe/app/result/components/viewer/structure/structure.py
+++ b/src/aiidalab_qe/app/result/components/viewer/structure/structure.py
@@ -9,7 +9,7 @@ def _render(self):
if not hasattr(self, "widget"):
structure = self._model.get_structure()
self.widget = StructureDataViewer(structure=structure)
- self.children = [self.widget]
+ self.results_container.children = [self.widget]
# HACK to resize the NGL viewer in cases where it auto-rendered when its
# container was not displayed, which leads to a null width. This hack restores
diff --git a/src/aiidalab_qe/app/result/components/viewer/viewer.py b/src/aiidalab_qe/app/result/components/viewer/viewer.py
index 288671e51..466daeedf 100644
--- a/src/aiidalab_qe/app/result/components/viewer/viewer.py
+++ b/src/aiidalab_qe/app/result/components/viewer/viewer.py
@@ -12,10 +12,15 @@
class WorkChainResultsViewer(ResultsComponent[WorkChainResultsViewerModel]):
def __init__(self, model: WorkChainResultsViewerModel, **kwargs):
- super().__init__(model=model, **kwargs)
+ # NOTE: here we want to add the structure and plugin models to the viewer
+ # model BEFORE we define the observation of the process uuid. This ensures
+ # that when the process changes, its reflected in the sub-models prior to
+ # the logic of the process change event handler.
+ # TODO avoid exceptions! Ensure sub-model synchronization in general!
self.panels: dict[str, ResultsPanel] = {}
- self._add_structure_panel() # TODO consider refactoring structure panel as a plugin
- self._fetch_plugin_results()
+ self._add_structure_panel(model)
+ self._fetch_plugin_results(model)
+ super().__init__(model=model, **kwargs)
def _on_process_change(self, _):
self._update_panels()
@@ -40,14 +45,10 @@ def _post_render(self):
self._set_tabs()
def _update_panels(self):
- properties = self._model.properties
- need_electronic_structure = "bands" in properties and "pdos" in properties
self.panels = {
- identifier: panel
- for identifier, panel in self.panels.items()
- if identifier == "structure"
- or identifier in properties
- or (identifier == "electronic_structure" and need_electronic_structure)
+ identifier: self.panels[identifier]
+ for identifier, model in self._model.get_models()
+ if model.include
}
def _set_tabs(self):
@@ -65,18 +66,18 @@ def _set_tabs(self):
if children:
self.tabs.selected_index = 0
- def _add_structure_panel(self):
+ def _add_structure_panel(self, viewer_model: WorkChainResultsViewerModel):
structure_model = StructureResultsModel()
- structure_model.process_uuid = self._model.process_uuid
+ structure_model.process_uuid = viewer_model.process_uuid
self.structure_results = StructureResultsPanel(model=structure_model)
identifier = structure_model.identifier
- self._model.add_model(identifier, structure_model)
+ viewer_model.add_model(identifier, structure_model)
self.panels = {
identifier: self.structure_results,
**self.panels,
}
- def _fetch_plugin_results(self):
+ def _fetch_plugin_results(self, viewer_model: WorkChainResultsViewerModel):
entries = get_entry_items("aiidalab_qe.properties", "result")
for identifier, entry in entries.items():
for key in ("panel", "model"):
@@ -86,7 +87,7 @@ def _fetch_plugin_results(self):
)
panel = entry["panel"]
model = entry["model"]()
- self._model.add_model(identifier, model)
+ viewer_model.add_model(identifier, model)
self.panels[identifier] = panel(
identifier=identifier,
model=model,
diff --git a/src/aiidalab_qe/app/static/styles/custom.css b/src/aiidalab_qe/app/static/styles/custom.css
index 89975761d..3a0e85130 100644
--- a/src/aiidalab_qe/app/static/styles/custom.css
+++ b/src/aiidalab_qe/app/static/styles/custom.css
@@ -108,3 +108,7 @@ footer {
.p-Accordion-child:has(.qe-app-step-fail) > .p-Collapse-header {
background-color: var(--color-failed);
}
+
+.p-TabBar-tab {
+ min-width: fit-content !important;
+}
diff --git a/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py b/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py
index e93653977..123548387 100644
--- a/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py
+++ b/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py
@@ -7,31 +7,34 @@
class BandsPdosWidget(ipw.VBox):
- """
- A widget for plotting band structure and projected density of states (PDOS) data.
+ """A widget for plotting band structure and projected density of states (PDOS).
Parameters
----------
- - bands (optional): A node containing band structure data.
- - pdos (optional): A node containing PDOS data.
+ `model`: `BandsPdosModel`
+ The MVC model containing the data and logic for the widget.
Attributes
----------
- - description: HTML description of the widget.
- - dos_atoms_group: Dropdown widget to select the grouping of atoms for PDOS plotting.
- - dos_plot_group: Dropdown widget to select the type of PDOS contributions to plot.
- - selected_atoms: Text widget to select specific atoms for PDOS plotting.
- - update_plot_button: Button widget to update the plot.
- - download_button: Button widget to download the data.
- - project_bands_box: Checkbox widget to choose whether projected bands should be plotted.
- - plot_widget: Plotly widget for band structure and PDOS plot.
- - bands_widget: Output widget to display the bandsplot widget.
+ `description`: `ipywidgets.HTML`
+ HTML description of the widget.
+ `dos_atoms_group`: `ipywidgets.Dropdown`
+ Dropdown widget to select the grouping of atoms for PDOS plotting.
+ `dos_plot_group`: `ipywidgets.Dropdown`
+ Dropdown widget to select the type of PDOS contributions to plot.
+ `selected_atoms`: `ipywidgets.Text`
+ Text widget to select specific atoms for PDOS plotting.
+ `update_plot_button`: `ipywidgets.Button`
+ Button widget to update the plot.
+ `download_button`: `ipywidgets.Button`
+ Button widget to download the data.
+ `project_bands_box`: `ipywidgets.Checkbox`
+ Checkbox widget to choose whether projected bands should be plotted.
+ `plot`: `plotly.graph_objects.FigureWidget`
+ Plotly widget for band structure and PDOS plot.
"""
- def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs):
- if bands is None and pdos is None:
- raise ValueError("Either bands or pdos must be provided")
-
+ def __init__(self, model: BandsPdosModel, **kwargs):
super().__init__(
children=[LoadingWidget("Loading widgets")],
**kwargs,
@@ -49,9 +52,6 @@ def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs):
self.rendered = False
- self._model.bands = bands
- self._model.pdos = pdos
-
def render(self):
if self.rendered:
return
diff --git a/src/aiidalab_qe/common/bands_pdos/model.py b/src/aiidalab_qe/common/bands_pdos/model.py
index 690ef10ca..c9f2b72e2 100644
--- a/src/aiidalab_qe/common/bands_pdos/model.py
+++ b/src/aiidalab_qe/common/bands_pdos/model.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import base64
import json
@@ -6,9 +8,12 @@
import traitlets as tl
from IPython.display import display
+from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiidalab_qe.common.bands_pdos.utils import (
HTML_TAGS,
+ extract_bands_output,
+ extract_pdos_output,
get_bands_data,
get_bands_projections_data,
get_pdos_data,
@@ -79,6 +84,49 @@ def __init__(self, *args, **kwargs):
lambda _: self._has_pdos or self.needs_projections_controls,
)
+ @classmethod
+ def from_nodes(
+ cls,
+ bands: orm.WorkChainNode | None = None,
+ pdos: orm.WorkChainNode | None = None,
+ root: orm.WorkChainNode | None = None,
+ ):
+ """Create a `BandsPdosModel` instance from the provided nodes.
+
+ The method attempts to extract the output attribute dictionaries from the
+ nodes and creates from them an instance of the model.
+
+ Parameters
+ ----------
+ `bands` : `orm.WorkChainNode`, optional
+ The bands workchain node.
+ `pdos` : `orm.WorkChainNode`, optional
+ The PDOS workchain node.
+ `root`: `orm.WorkChainNode`, optional
+ The root QE app workchain node.
+
+ Returns
+ -------
+ `BandsPdosModel`
+ The model instance.
+
+ Raises
+ ------
+ `ValueError`
+ If neither of the nodes is provided or if the parsing of the nodes fails.
+ """
+ if bands or pdos:
+ bands_output = extract_bands_output(bands)
+ pdos_output = extract_pdos_output(pdos)
+ elif root:
+ bands_output = extract_bands_output(root)
+ pdos_output = extract_pdos_output(root)
+ else:
+ raise ValueError("At least one of the nodes must be provided")
+ if bands_output or pdos_output:
+ return cls(bands=bands_output, pdos=pdos_output)
+ raise ValueError("Failed to parse at least one node")
+
def fetch_data(self):
"""Fetch the data from the nodes."""
if self.bands:
diff --git a/src/aiidalab_qe/common/bands_pdos/utils.py b/src/aiidalab_qe/common/bands_pdos/utils.py
index d0203b171..c73b20222 100644
--- a/src/aiidalab_qe/common/bands_pdos/utils.py
+++ b/src/aiidalab_qe/common/bands_pdos/utils.py
@@ -1,10 +1,13 @@
+from __future__ import annotations
+
import json
import re
import numpy as np
from pymatgen.core.periodic_table import Element
-from aiida.orm import ProjectionData
+from aiida.common.extendeddicts import AttributeDict
+from aiida.orm import ProjectionData, WorkChainNode
# Constants for HTML tags
HTML_TAGS = {
@@ -37,6 +40,56 @@
}
+def extract_pdos_output(node: WorkChainNode) -> AttributeDict | None:
+ """Extract the PDOS output node from the given node.
+
+ Parameters
+ ----------
+ `node`: `WorkChainNode`
+ The node to extract the PDOS output from.
+
+ Returns
+ -------
+ `AttributeDict | None`
+ The PDOS output node, if available.
+ """
+ if not node:
+ return
+ if node.process_label == "QeAppWorkChain" and "pdos" in node.outputs:
+ return node.outputs.pdos
+ if "dos" in node.outputs and "projwfc" in node.outputs:
+ items = {key: getattr(node.outputs, key) for key in node.outputs}
+ return AttributeDict(items)
+
+
+def extract_bands_output(node: WorkChainNode) -> AttributeDict | None:
+ """Extract the bands output node from the given node.
+
+ Parameters
+ ----------
+ `node`: `WorkChainNode`
+ The node to extract the bands output from.
+
+ Returns
+ -------
+ `AttributeDict | None`
+ The bands output node, if available.
+ """
+ if not node:
+ return
+ if node.process_label == "QeAppWorkChain" and "bands" in node.outputs:
+ outputs = node.outputs.bands
+ else:
+ outputs = node.outputs
+ return (
+ outputs.bands
+ if "bands" in outputs
+ else outputs.bands_projwfc
+ if "bands_projwfc" in outputs
+ else None
+ )
+
+
def get_bands_data(outputs, fermi_energy=None):
if "band_structure" not in outputs:
return None
diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py
index 223de660a..561e2b2de 100644
--- a/src/aiidalab_qe/common/panel.py
+++ b/src/aiidalab_qe/common/panel.py
@@ -507,6 +507,7 @@ class ResultsModel(PanelModel, HasProcess):
_this_process_uuid = None
auto_render = False
+ _completed_process = False
CSS_MAP = {
"finished": "success",
@@ -519,9 +520,13 @@ class ResultsModel(PanelModel, HasProcess):
"created": "info",
}
+ @property
+ def include(self):
+ return self.identifier in self.properties
+
@property
def has_results(self):
- node = self._fetch_child_process_node()
+ node = self.fetch_child_process_node()
return node and node.is_finished_ok
def update(self):
@@ -529,14 +534,32 @@ def update(self):
self.auto_render = True
def update_process_status_notification(self):
- self.process_status_notification = self._get_child_process_status()
+ if self._completed_process:
+ self.process_status_notification = ""
+ return
+ status = self._get_child_process_status()
+ self.process_status_notification = status
+ if "success" in status:
+ self._completed_process = True
+
+ def fetch_child_process_node(self, which="this") -> orm.ProcessNode | None:
+ if not self.process_uuid:
+ return
+ which = which.lower()
+ uuid = getattr(self, f"_{which}_process_uuid")
+ label = getattr(self, f"_{which}_process_label")
+ if not uuid:
+ root = self.fetch_process_node()
+ child = next((c for c in root.called if c.process_label == label), None)
+ uuid = child.uuid if child else None
+ return orm.load_node(uuid) if uuid else None # type: ignore
- def _get_child_process_status(self, child="this"):
- state, exit_message = self._get_child_state_and_exit_message(child)
+ def _get_child_process_status(self, which="this"):
+ state, exit_message = self._get_child_state_and_exit_message(which)
status = state.upper()
if exit_message:
status = f"{status} ({exit_message})"
- label = "Status" if child == "this" else f"{child.capitalize()} status"
+ label = "Status" if which == "this" else f"{which.capitalize()} status"
alert_class = f"alert-{self.CSS_MAP.get(state, 'info')}"
return f"""
@@ -544,9 +567,9 @@ def _get_child_process_status(self, child="this"):
"""
- def _get_child_state_and_exit_message(self, child="this"):
+ def _get_child_state_and_exit_message(self, which="this"):
if not (
- (node := self._fetch_child_process_node(child))
+ (node := self.fetch_child_process_node(which))
and hasattr(node, "process_state")
and node.process_state
):
@@ -555,37 +578,13 @@ def _get_child_state_and_exit_message(self, child="this"):
return "failed", node.exit_message
return node.process_state.value, None
- def _get_child_outputs(self, child="this"):
- if not (node := self._fetch_child_process_node(child)):
+ def _get_child_outputs(self, which="this"):
+ if not (node := self.fetch_child_process_node(which)):
outputs = super().outputs
- child = child if child != "this" else self.identifier
+ child = which if which != "this" else self.identifier
return getattr(outputs, child) if child in outputs else AttributeDict({})
return AttributeDict({key: getattr(node.outputs, key) for key in node.outputs})
- def _fetch_child_process_node(self, child="this") -> orm.ProcessNode | None:
- if not self.process_uuid:
- return
- child = child.lower()
- uuid = getattr(self, f"_{child}_process_uuid")
- label = getattr(self, f"_{child}_process_label")
- if not uuid:
- uuid = (
- orm.QueryBuilder()
- .append(
- orm.WorkChainNode,
- filters={"uuid": self.process_uuid},
- tag="root_process",
- )
- .append(
- orm.WorkChainNode,
- filters={"attributes.process_label": label},
- project="uuid",
- with_incoming="root_process",
- )
- .first(flat=True)
- )
- return orm.load_node(uuid) if uuid else None # type: ignore
-
RM = t.TypeVar("RM", bound=ResultsModel)
@@ -598,12 +597,10 @@ class ResultsPanel(Panel[RM]):
It has a update method to update the result in the panel.
"""
- has_controls = False
loading_message = "Loading {identifier} results"
def __init__(self, model: RM, **kwargs):
super().__init__(model=model, **kwargs)
-
self._model.observe(
self._on_process_change,
"process_uuid",
@@ -613,19 +610,23 @@ def __init__(self, model: RM, **kwargs):
"monitor_counter",
)
- self.links = []
-
def render(self):
if self.rendered:
if self._model.identifier == "structure":
self._render()
return
- if self.has_controls or not self._model.has_process:
+
+ if not self._model.has_process:
return
+
+ self.results_container = ipw.VBox()
+
if self._model.auto_render:
+ self.children = [self.results_container]
self._load_results()
else:
self._render_controls()
+ self.children += (self.results_container,)
def _on_process_change(self, _):
self._model.update()
@@ -634,14 +635,14 @@ def _on_monitor_counter_change(self, _):
self._model.update_process_status_notification()
def _on_load_results_click(self, _):
+ self.load_controls.children = []
self._load_results()
def _load_results(self):
- self.children = [self.loading_message]
+ self.results_container.children = [self.loading_message]
self._render()
self.rendered = True
self._post_render()
- self.has_controls = False
def _render_controls(self):
self.process_status_notification = ipw.HTML()
@@ -663,22 +664,25 @@ def _render_controls(self):
)
self.load_results_button.on_click(self._on_load_results_click)
+ self.load_controls = ipw.HBox(
+ children=[]
+ if self._model.auto_render
+ else [
+ self.load_results_button,
+ ipw.HTML("""
+
+ Note: Load time may vary depending on the size of the
+ calculation
+
+ """),
+ ]
+ )
+
self.children = [
self.process_status_notification,
- ipw.HBox(
- children=[
- self.load_results_button,
- ipw.HTML("""
-
- Note: Load time may vary depending on the size of the calculation
-
- """),
- ]
- ),
+ self.load_controls,
]
- self.has_controls = True
-
def _render(self):
raise NotImplementedError()
diff --git a/src/aiidalab_qe/plugins/bands/__init__.py b/src/aiidalab_qe/plugins/bands/__init__.py
index 3c2a97aca..13ebfed85 100644
--- a/src/aiidalab_qe/plugins/bands/__init__.py
+++ b/src/aiidalab_qe/plugins/bands/__init__.py
@@ -5,7 +5,6 @@
from .model import BandsConfigurationSettingsModel
from .resources import BandsResourceSettingsModel, BandsResourceSettingsPanel
-from .result import BandsResultsModel, BandsResultsPanel
from .setting import BandsConfigurationSettingsPanel
from .workchain import workchain_and_builder
@@ -24,10 +23,6 @@ class BandsPluginOutline(PluginOutline):
"panel": BandsResourceSettingsPanel,
"model": BandsResourceSettingsModel,
},
- "result": {
- "panel": BandsResultsPanel,
- "model": BandsResultsModel,
- },
"workchain": workchain_and_builder,
"guides": Path(__file__).parent / "guides",
}
diff --git a/src/aiidalab_qe/plugins/bands/result/__init__.py b/src/aiidalab_qe/plugins/bands/result/__init__.py
deleted file mode 100644
index 6d4e06a46..000000000
--- a/src/aiidalab_qe/plugins/bands/result/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .model import BandsResultsModel
-from .result import BandsResultsPanel
-
-__all__ = [
- "BandsResultsModel",
- "BandsResultsPanel",
-]
diff --git a/src/aiidalab_qe/plugins/bands/result/model.py b/src/aiidalab_qe/plugins/bands/result/model.py
deleted file mode 100644
index 1f9df2bb9..000000000
--- a/src/aiidalab_qe/plugins/bands/result/model.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from aiidalab_qe.common.panel import ResultsModel
-
-
-class BandsResultsModel(ResultsModel):
- title = "Bands"
- identifier = "bands"
-
- _this_process_label = "BandsWorkChain"
-
- def get_bands_node(self):
- outputs = self._get_child_outputs()
- if "bands" in outputs:
- return outputs.bands
- elif "bands_projwfc" in outputs:
- return outputs.bands_projwfc
- else:
- # If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself
- # This is the case for compatibility with older versions of the plugin
- return outputs
diff --git a/src/aiidalab_qe/plugins/bands/result/result.py b/src/aiidalab_qe/plugins/bands/result/result.py
deleted file mode 100644
index ad940024d..000000000
--- a/src/aiidalab_qe/plugins/bands/result/result.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Bands results view widgets"""
-
-from aiidalab_qe.common.bands_pdos import BandsPdosModel, BandsPdosWidget
-from aiidalab_qe.common.panel import ResultsPanel
-
-from .model import BandsResultsModel
-
-
-class BandsResultsPanel(ResultsPanel[BandsResultsModel]):
- def _render(self):
- bands_node = self._model.get_bands_node()
- model = BandsPdosModel()
- widget = BandsPdosWidget(model=model, bands=bands_node)
- widget.render()
- self.children = [widget]
diff --git a/src/aiidalab_qe/plugins/electronic_structure/result/model.py b/src/aiidalab_qe/plugins/electronic_structure/result/model.py
index 5f8534f7f..eca444f71 100644
--- a/src/aiidalab_qe/plugins/electronic_structure/result/model.py
+++ b/src/aiidalab_qe/plugins/electronic_structure/result/model.py
@@ -3,8 +3,6 @@
from aiidalab_qe.common.panel import ResultsModel
-# TODO if combined, this model should extend `HasModels`, and effectively
-# TODO reduce to a container of Bands and PDOS, similar to its results panel
class ElectronicStructureResultsModel(ResultsModel):
title = "Electronic Structure"
identifier = "electronic_structure"
@@ -17,31 +15,55 @@ class ElectronicStructureResultsModel(ResultsModel):
_pdos_process_label = "PdosWorkChain"
_pdos_process_uuid = None
- def get_pdos_node(self):
- return self._get_child_outputs("pdos")
+ _TITLE_MAPPING = {
+ "bands": "bands",
+ "pdos": "PDOS",
+ }
- def get_bands_node(self):
- outputs = self._get_child_outputs("bands")
- if "bands" in outputs:
- return outputs.bands
- elif "bands_projwfc" in outputs:
- return outputs.bands_projwfc
- else:
- # If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself
- # This is the case for compatibility with older versions of the plugin
- return outputs
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._completed_processes = set()
@property
def include(self):
- return all(identifier in self.properties for identifier in self.identifiers)
+ return any(identifier in self.properties for identifier in self.identifiers)
@property
def has_results(self):
- return self._has_bands and self._has_pdos
+ return self._has_bands or self._has_pdos
+
+ @property
+ def needs_property_selector(self):
+ return len(self.identifiers) > 1
+
+ def update(self):
+ super().update()
+ self.identifiers = list(
+ filter(
+ lambda identifier: identifier in self.properties,
+ self.identifiers,
+ ),
+ )
+ parts = [self._TITLE_MAPPING[identifier] for identifier in self.identifiers]
+ self.title = f"Electronic {' + '.join(parts)}"
def update_process_status_notification(self):
- statuses = [self._get_child_process_status(child) for child in self.identifiers]
- self.process_status_notification = "\n".join(statuses)
+ self._status_notifications = []
+ for identifier in self.identifiers:
+ if identifier in self._completed_processes:
+ continue
+ status = self._get_child_process_status(identifier)
+ self._status_notifications.append(status)
+ if "success" in status:
+ self._completed_processes.add(identifier)
+ self.process_status_notification = "\n".join(self._status_notifications)
+
+ def has_partial_results(self, identifier):
+ if identifier == "bands":
+ return self._has_bands
+ elif identifier == "pdos":
+ return self._has_pdos
+ return False
@property
def _has_bands(self):
diff --git a/src/aiidalab_qe/plugins/electronic_structure/result/result.py b/src/aiidalab_qe/plugins/electronic_structure/result/result.py
index 87eb8d740..c81b13eaf 100644
--- a/src/aiidalab_qe/plugins/electronic_structure/result/result.py
+++ b/src/aiidalab_qe/plugins/electronic_structure/result/result.py
@@ -1,16 +1,104 @@
"""Electronic structure results view widgets"""
+import ipywidgets as ipw
+
from aiidalab_qe.common.bands_pdos import BandsPdosModel, BandsPdosWidget
from aiidalab_qe.common.panel import ResultsPanel
+from aiidalab_qe.common.widgets import LoadingWidget
from .model import ElectronicStructureResultsModel
class ElectronicStructureResultsPanel(ResultsPanel[ElectronicStructureResultsModel]):
+ has_property_selector = False
+
def _render(self):
- bands_node = self._model.get_bands_node()
- pdos_node = self._model.get_pdos_node()
- model = BandsPdosModel()
- widget = BandsPdosWidget(model=model, bands=bands_node, pdos=pdos_node)
+ self.bands_pdos_container = ipw.VBox()
+ children = []
+ if self._model.needs_property_selector:
+ children.append(self._render_property_selector())
+ self.has_property_selector = True
+ children.append(self.bands_pdos_container)
+ self.results_container.children = children
+
+ def _post_render(self):
+ self._render_property_results()
+
+ def _render_property_selector(self):
+ apply_button = ipw.Button(
+ description="Apply selection",
+ button_style="primary",
+ icon="pencil",
+ )
+ apply_button.on_click(self._render_property_results)
+
+ property_selector = ipw.HBox(layout=ipw.Layout(grid_gap="10px"))
+
+ self.checkboxes: dict[str, ipw.Checkbox] = {}
+ for identifier in self._model.identifiers:
+ checkbox = ipw.Checkbox(
+ description=self._model._TITLE_MAPPING[identifier],
+ indent=False,
+ value=self._model.has_partial_results(identifier),
+ layout=ipw.Layout(width="fit-content"),
+ )
+ ipw.dlink(
+ (self._model, "monitor_counter"),
+ (checkbox, "disabled"),
+ lambda _, cid=identifier: not self._model.has_partial_results(cid),
+ )
+ ipw.dlink(
+ (checkbox, "value"),
+ (apply_button, "disabled"),
+ lambda _: not any(cb.value for cb in self.checkboxes.values()),
+ )
+ self.checkboxes[identifier] = checkbox
+ property_selector.children += (checkbox,)
+
+ property_selector.children += (apply_button,)
+
+ return ipw.VBox(
+ children=[
+ ipw.HTML("""
+
+
Select one or more properties to plot:
+
+ You can choose to plot only bands, only PDOS, or both
+ combined in one plot. After making your selection, click
+
+ Apply selection
+
+ to proceed.
+
+
+ """),
+ property_selector,
+ ]
+ )
+
+ def _render_property_results(self, _=None):
+ node_identifiers = (
+ [
+ identifier
+ for identifier, checkbox in self.checkboxes.items()
+ if checkbox.value
+ ]
+ if self.has_property_selector
+ else self._model.identifiers
+ )
+ self._render_bands_pdos_widget(node_identifiers)
+
+ def _render_bands_pdos_widget(self, node_identifiers):
+ message = f"Loading {' + '.join(node_identifiers)} results"
+ self.bands_pdos_container.children = [LoadingWidget(message)]
+ nodes = {
+ **{
+ identifier: self._model.fetch_child_process_node(identifier)
+ for identifier in node_identifiers
+ },
+ "root": self._model.fetch_process_node(),
+ }
+ model = BandsPdosModel.from_nodes(**nodes)
+ widget = BandsPdosWidget(model=model)
widget.render()
- self.children = [widget]
+ self.bands_pdos_container.children = [widget]
diff --git a/src/aiidalab_qe/plugins/pdos/__init__.py b/src/aiidalab_qe/plugins/pdos/__init__.py
index b22119506..2d188e926 100644
--- a/src/aiidalab_qe/plugins/pdos/__init__.py
+++ b/src/aiidalab_qe/plugins/pdos/__init__.py
@@ -2,7 +2,6 @@
from .model import PdosConfigurationSettingsModel
from .resources import PdosResourceSettingsModel, PdosResourceSettingsPanel
-from .result import PdosResultsModel, PdosResultsPanel
from .setting import PdosConfigurationSettingPanel
from .workchain import workchain_and_builder
@@ -21,9 +20,5 @@ class PdosPluginOutline(PluginOutline):
"panel": PdosResourceSettingsPanel,
"model": PdosResourceSettingsModel,
},
- "result": {
- "panel": PdosResultsPanel,
- "model": PdosResultsModel,
- },
"workchain": workchain_and_builder,
}
diff --git a/src/aiidalab_qe/plugins/pdos/result/__init__.py b/src/aiidalab_qe/plugins/pdos/result/__init__.py
deleted file mode 100644
index 30069fbb8..000000000
--- a/src/aiidalab_qe/plugins/pdos/result/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .model import PdosResultsModel
-from .result import PdosResultsPanel
-
-__all__ = [
- "PdosResultsModel",
- "PdosResultsPanel",
-]
diff --git a/src/aiidalab_qe/plugins/pdos/result/model.py b/src/aiidalab_qe/plugins/pdos/result/model.py
deleted file mode 100644
index c010c0f86..000000000
--- a/src/aiidalab_qe/plugins/pdos/result/model.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from aiidalab_qe.common.panel import ResultsModel
-
-
-class PdosResultsModel(ResultsModel):
- title = "PDOS"
- identifier = "pdos"
-
- _this_process_label = "PdosWorkChain"
-
- def get_pdos_node(self):
- return self._get_child_outputs()
diff --git a/src/aiidalab_qe/plugins/pdos/result/result.py b/src/aiidalab_qe/plugins/pdos/result/result.py
deleted file mode 100644
index eefa4e705..000000000
--- a/src/aiidalab_qe/plugins/pdos/result/result.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""PDOS results view widgets"""
-
-from aiidalab_qe.common.bands_pdos import BandsPdosModel, BandsPdosWidget
-from aiidalab_qe.common.panel import ResultsPanel
-
-from .model import PdosResultsModel
-
-
-class PdosResultsPanel(ResultsPanel[PdosResultsModel]):
- def _render(self):
- pdos_node = self._model.get_pdos_node()
- model = BandsPdosModel()
- widget = BandsPdosWidget(model=model, pdos=pdos_node)
- widget.render()
- self.children = [widget]
diff --git a/src/aiidalab_qe/plugins/xas/result/result.py b/src/aiidalab_qe/plugins/xas/result/result.py
index a4edb1772..29d13fc81 100644
--- a/src/aiidalab_qe/plugins/xas/result/result.py
+++ b/src/aiidalab_qe/plugins/xas/result/result.py
@@ -125,7 +125,7 @@ def _render(self):
)
self.plot.layout.xaxis.title = "Relative Photon Energy (eV)"
- self.children = [
+ self.results_container.children = [
ipw.HBox(
[
ipw.VBox(
diff --git a/src/aiidalab_qe/plugins/xps/result/result.py b/src/aiidalab_qe/plugins/xps/result/result.py
index fbd22b89f..cbdee2842 100644
--- a/src/aiidalab_qe/plugins/xps/result/result.py
+++ b/src/aiidalab_qe/plugins/xps/result/result.py
@@ -148,7 +148,7 @@ def _render(self):
self.plot.layout.xaxis.title = "Chemical shift (eV)"
self.plot.layout.xaxis.autorange = "reversed"
- self.children = [
+ self.results_container.children = [
spectra_type,
ipw.HBox(
children=[
diff --git a/tests/conftest.py b/tests/conftest.py
index d163fcda5..867464fc5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -820,29 +820,29 @@ def _generate_qeapp_workchain(
inputs = builder._inputs()
inputs["relax"]["base_final_scf"] = deepcopy(inputs["relax"]["base"])
- # Setting up inputs for bands_projwfc
- inputs["bands"]["bands_projwfc"]["scf"]["pw"] = deepcopy(
- inputs["bands"]["bands"]["scf"]["pw"]
- )
- inputs["bands"]["bands_projwfc"]["bands"]["pw"] = deepcopy(
- inputs["bands"]["bands"]["bands"]["pw"]
- )
- inputs["bands"]["bands_projwfc"]["bands"]["pw"]["code"] = inputs["bands"][
- "bands"
- ]["bands"]["pw"]["code"]
- inputs["bands"]["bands_projwfc"]["scf"]["pw"]["code"] = inputs["bands"][
- "bands"
- ]["scf"]["pw"]["code"]
-
- inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["code"] = fixture_code(
- "quantumespresso.projwfc"
- )
- inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["parameters"] = Dict(
- {"PROJWFC": {"DeltaE": 0.01}}
- ).store()
-
if run_bands:
+ # Setting up inputs for bands_projwfc
+ inputs["bands"]["bands_projwfc"]["scf"]["pw"] = deepcopy(
+ inputs["bands"]["bands"]["scf"]["pw"]
+ )
+ inputs["bands"]["bands_projwfc"]["bands"]["pw"] = deepcopy(
+ inputs["bands"]["bands"]["bands"]["pw"]
+ )
+ inputs["bands"]["bands_projwfc"]["bands"]["pw"]["code"] = inputs["bands"][
+ "bands"
+ ]["bands"]["pw"]["code"]
+ inputs["bands"]["bands_projwfc"]["scf"]["pw"]["code"] = inputs["bands"][
+ "bands"
+ ]["scf"]["pw"]["code"]
+
+ inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["code"] = (
+ fixture_code("quantumespresso.projwfc")
+ )
+ inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["parameters"] = Dict(
+ {"PROJWFC": {"DeltaE": 0.01}}
+ ).store()
inputs["properties"].append("bands")
+
if run_pdos:
inputs["properties"].append("pdos")
@@ -852,6 +852,7 @@ def _generate_qeapp_workchain(
# mock output
if relax_type != "none":
workchain.out("structure", app.structure_model.input_structure)
+
if run_pdos:
from aiida_quantumespresso.workflows.pdos import PdosWorkChain
@@ -863,6 +864,7 @@ def _generate_qeapp_workchain(
namespace="pdos",
)
)
+
if run_bands:
from aiidalab_qe.plugins.bands.bands_workchain import BandsWorkChain
diff --git a/tests/test_plugins_bands.py b/tests/test_plugins_bands.py
index a6067bab4..fb973bc25 100644
--- a/tests/test_plugins_bands.py
+++ b/tests/test_plugins_bands.py
@@ -2,30 +2,38 @@ def test_result(generate_qeapp_workchain):
import plotly.graph_objects as go
from aiidalab_qe.common.bands_pdos import BandsPdosWidget
- from aiidalab_qe.plugins.bands.result import BandsResultsModel, BandsResultsPanel
+ from aiidalab_qe.plugins.electronic_structure.result import (
+ ElectronicStructureResultsModel,
+ ElectronicStructureResultsPanel,
+ )
+
+ workchain = generate_qeapp_workchain(run_pdos=False)
+ model = ElectronicStructureResultsModel()
+ panel = ElectronicStructureResultsPanel(model=model)
- workchain = generate_qeapp_workchain()
- model = BandsResultsModel()
model.process_uuid = workchain.node.uuid
- result = BandsResultsPanel(model=model)
- result._render()
- widget = result.children[0]
+ assert model.title == "Electronic bands"
+ assert model.identifiers == ["bands"]
+
+ panel.render()
+
+ assert len(panel.results_container.children) == 1 # only bands, so no controls
+
+ widget = panel.bands_pdos_container.children[0] # type: ignore
model = widget._model
assert isinstance(widget, BandsPdosWidget)
assert isinstance(widget.plot, go.FigureWidget)
- # Check if data is correct
assert not model.pdos_data
assert model.bands_data
- assert model.bands_data["pathlabels"] # type: ignore
+ assert model.bands_data["pathlabels"]
- # Check Bands axis
assert widget.plot.layout.xaxis.title.text == "k-points"
assert widget.plot.layout.yaxis.title.text == "Electronic Bands (eV)"
assert isinstance(widget.plot.layout.xaxis.rangeslider, go.layout.xaxis.Rangeslider)
- assert model.bands_data["pathlabels"][0] == list(widget.plot.layout.xaxis.ticktext) # type: ignore
+ assert model.bands_data["pathlabels"][0] == list(widget.plot.layout.xaxis.ticktext)
def test_structure_1d(generate_qeapp_workchain, generate_structure_data):
diff --git a/tests/test_plugins_electronic_structure.py b/tests/test_plugins_electronic_structure.py
index 18f381be6..265162f3b 100644
--- a/tests/test_plugins_electronic_structure.py
+++ b/tests/test_plugins_electronic_structure.py
@@ -9,11 +9,17 @@ def test_electronic_structure(generate_qeapp_workchain):
workchain = generate_qeapp_workchain()
model = ElectronicStructureResultsModel()
+ panel = ElectronicStructureResultsPanel(model=model)
model.process_uuid = workchain.node.uuid
- result = ElectronicStructureResultsPanel(model=model)
- result._render()
- widget = result.children[0]
+ assert model.title == "Electronic bands + PDOS"
+ assert model.identifiers == ["bands", "pdos"]
+
+ panel.render()
+
+ assert len(panel.results_container.children) == 2 # has controls
+
+ widget = panel.bands_pdos_container.children[0] # type: ignore
model = widget._model
assert isinstance(widget, BandsPdosWidget)
diff --git a/tests/test_plugins_pdos.py b/tests/test_plugins_pdos.py
index 8bad6bd03..23e0666e8 100644
--- a/tests/test_plugins_pdos.py
+++ b/tests/test_plugins_pdos.py
@@ -2,26 +2,31 @@ def test_result(generate_qeapp_workchain):
import plotly.graph_objects as go
from aiidalab_qe.common.bands_pdos import BandsPdosWidget
- from aiidalab_qe.plugins.pdos.result import PdosResultsModel, PdosResultsPanel
+ from aiidalab_qe.plugins.electronic_structure.result import (
+ ElectronicStructureResultsModel,
+ ElectronicStructureResultsPanel,
+ )
- workchain = generate_qeapp_workchain()
- model = PdosResultsModel()
+ workchain = generate_qeapp_workchain(run_bands=False)
+ model = ElectronicStructureResultsModel()
+ panel = ElectronicStructureResultsPanel(model=model)
model.process_uuid = workchain.node.uuid
- result = PdosResultsPanel(model=model)
- result._render()
- widget = result.children[0]
+ assert model.title == "Electronic PDOS"
+ assert model.identifiers == ["pdos"]
+
+ panel.render()
+
+ assert len(panel.results_container.children) == 1 # only pdos, so no controls
+
+ widget = panel.bands_pdos_container.children[0] # type: ignore
model = widget._model
assert isinstance(widget, BandsPdosWidget)
assert isinstance(widget.plot, go.FigureWidget)
- # Check if data is correct
assert not model.bands_data
assert model.pdos_data
- # Check PDOS settings is not None
-
- # Check Bands axis
assert widget.plot.layout.xaxis.title.text == "Density of states (eV)"
assert widget.plot.layout.yaxis.title.text is None
diff --git a/tests/test_result.py b/tests/test_result.py
index db7de7038..b5932fdbb 100644
--- a/tests/test_result.py
+++ b/tests/test_result.py
@@ -40,7 +40,7 @@ def test_workchainview(generate_qeapp_workchain):
viewer = WorkChainResultsViewer(model=model)
model.process_uuid = workchain.node.uuid
viewer.render()
- assert len(viewer.tabs.children) == 4
+ assert len(viewer.tabs.children) == 2
assert viewer.tabs._titles["0"] == "Relaxed structure" # type: ignore