Skip to content

Commit

Permalink
Unify electronic structure result panels (#1039)
Browse files Browse the repository at this point in the history
This PR implements a single unified panel for electronic structure results. The title of the panel's tab is dynamic depending on the selected properties. In addition, if both bands and PDOS are selected, selection controls are added at the top of the panel for the user to choose which property they would like to plot.
  • Loading branch information
edan-bainglass authored Dec 30, 2024
1 parent 452d60d commit 36af095
Show file tree
Hide file tree
Showing 26 changed files with 407 additions and 247 deletions.
2 changes: 1 addition & 1 deletion src/aiidalab_qe/app/result/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _post_render(self):
self.toggle_controls.value = "Summary"

self.process_monitor = ProcessMonitor(
timeout=0.2,
timeout=0.5,
callbacks=[
self._update_status,
self._update_state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ class StructureResultsModel(ResultsModel):

source = None

@property
def include(self):
return True

def update(self):
super().update()
is_relaxed = "relax" in self.properties
self.title = "Relaxed structure" if is_relaxed else "Initial structure"
self.source = self.outputs if is_relaxed else self.inputs
self.auto_render = not is_relaxed # auto-render initial structure
self.auto_render = not is_relaxed or self.has_results

def get_structure(self):
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def _render(self):
if not hasattr(self, "widget"):
structure = self._model.get_structure()
self.widget = StructureDataViewer(structure=structure)
self.children = [self.widget]
self.results_container.children = [self.widget]

# HACK to resize the NGL viewer in cases where it auto-rendered when its
# container was not displayed, which leads to a null width. This hack restores
Expand Down
31 changes: 16 additions & 15 deletions src/aiidalab_qe/app/result/components/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@

class WorkChainResultsViewer(ResultsComponent[WorkChainResultsViewerModel]):
def __init__(self, model: WorkChainResultsViewerModel, **kwargs):
super().__init__(model=model, **kwargs)
# NOTE: here we want to add the structure and plugin models to the viewer
# model BEFORE we define the observation of the process uuid. This ensures
# that when the process changes, its reflected in the sub-models prior to
# the logic of the process change event handler.
# TODO avoid exceptions! Ensure sub-model synchronization in general!
self.panels: dict[str, ResultsPanel] = {}
self._add_structure_panel() # TODO consider refactoring structure panel as a plugin
self._fetch_plugin_results()
self._add_structure_panel(model)
self._fetch_plugin_results(model)
super().__init__(model=model, **kwargs)

def _on_process_change(self, _):
self._update_panels()
Expand All @@ -40,14 +45,10 @@ def _post_render(self):
self._set_tabs()

def _update_panels(self):
properties = self._model.properties
need_electronic_structure = "bands" in properties and "pdos" in properties
self.panels = {
identifier: panel
for identifier, panel in self.panels.items()
if identifier == "structure"
or identifier in properties
or (identifier == "electronic_structure" and need_electronic_structure)
identifier: self.panels[identifier]
for identifier, model in self._model.get_models()
if model.include
}

def _set_tabs(self):
Expand All @@ -65,18 +66,18 @@ def _set_tabs(self):
if children:
self.tabs.selected_index = 0

def _add_structure_panel(self):
def _add_structure_panel(self, viewer_model: WorkChainResultsViewerModel):
structure_model = StructureResultsModel()
structure_model.process_uuid = self._model.process_uuid
structure_model.process_uuid = viewer_model.process_uuid
self.structure_results = StructureResultsPanel(model=structure_model)
identifier = structure_model.identifier
self._model.add_model(identifier, structure_model)
viewer_model.add_model(identifier, structure_model)
self.panels = {
identifier: self.structure_results,
**self.panels,
}

def _fetch_plugin_results(self):
def _fetch_plugin_results(self, viewer_model: WorkChainResultsViewerModel):
entries = get_entry_items("aiidalab_qe.properties", "result")
for identifier, entry in entries.items():
for key in ("panel", "model"):
Expand All @@ -86,7 +87,7 @@ def _fetch_plugin_results(self):
)
panel = entry["panel"]
model = entry["model"]()
self._model.add_model(identifier, model)
viewer_model.add_model(identifier, model)
self.panels[identifier] = panel(
identifier=identifier,
model=model,
Expand Down
4 changes: 4 additions & 0 deletions src/aiidalab_qe/app/static/styles/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,7 @@ footer {
.p-Accordion-child:has(.qe-app-step-fail) > .p-Collapse-header {
background-color: var(--color-failed);
}

.p-TabBar-tab {
min-width: fit-content !important;
}
40 changes: 20 additions & 20 deletions src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,34 @@


class BandsPdosWidget(ipw.VBox):
"""
A widget for plotting band structure and projected density of states (PDOS) data.
"""A widget for plotting band structure and projected density of states (PDOS).
Parameters
----------
- bands (optional): A node containing band structure data.
- pdos (optional): A node containing PDOS data.
`model`: `BandsPdosModel`
The MVC model containing the data and logic for the widget.
Attributes
----------
- description: HTML description of the widget.
- dos_atoms_group: Dropdown widget to select the grouping of atoms for PDOS plotting.
- dos_plot_group: Dropdown widget to select the type of PDOS contributions to plot.
- selected_atoms: Text widget to select specific atoms for PDOS plotting.
- update_plot_button: Button widget to update the plot.
- download_button: Button widget to download the data.
- project_bands_box: Checkbox widget to choose whether projected bands should be plotted.
- plot_widget: Plotly widget for band structure and PDOS plot.
- bands_widget: Output widget to display the bandsplot widget.
`description`: `ipywidgets.HTML`
HTML description of the widget.
`dos_atoms_group`: `ipywidgets.Dropdown`
Dropdown widget to select the grouping of atoms for PDOS plotting.
`dos_plot_group`: `ipywidgets.Dropdown`
Dropdown widget to select the type of PDOS contributions to plot.
`selected_atoms`: `ipywidgets.Text`
Text widget to select specific atoms for PDOS plotting.
`update_plot_button`: `ipywidgets.Button`
Button widget to update the plot.
`download_button`: `ipywidgets.Button`
Button widget to download the data.
`project_bands_box`: `ipywidgets.Checkbox`
Checkbox widget to choose whether projected bands should be plotted.
`plot`: `plotly.graph_objects.FigureWidget`
Plotly widget for band structure and PDOS plot.
"""

def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs):
if bands is None and pdos is None:
raise ValueError("Either bands or pdos must be provided")

def __init__(self, model: BandsPdosModel, **kwargs):
super().__init__(
children=[LoadingWidget("Loading widgets")],
**kwargs,
Expand All @@ -49,9 +52,6 @@ def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs):

self.rendered = False

self._model.bands = bands
self._model.pdos = pdos

def render(self):
if self.rendered:
return
Expand Down
48 changes: 48 additions & 0 deletions src/aiidalab_qe/common/bands_pdos/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import json

Expand All @@ -6,9 +8,12 @@
import traitlets as tl
from IPython.display import display

from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiidalab_qe.common.bands_pdos.utils import (
HTML_TAGS,
extract_bands_output,
extract_pdos_output,
get_bands_data,
get_bands_projections_data,
get_pdos_data,
Expand Down Expand Up @@ -79,6 +84,49 @@ def __init__(self, *args, **kwargs):
lambda _: self._has_pdos or self.needs_projections_controls,
)

@classmethod
def from_nodes(
cls,
bands: orm.WorkChainNode | None = None,
pdos: orm.WorkChainNode | None = None,
root: orm.WorkChainNode | None = None,
):
"""Create a `BandsPdosModel` instance from the provided nodes.
The method attempts to extract the output attribute dictionaries from the
nodes and creates from them an instance of the model.
Parameters
----------
`bands` : `orm.WorkChainNode`, optional
The bands workchain node.
`pdos` : `orm.WorkChainNode`, optional
The PDOS workchain node.
`root`: `orm.WorkChainNode`, optional
The root QE app workchain node.
Returns
-------
`BandsPdosModel`
The model instance.
Raises
------
`ValueError`
If neither of the nodes is provided or if the parsing of the nodes fails.
"""
if bands or pdos:
bands_output = extract_bands_output(bands)
pdos_output = extract_pdos_output(pdos)
elif root:
bands_output = extract_bands_output(root)
pdos_output = extract_pdos_output(root)
else:
raise ValueError("At least one of the nodes must be provided")
if bands_output or pdos_output:
return cls(bands=bands_output, pdos=pdos_output)
raise ValueError("Failed to parse at least one node")

def fetch_data(self):
"""Fetch the data from the nodes."""
if self.bands:
Expand Down
55 changes: 54 additions & 1 deletion src/aiidalab_qe/common/bands_pdos/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import json
import re

import numpy as np
from pymatgen.core.periodic_table import Element

from aiida.orm import ProjectionData
from aiida.common.extendeddicts import AttributeDict
from aiida.orm import ProjectionData, WorkChainNode

# Constants for HTML tags
HTML_TAGS = {
Expand Down Expand Up @@ -37,6 +40,56 @@
}


def extract_pdos_output(node: WorkChainNode) -> AttributeDict | None:
"""Extract the PDOS output node from the given node.
Parameters
----------
`node`: `WorkChainNode`
The node to extract the PDOS output from.
Returns
-------
`AttributeDict | None`
The PDOS output node, if available.
"""
if not node:
return
if node.process_label == "QeAppWorkChain" and "pdos" in node.outputs:
return node.outputs.pdos
if "dos" in node.outputs and "projwfc" in node.outputs:
items = {key: getattr(node.outputs, key) for key in node.outputs}
return AttributeDict(items)


def extract_bands_output(node: WorkChainNode) -> AttributeDict | None:
"""Extract the bands output node from the given node.
Parameters
----------
`node`: `WorkChainNode`
The node to extract the bands output from.
Returns
-------
`AttributeDict | None`
The bands output node, if available.
"""
if not node:
return
if node.process_label == "QeAppWorkChain" and "bands" in node.outputs:
outputs = node.outputs.bands
else:
outputs = node.outputs
return (
outputs.bands
if "bands" in outputs
else outputs.bands_projwfc
if "bands_projwfc" in outputs
else None
)


def get_bands_data(outputs, fermi_energy=None):
if "band_structure" not in outputs:
return None
Expand Down
Loading

0 comments on commit 36af095

Please sign in to comment.