diff --git a/src/aiidalab_qe/app/result/__init__.py b/src/aiidalab_qe/app/result/__init__.py index e81a80e50..eabe8469d 100644 --- a/src/aiidalab_qe/app/result/__init__.py +++ b/src/aiidalab_qe/app/result/__init__.py @@ -121,7 +121,7 @@ def _post_render(self): self.toggle_controls.value = "Summary" self.process_monitor = ProcessMonitor( - timeout=0.2, + timeout=0.5, callbacks=[ self._update_status, self._update_state, diff --git a/src/aiidalab_qe/app/result/components/viewer/structure/model.py b/src/aiidalab_qe/app/result/components/viewer/structure/model.py index 3d56ab48a..eb9526357 100644 --- a/src/aiidalab_qe/app/result/components/viewer/structure/model.py +++ b/src/aiidalab_qe/app/result/components/viewer/structure/model.py @@ -9,12 +9,15 @@ class StructureResultsModel(ResultsModel): source = None + @property + def include(self): + return True + def update(self): - super().update() is_relaxed = "relax" in self.properties self.title = "Relaxed structure" if is_relaxed else "Initial structure" self.source = self.outputs if is_relaxed else self.inputs - self.auto_render = not is_relaxed # auto-render initial structure + self.auto_render = not is_relaxed or self.has_results def get_structure(self): try: diff --git a/src/aiidalab_qe/app/result/components/viewer/structure/structure.py b/src/aiidalab_qe/app/result/components/viewer/structure/structure.py index 93a95e5c4..3cae7cb85 100644 --- a/src/aiidalab_qe/app/result/components/viewer/structure/structure.py +++ b/src/aiidalab_qe/app/result/components/viewer/structure/structure.py @@ -9,7 +9,7 @@ def _render(self): if not hasattr(self, "widget"): structure = self._model.get_structure() self.widget = StructureDataViewer(structure=structure) - self.children = [self.widget] + self.results_container.children = [self.widget] # HACK to resize the NGL viewer in cases where it auto-rendered when its # container was not displayed, which leads to a null width. This hack restores diff --git a/src/aiidalab_qe/app/result/components/viewer/viewer.py b/src/aiidalab_qe/app/result/components/viewer/viewer.py index 288671e51..466daeedf 100644 --- a/src/aiidalab_qe/app/result/components/viewer/viewer.py +++ b/src/aiidalab_qe/app/result/components/viewer/viewer.py @@ -12,10 +12,15 @@ class WorkChainResultsViewer(ResultsComponent[WorkChainResultsViewerModel]): def __init__(self, model: WorkChainResultsViewerModel, **kwargs): - super().__init__(model=model, **kwargs) + # NOTE: here we want to add the structure and plugin models to the viewer + # model BEFORE we define the observation of the process uuid. This ensures + # that when the process changes, its reflected in the sub-models prior to + # the logic of the process change event handler. + # TODO avoid exceptions! Ensure sub-model synchronization in general! self.panels: dict[str, ResultsPanel] = {} - self._add_structure_panel() # TODO consider refactoring structure panel as a plugin - self._fetch_plugin_results() + self._add_structure_panel(model) + self._fetch_plugin_results(model) + super().__init__(model=model, **kwargs) def _on_process_change(self, _): self._update_panels() @@ -40,14 +45,10 @@ def _post_render(self): self._set_tabs() def _update_panels(self): - properties = self._model.properties - need_electronic_structure = "bands" in properties and "pdos" in properties self.panels = { - identifier: panel - for identifier, panel in self.panels.items() - if identifier == "structure" - or identifier in properties - or (identifier == "electronic_structure" and need_electronic_structure) + identifier: self.panels[identifier] + for identifier, model in self._model.get_models() + if model.include } def _set_tabs(self): @@ -65,18 +66,18 @@ def _set_tabs(self): if children: self.tabs.selected_index = 0 - def _add_structure_panel(self): + def _add_structure_panel(self, viewer_model: WorkChainResultsViewerModel): structure_model = StructureResultsModel() - structure_model.process_uuid = self._model.process_uuid + structure_model.process_uuid = viewer_model.process_uuid self.structure_results = StructureResultsPanel(model=structure_model) identifier = structure_model.identifier - self._model.add_model(identifier, structure_model) + viewer_model.add_model(identifier, structure_model) self.panels = { identifier: self.structure_results, **self.panels, } - def _fetch_plugin_results(self): + def _fetch_plugin_results(self, viewer_model: WorkChainResultsViewerModel): entries = get_entry_items("aiidalab_qe.properties", "result") for identifier, entry in entries.items(): for key in ("panel", "model"): @@ -86,7 +87,7 @@ def _fetch_plugin_results(self): ) panel = entry["panel"] model = entry["model"]() - self._model.add_model(identifier, model) + viewer_model.add_model(identifier, model) self.panels[identifier] = panel( identifier=identifier, model=model, diff --git a/src/aiidalab_qe/app/static/styles/custom.css b/src/aiidalab_qe/app/static/styles/custom.css index 89975761d..3a0e85130 100644 --- a/src/aiidalab_qe/app/static/styles/custom.css +++ b/src/aiidalab_qe/app/static/styles/custom.css @@ -108,3 +108,7 @@ footer { .p-Accordion-child:has(.qe-app-step-fail) > .p-Collapse-header { background-color: var(--color-failed); } + +.p-TabBar-tab { + min-width: fit-content !important; +} diff --git a/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py b/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py index e93653977..123548387 100644 --- a/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py +++ b/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py @@ -7,31 +7,34 @@ class BandsPdosWidget(ipw.VBox): - """ - A widget for plotting band structure and projected density of states (PDOS) data. + """A widget for plotting band structure and projected density of states (PDOS). Parameters ---------- - - bands (optional): A node containing band structure data. - - pdos (optional): A node containing PDOS data. + `model`: `BandsPdosModel` + The MVC model containing the data and logic for the widget. Attributes ---------- - - description: HTML description of the widget. - - dos_atoms_group: Dropdown widget to select the grouping of atoms for PDOS plotting. - - dos_plot_group: Dropdown widget to select the type of PDOS contributions to plot. - - selected_atoms: Text widget to select specific atoms for PDOS plotting. - - update_plot_button: Button widget to update the plot. - - download_button: Button widget to download the data. - - project_bands_box: Checkbox widget to choose whether projected bands should be plotted. - - plot_widget: Plotly widget for band structure and PDOS plot. - - bands_widget: Output widget to display the bandsplot widget. + `description`: `ipywidgets.HTML` + HTML description of the widget. + `dos_atoms_group`: `ipywidgets.Dropdown` + Dropdown widget to select the grouping of atoms for PDOS plotting. + `dos_plot_group`: `ipywidgets.Dropdown` + Dropdown widget to select the type of PDOS contributions to plot. + `selected_atoms`: `ipywidgets.Text` + Text widget to select specific atoms for PDOS plotting. + `update_plot_button`: `ipywidgets.Button` + Button widget to update the plot. + `download_button`: `ipywidgets.Button` + Button widget to download the data. + `project_bands_box`: `ipywidgets.Checkbox` + Checkbox widget to choose whether projected bands should be plotted. + `plot`: `plotly.graph_objects.FigureWidget` + Plotly widget for band structure and PDOS plot. """ - def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs): - if bands is None and pdos is None: - raise ValueError("Either bands or pdos must be provided") - + def __init__(self, model: BandsPdosModel, **kwargs): super().__init__( children=[LoadingWidget("Loading widgets")], **kwargs, @@ -49,9 +52,6 @@ def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs): self.rendered = False - self._model.bands = bands - self._model.pdos = pdos - def render(self): if self.rendered: return diff --git a/src/aiidalab_qe/common/bands_pdos/model.py b/src/aiidalab_qe/common/bands_pdos/model.py index 690ef10ca..c9f2b72e2 100644 --- a/src/aiidalab_qe/common/bands_pdos/model.py +++ b/src/aiidalab_qe/common/bands_pdos/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json @@ -6,9 +8,12 @@ import traitlets as tl from IPython.display import display +from aiida import orm from aiida.common.extendeddicts import AttributeDict from aiidalab_qe.common.bands_pdos.utils import ( HTML_TAGS, + extract_bands_output, + extract_pdos_output, get_bands_data, get_bands_projections_data, get_pdos_data, @@ -79,6 +84,49 @@ def __init__(self, *args, **kwargs): lambda _: self._has_pdos or self.needs_projections_controls, ) + @classmethod + def from_nodes( + cls, + bands: orm.WorkChainNode | None = None, + pdos: orm.WorkChainNode | None = None, + root: orm.WorkChainNode | None = None, + ): + """Create a `BandsPdosModel` instance from the provided nodes. + + The method attempts to extract the output attribute dictionaries from the + nodes and creates from them an instance of the model. + + Parameters + ---------- + `bands` : `orm.WorkChainNode`, optional + The bands workchain node. + `pdos` : `orm.WorkChainNode`, optional + The PDOS workchain node. + `root`: `orm.WorkChainNode`, optional + The root QE app workchain node. + + Returns + ------- + `BandsPdosModel` + The model instance. + + Raises + ------ + `ValueError` + If neither of the nodes is provided or if the parsing of the nodes fails. + """ + if bands or pdos: + bands_output = extract_bands_output(bands) + pdos_output = extract_pdos_output(pdos) + elif root: + bands_output = extract_bands_output(root) + pdos_output = extract_pdos_output(root) + else: + raise ValueError("At least one of the nodes must be provided") + if bands_output or pdos_output: + return cls(bands=bands_output, pdos=pdos_output) + raise ValueError("Failed to parse at least one node") + def fetch_data(self): """Fetch the data from the nodes.""" if self.bands: diff --git a/src/aiidalab_qe/common/bands_pdos/utils.py b/src/aiidalab_qe/common/bands_pdos/utils.py index d0203b171..c73b20222 100644 --- a/src/aiidalab_qe/common/bands_pdos/utils.py +++ b/src/aiidalab_qe/common/bands_pdos/utils.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import json import re import numpy as np from pymatgen.core.periodic_table import Element -from aiida.orm import ProjectionData +from aiida.common.extendeddicts import AttributeDict +from aiida.orm import ProjectionData, WorkChainNode # Constants for HTML tags HTML_TAGS = { @@ -37,6 +40,56 @@ } +def extract_pdos_output(node: WorkChainNode) -> AttributeDict | None: + """Extract the PDOS output node from the given node. + + Parameters + ---------- + `node`: `WorkChainNode` + The node to extract the PDOS output from. + + Returns + ------- + `AttributeDict | None` + The PDOS output node, if available. + """ + if not node: + return + if node.process_label == "QeAppWorkChain" and "pdos" in node.outputs: + return node.outputs.pdos + if "dos" in node.outputs and "projwfc" in node.outputs: + items = {key: getattr(node.outputs, key) for key in node.outputs} + return AttributeDict(items) + + +def extract_bands_output(node: WorkChainNode) -> AttributeDict | None: + """Extract the bands output node from the given node. + + Parameters + ---------- + `node`: `WorkChainNode` + The node to extract the bands output from. + + Returns + ------- + `AttributeDict | None` + The bands output node, if available. + """ + if not node: + return + if node.process_label == "QeAppWorkChain" and "bands" in node.outputs: + outputs = node.outputs.bands + else: + outputs = node.outputs + return ( + outputs.bands + if "bands" in outputs + else outputs.bands_projwfc + if "bands_projwfc" in outputs + else None + ) + + def get_bands_data(outputs, fermi_energy=None): if "band_structure" not in outputs: return None diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py index 223de660a..561e2b2de 100644 --- a/src/aiidalab_qe/common/panel.py +++ b/src/aiidalab_qe/common/panel.py @@ -507,6 +507,7 @@ class ResultsModel(PanelModel, HasProcess): _this_process_uuid = None auto_render = False + _completed_process = False CSS_MAP = { "finished": "success", @@ -519,9 +520,13 @@ class ResultsModel(PanelModel, HasProcess): "created": "info", } + @property + def include(self): + return self.identifier in self.properties + @property def has_results(self): - node = self._fetch_child_process_node() + node = self.fetch_child_process_node() return node and node.is_finished_ok def update(self): @@ -529,14 +534,32 @@ def update(self): self.auto_render = True def update_process_status_notification(self): - self.process_status_notification = self._get_child_process_status() + if self._completed_process: + self.process_status_notification = "" + return + status = self._get_child_process_status() + self.process_status_notification = status + if "success" in status: + self._completed_process = True + + def fetch_child_process_node(self, which="this") -> orm.ProcessNode | None: + if not self.process_uuid: + return + which = which.lower() + uuid = getattr(self, f"_{which}_process_uuid") + label = getattr(self, f"_{which}_process_label") + if not uuid: + root = self.fetch_process_node() + child = next((c for c in root.called if c.process_label == label), None) + uuid = child.uuid if child else None + return orm.load_node(uuid) if uuid else None # type: ignore - def _get_child_process_status(self, child="this"): - state, exit_message = self._get_child_state_and_exit_message(child) + def _get_child_process_status(self, which="this"): + state, exit_message = self._get_child_state_and_exit_message(which) status = state.upper() if exit_message: status = f"{status} ({exit_message})" - label = "Status" if child == "this" else f"{child.capitalize()} status" + label = "Status" if which == "this" else f"{which.capitalize()} status" alert_class = f"alert-{self.CSS_MAP.get(state, 'info')}" return f"""
@@ -544,9 +567,9 @@ def _get_child_process_status(self, child="this"):
""" - def _get_child_state_and_exit_message(self, child="this"): + def _get_child_state_and_exit_message(self, which="this"): if not ( - (node := self._fetch_child_process_node(child)) + (node := self.fetch_child_process_node(which)) and hasattr(node, "process_state") and node.process_state ): @@ -555,37 +578,13 @@ def _get_child_state_and_exit_message(self, child="this"): return "failed", node.exit_message return node.process_state.value, None - def _get_child_outputs(self, child="this"): - if not (node := self._fetch_child_process_node(child)): + def _get_child_outputs(self, which="this"): + if not (node := self.fetch_child_process_node(which)): outputs = super().outputs - child = child if child != "this" else self.identifier + child = which if which != "this" else self.identifier return getattr(outputs, child) if child in outputs else AttributeDict({}) return AttributeDict({key: getattr(node.outputs, key) for key in node.outputs}) - def _fetch_child_process_node(self, child="this") -> orm.ProcessNode | None: - if not self.process_uuid: - return - child = child.lower() - uuid = getattr(self, f"_{child}_process_uuid") - label = getattr(self, f"_{child}_process_label") - if not uuid: - uuid = ( - orm.QueryBuilder() - .append( - orm.WorkChainNode, - filters={"uuid": self.process_uuid}, - tag="root_process", - ) - .append( - orm.WorkChainNode, - filters={"attributes.process_label": label}, - project="uuid", - with_incoming="root_process", - ) - .first(flat=True) - ) - return orm.load_node(uuid) if uuid else None # type: ignore - RM = t.TypeVar("RM", bound=ResultsModel) @@ -598,12 +597,10 @@ class ResultsPanel(Panel[RM]): It has a update method to update the result in the panel. """ - has_controls = False loading_message = "Loading {identifier} results" def __init__(self, model: RM, **kwargs): super().__init__(model=model, **kwargs) - self._model.observe( self._on_process_change, "process_uuid", @@ -613,19 +610,23 @@ def __init__(self, model: RM, **kwargs): "monitor_counter", ) - self.links = [] - def render(self): if self.rendered: if self._model.identifier == "structure": self._render() return - if self.has_controls or not self._model.has_process: + + if not self._model.has_process: return + + self.results_container = ipw.VBox() + if self._model.auto_render: + self.children = [self.results_container] self._load_results() else: self._render_controls() + self.children += (self.results_container,) def _on_process_change(self, _): self._model.update() @@ -634,14 +635,14 @@ def _on_monitor_counter_change(self, _): self._model.update_process_status_notification() def _on_load_results_click(self, _): + self.load_controls.children = [] self._load_results() def _load_results(self): - self.children = [self.loading_message] + self.results_container.children = [self.loading_message] self._render() self.rendered = True self._post_render() - self.has_controls = False def _render_controls(self): self.process_status_notification = ipw.HTML() @@ -663,22 +664,25 @@ def _render_controls(self): ) self.load_results_button.on_click(self._on_load_results_click) + self.load_controls = ipw.HBox( + children=[] + if self._model.auto_render + else [ + self.load_results_button, + ipw.HTML(""" +
+ Note: Load time may vary depending on the size of the + calculation +
+ """), + ] + ) + self.children = [ self.process_status_notification, - ipw.HBox( - children=[ - self.load_results_button, - ipw.HTML(""" -
- Note: Load time may vary depending on the size of the calculation -
- """), - ] - ), + self.load_controls, ] - self.has_controls = True - def _render(self): raise NotImplementedError() diff --git a/src/aiidalab_qe/plugins/bands/__init__.py b/src/aiidalab_qe/plugins/bands/__init__.py index 3c2a97aca..13ebfed85 100644 --- a/src/aiidalab_qe/plugins/bands/__init__.py +++ b/src/aiidalab_qe/plugins/bands/__init__.py @@ -5,7 +5,6 @@ from .model import BandsConfigurationSettingsModel from .resources import BandsResourceSettingsModel, BandsResourceSettingsPanel -from .result import BandsResultsModel, BandsResultsPanel from .setting import BandsConfigurationSettingsPanel from .workchain import workchain_and_builder @@ -24,10 +23,6 @@ class BandsPluginOutline(PluginOutline): "panel": BandsResourceSettingsPanel, "model": BandsResourceSettingsModel, }, - "result": { - "panel": BandsResultsPanel, - "model": BandsResultsModel, - }, "workchain": workchain_and_builder, "guides": Path(__file__).parent / "guides", } diff --git a/src/aiidalab_qe/plugins/bands/result/__init__.py b/src/aiidalab_qe/plugins/bands/result/__init__.py deleted file mode 100644 index 6d4e06a46..000000000 --- a/src/aiidalab_qe/plugins/bands/result/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .model import BandsResultsModel -from .result import BandsResultsPanel - -__all__ = [ - "BandsResultsModel", - "BandsResultsPanel", -] diff --git a/src/aiidalab_qe/plugins/bands/result/model.py b/src/aiidalab_qe/plugins/bands/result/model.py deleted file mode 100644 index 1f9df2bb9..000000000 --- a/src/aiidalab_qe/plugins/bands/result/model.py +++ /dev/null @@ -1,19 +0,0 @@ -from aiidalab_qe.common.panel import ResultsModel - - -class BandsResultsModel(ResultsModel): - title = "Bands" - identifier = "bands" - - _this_process_label = "BandsWorkChain" - - def get_bands_node(self): - outputs = self._get_child_outputs() - if "bands" in outputs: - return outputs.bands - elif "bands_projwfc" in outputs: - return outputs.bands_projwfc - else: - # If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself - # This is the case for compatibility with older versions of the plugin - return outputs diff --git a/src/aiidalab_qe/plugins/bands/result/result.py b/src/aiidalab_qe/plugins/bands/result/result.py deleted file mode 100644 index ad940024d..000000000 --- a/src/aiidalab_qe/plugins/bands/result/result.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Bands results view widgets""" - -from aiidalab_qe.common.bands_pdos import BandsPdosModel, BandsPdosWidget -from aiidalab_qe.common.panel import ResultsPanel - -from .model import BandsResultsModel - - -class BandsResultsPanel(ResultsPanel[BandsResultsModel]): - def _render(self): - bands_node = self._model.get_bands_node() - model = BandsPdosModel() - widget = BandsPdosWidget(model=model, bands=bands_node) - widget.render() - self.children = [widget] diff --git a/src/aiidalab_qe/plugins/electronic_structure/result/model.py b/src/aiidalab_qe/plugins/electronic_structure/result/model.py index 5f8534f7f..eca444f71 100644 --- a/src/aiidalab_qe/plugins/electronic_structure/result/model.py +++ b/src/aiidalab_qe/plugins/electronic_structure/result/model.py @@ -3,8 +3,6 @@ from aiidalab_qe.common.panel import ResultsModel -# TODO if combined, this model should extend `HasModels`, and effectively -# TODO reduce to a container of Bands and PDOS, similar to its results panel class ElectronicStructureResultsModel(ResultsModel): title = "Electronic Structure" identifier = "electronic_structure" @@ -17,31 +15,55 @@ class ElectronicStructureResultsModel(ResultsModel): _pdos_process_label = "PdosWorkChain" _pdos_process_uuid = None - def get_pdos_node(self): - return self._get_child_outputs("pdos") + _TITLE_MAPPING = { + "bands": "bands", + "pdos": "PDOS", + } - def get_bands_node(self): - outputs = self._get_child_outputs("bands") - if "bands" in outputs: - return outputs.bands - elif "bands_projwfc" in outputs: - return outputs.bands_projwfc - else: - # If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself - # This is the case for compatibility with older versions of the plugin - return outputs + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._completed_processes = set() @property def include(self): - return all(identifier in self.properties for identifier in self.identifiers) + return any(identifier in self.properties for identifier in self.identifiers) @property def has_results(self): - return self._has_bands and self._has_pdos + return self._has_bands or self._has_pdos + + @property + def needs_property_selector(self): + return len(self.identifiers) > 1 + + def update(self): + super().update() + self.identifiers = list( + filter( + lambda identifier: identifier in self.properties, + self.identifiers, + ), + ) + parts = [self._TITLE_MAPPING[identifier] for identifier in self.identifiers] + self.title = f"Electronic {' + '.join(parts)}" def update_process_status_notification(self): - statuses = [self._get_child_process_status(child) for child in self.identifiers] - self.process_status_notification = "\n".join(statuses) + self._status_notifications = [] + for identifier in self.identifiers: + if identifier in self._completed_processes: + continue + status = self._get_child_process_status(identifier) + self._status_notifications.append(status) + if "success" in status: + self._completed_processes.add(identifier) + self.process_status_notification = "\n".join(self._status_notifications) + + def has_partial_results(self, identifier): + if identifier == "bands": + return self._has_bands + elif identifier == "pdos": + return self._has_pdos + return False @property def _has_bands(self): diff --git a/src/aiidalab_qe/plugins/electronic_structure/result/result.py b/src/aiidalab_qe/plugins/electronic_structure/result/result.py index 87eb8d740..c81b13eaf 100644 --- a/src/aiidalab_qe/plugins/electronic_structure/result/result.py +++ b/src/aiidalab_qe/plugins/electronic_structure/result/result.py @@ -1,16 +1,104 @@ """Electronic structure results view widgets""" +import ipywidgets as ipw + from aiidalab_qe.common.bands_pdos import BandsPdosModel, BandsPdosWidget from aiidalab_qe.common.panel import ResultsPanel +from aiidalab_qe.common.widgets import LoadingWidget from .model import ElectronicStructureResultsModel class ElectronicStructureResultsPanel(ResultsPanel[ElectronicStructureResultsModel]): + has_property_selector = False + def _render(self): - bands_node = self._model.get_bands_node() - pdos_node = self._model.get_pdos_node() - model = BandsPdosModel() - widget = BandsPdosWidget(model=model, bands=bands_node, pdos=pdos_node) + self.bands_pdos_container = ipw.VBox() + children = [] + if self._model.needs_property_selector: + children.append(self._render_property_selector()) + self.has_property_selector = True + children.append(self.bands_pdos_container) + self.results_container.children = children + + def _post_render(self): + self._render_property_results() + + def _render_property_selector(self): + apply_button = ipw.Button( + description="Apply selection", + button_style="primary", + icon="pencil", + ) + apply_button.on_click(self._render_property_results) + + property_selector = ipw.HBox(layout=ipw.Layout(grid_gap="10px")) + + self.checkboxes: dict[str, ipw.Checkbox] = {} + for identifier in self._model.identifiers: + checkbox = ipw.Checkbox( + description=self._model._TITLE_MAPPING[identifier], + indent=False, + value=self._model.has_partial_results(identifier), + layout=ipw.Layout(width="fit-content"), + ) + ipw.dlink( + (self._model, "monitor_counter"), + (checkbox, "disabled"), + lambda _, cid=identifier: not self._model.has_partial_results(cid), + ) + ipw.dlink( + (checkbox, "value"), + (apply_button, "disabled"), + lambda _: not any(cb.value for cb in self.checkboxes.values()), + ) + self.checkboxes[identifier] = checkbox + property_selector.children += (checkbox,) + + property_selector.children += (apply_button,) + + return ipw.VBox( + children=[ + ipw.HTML(""" +
+ Select one or more properties to plot: +

+ You can choose to plot only bands, only PDOS, or both + combined in one plot. After making your selection, click + + Apply selection + + to proceed. +

+
+ """), + property_selector, + ] + ) + + def _render_property_results(self, _=None): + node_identifiers = ( + [ + identifier + for identifier, checkbox in self.checkboxes.items() + if checkbox.value + ] + if self.has_property_selector + else self._model.identifiers + ) + self._render_bands_pdos_widget(node_identifiers) + + def _render_bands_pdos_widget(self, node_identifiers): + message = f"Loading {' + '.join(node_identifiers)} results" + self.bands_pdos_container.children = [LoadingWidget(message)] + nodes = { + **{ + identifier: self._model.fetch_child_process_node(identifier) + for identifier in node_identifiers + }, + "root": self._model.fetch_process_node(), + } + model = BandsPdosModel.from_nodes(**nodes) + widget = BandsPdosWidget(model=model) widget.render() - self.children = [widget] + self.bands_pdos_container.children = [widget] diff --git a/src/aiidalab_qe/plugins/pdos/__init__.py b/src/aiidalab_qe/plugins/pdos/__init__.py index b22119506..2d188e926 100644 --- a/src/aiidalab_qe/plugins/pdos/__init__.py +++ b/src/aiidalab_qe/plugins/pdos/__init__.py @@ -2,7 +2,6 @@ from .model import PdosConfigurationSettingsModel from .resources import PdosResourceSettingsModel, PdosResourceSettingsPanel -from .result import PdosResultsModel, PdosResultsPanel from .setting import PdosConfigurationSettingPanel from .workchain import workchain_and_builder @@ -21,9 +20,5 @@ class PdosPluginOutline(PluginOutline): "panel": PdosResourceSettingsPanel, "model": PdosResourceSettingsModel, }, - "result": { - "panel": PdosResultsPanel, - "model": PdosResultsModel, - }, "workchain": workchain_and_builder, } diff --git a/src/aiidalab_qe/plugins/pdos/result/__init__.py b/src/aiidalab_qe/plugins/pdos/result/__init__.py deleted file mode 100644 index 30069fbb8..000000000 --- a/src/aiidalab_qe/plugins/pdos/result/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .model import PdosResultsModel -from .result import PdosResultsPanel - -__all__ = [ - "PdosResultsModel", - "PdosResultsPanel", -] diff --git a/src/aiidalab_qe/plugins/pdos/result/model.py b/src/aiidalab_qe/plugins/pdos/result/model.py deleted file mode 100644 index c010c0f86..000000000 --- a/src/aiidalab_qe/plugins/pdos/result/model.py +++ /dev/null @@ -1,11 +0,0 @@ -from aiidalab_qe.common.panel import ResultsModel - - -class PdosResultsModel(ResultsModel): - title = "PDOS" - identifier = "pdos" - - _this_process_label = "PdosWorkChain" - - def get_pdos_node(self): - return self._get_child_outputs() diff --git a/src/aiidalab_qe/plugins/pdos/result/result.py b/src/aiidalab_qe/plugins/pdos/result/result.py deleted file mode 100644 index eefa4e705..000000000 --- a/src/aiidalab_qe/plugins/pdos/result/result.py +++ /dev/null @@ -1,15 +0,0 @@ -"""PDOS results view widgets""" - -from aiidalab_qe.common.bands_pdos import BandsPdosModel, BandsPdosWidget -from aiidalab_qe.common.panel import ResultsPanel - -from .model import PdosResultsModel - - -class PdosResultsPanel(ResultsPanel[PdosResultsModel]): - def _render(self): - pdos_node = self._model.get_pdos_node() - model = BandsPdosModel() - widget = BandsPdosWidget(model=model, pdos=pdos_node) - widget.render() - self.children = [widget] diff --git a/src/aiidalab_qe/plugins/xas/result/result.py b/src/aiidalab_qe/plugins/xas/result/result.py index a4edb1772..29d13fc81 100644 --- a/src/aiidalab_qe/plugins/xas/result/result.py +++ b/src/aiidalab_qe/plugins/xas/result/result.py @@ -125,7 +125,7 @@ def _render(self): ) self.plot.layout.xaxis.title = "Relative Photon Energy (eV)" - self.children = [ + self.results_container.children = [ ipw.HBox( [ ipw.VBox( diff --git a/src/aiidalab_qe/plugins/xps/result/result.py b/src/aiidalab_qe/plugins/xps/result/result.py index fbd22b89f..cbdee2842 100644 --- a/src/aiidalab_qe/plugins/xps/result/result.py +++ b/src/aiidalab_qe/plugins/xps/result/result.py @@ -148,7 +148,7 @@ def _render(self): self.plot.layout.xaxis.title = "Chemical shift (eV)" self.plot.layout.xaxis.autorange = "reversed" - self.children = [ + self.results_container.children = [ spectra_type, ipw.HBox( children=[ diff --git a/tests/conftest.py b/tests/conftest.py index d163fcda5..867464fc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -820,29 +820,29 @@ def _generate_qeapp_workchain( inputs = builder._inputs() inputs["relax"]["base_final_scf"] = deepcopy(inputs["relax"]["base"]) - # Setting up inputs for bands_projwfc - inputs["bands"]["bands_projwfc"]["scf"]["pw"] = deepcopy( - inputs["bands"]["bands"]["scf"]["pw"] - ) - inputs["bands"]["bands_projwfc"]["bands"]["pw"] = deepcopy( - inputs["bands"]["bands"]["bands"]["pw"] - ) - inputs["bands"]["bands_projwfc"]["bands"]["pw"]["code"] = inputs["bands"][ - "bands" - ]["bands"]["pw"]["code"] - inputs["bands"]["bands_projwfc"]["scf"]["pw"]["code"] = inputs["bands"][ - "bands" - ]["scf"]["pw"]["code"] - - inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["code"] = fixture_code( - "quantumespresso.projwfc" - ) - inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["parameters"] = Dict( - {"PROJWFC": {"DeltaE": 0.01}} - ).store() - if run_bands: + # Setting up inputs for bands_projwfc + inputs["bands"]["bands_projwfc"]["scf"]["pw"] = deepcopy( + inputs["bands"]["bands"]["scf"]["pw"] + ) + inputs["bands"]["bands_projwfc"]["bands"]["pw"] = deepcopy( + inputs["bands"]["bands"]["bands"]["pw"] + ) + inputs["bands"]["bands_projwfc"]["bands"]["pw"]["code"] = inputs["bands"][ + "bands" + ]["bands"]["pw"]["code"] + inputs["bands"]["bands_projwfc"]["scf"]["pw"]["code"] = inputs["bands"][ + "bands" + ]["scf"]["pw"]["code"] + + inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["code"] = ( + fixture_code("quantumespresso.projwfc") + ) + inputs["bands"]["bands_projwfc"]["projwfc"]["projwfc"]["parameters"] = Dict( + {"PROJWFC": {"DeltaE": 0.01}} + ).store() inputs["properties"].append("bands") + if run_pdos: inputs["properties"].append("pdos") @@ -852,6 +852,7 @@ def _generate_qeapp_workchain( # mock output if relax_type != "none": workchain.out("structure", app.structure_model.input_structure) + if run_pdos: from aiida_quantumespresso.workflows.pdos import PdosWorkChain @@ -863,6 +864,7 @@ def _generate_qeapp_workchain( namespace="pdos", ) ) + if run_bands: from aiidalab_qe.plugins.bands.bands_workchain import BandsWorkChain diff --git a/tests/test_plugins_bands.py b/tests/test_plugins_bands.py index a6067bab4..fb973bc25 100644 --- a/tests/test_plugins_bands.py +++ b/tests/test_plugins_bands.py @@ -2,30 +2,38 @@ def test_result(generate_qeapp_workchain): import plotly.graph_objects as go from aiidalab_qe.common.bands_pdos import BandsPdosWidget - from aiidalab_qe.plugins.bands.result import BandsResultsModel, BandsResultsPanel + from aiidalab_qe.plugins.electronic_structure.result import ( + ElectronicStructureResultsModel, + ElectronicStructureResultsPanel, + ) + + workchain = generate_qeapp_workchain(run_pdos=False) + model = ElectronicStructureResultsModel() + panel = ElectronicStructureResultsPanel(model=model) - workchain = generate_qeapp_workchain() - model = BandsResultsModel() model.process_uuid = workchain.node.uuid - result = BandsResultsPanel(model=model) - result._render() - widget = result.children[0] + assert model.title == "Electronic bands" + assert model.identifiers == ["bands"] + + panel.render() + + assert len(panel.results_container.children) == 1 # only bands, so no controls + + widget = panel.bands_pdos_container.children[0] # type: ignore model = widget._model assert isinstance(widget, BandsPdosWidget) assert isinstance(widget.plot, go.FigureWidget) - # Check if data is correct assert not model.pdos_data assert model.bands_data - assert model.bands_data["pathlabels"] # type: ignore + assert model.bands_data["pathlabels"] - # Check Bands axis assert widget.plot.layout.xaxis.title.text == "k-points" assert widget.plot.layout.yaxis.title.text == "Electronic Bands (eV)" assert isinstance(widget.plot.layout.xaxis.rangeslider, go.layout.xaxis.Rangeslider) - assert model.bands_data["pathlabels"][0] == list(widget.plot.layout.xaxis.ticktext) # type: ignore + assert model.bands_data["pathlabels"][0] == list(widget.plot.layout.xaxis.ticktext) def test_structure_1d(generate_qeapp_workchain, generate_structure_data): diff --git a/tests/test_plugins_electronic_structure.py b/tests/test_plugins_electronic_structure.py index 18f381be6..265162f3b 100644 --- a/tests/test_plugins_electronic_structure.py +++ b/tests/test_plugins_electronic_structure.py @@ -9,11 +9,17 @@ def test_electronic_structure(generate_qeapp_workchain): workchain = generate_qeapp_workchain() model = ElectronicStructureResultsModel() + panel = ElectronicStructureResultsPanel(model=model) model.process_uuid = workchain.node.uuid - result = ElectronicStructureResultsPanel(model=model) - result._render() - widget = result.children[0] + assert model.title == "Electronic bands + PDOS" + assert model.identifiers == ["bands", "pdos"] + + panel.render() + + assert len(panel.results_container.children) == 2 # has controls + + widget = panel.bands_pdos_container.children[0] # type: ignore model = widget._model assert isinstance(widget, BandsPdosWidget) diff --git a/tests/test_plugins_pdos.py b/tests/test_plugins_pdos.py index 8bad6bd03..23e0666e8 100644 --- a/tests/test_plugins_pdos.py +++ b/tests/test_plugins_pdos.py @@ -2,26 +2,31 @@ def test_result(generate_qeapp_workchain): import plotly.graph_objects as go from aiidalab_qe.common.bands_pdos import BandsPdosWidget - from aiidalab_qe.plugins.pdos.result import PdosResultsModel, PdosResultsPanel + from aiidalab_qe.plugins.electronic_structure.result import ( + ElectronicStructureResultsModel, + ElectronicStructureResultsPanel, + ) - workchain = generate_qeapp_workchain() - model = PdosResultsModel() + workchain = generate_qeapp_workchain(run_bands=False) + model = ElectronicStructureResultsModel() + panel = ElectronicStructureResultsPanel(model=model) model.process_uuid = workchain.node.uuid - result = PdosResultsPanel(model=model) - result._render() - widget = result.children[0] + assert model.title == "Electronic PDOS" + assert model.identifiers == ["pdos"] + + panel.render() + + assert len(panel.results_container.children) == 1 # only pdos, so no controls + + widget = panel.bands_pdos_container.children[0] # type: ignore model = widget._model assert isinstance(widget, BandsPdosWidget) assert isinstance(widget.plot, go.FigureWidget) - # Check if data is correct assert not model.bands_data assert model.pdos_data - # Check PDOS settings is not None - - # Check Bands axis assert widget.plot.layout.xaxis.title.text == "Density of states (eV)" assert widget.plot.layout.yaxis.title.text is None diff --git a/tests/test_result.py b/tests/test_result.py index db7de7038..b5932fdbb 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -40,7 +40,7 @@ def test_workchainview(generate_qeapp_workchain): viewer = WorkChainResultsViewer(model=model) model.process_uuid = workchain.node.uuid viewer.render() - assert len(viewer.tabs.children) == 4 + assert len(viewer.tabs.children) == 2 assert viewer.tabs._titles["0"] == "Relaxed structure" # type: ignore