Skip to content

Commit

Permalink
Hide metadata inputs for tasks in HTML visualization. (#346)
Browse files Browse the repository at this point in the history
Add a parameter `show_socket_depth` to control the level of the input sockets to be shown in the GUI.
  • Loading branch information
GeigerJ2 authored Dec 2, 2024
1 parent 06612f8 commit 512d1a8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ dmypy.json
tests/work
/tests/**/*.png
/tests/**/*txt
.vscode
.vscode/
13 changes: 10 additions & 3 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from node_graph.node import Node as GraphNode
from aiida_workgraph import USE_WIDGET
from aiida_workgraph.properties import property_pool
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
self._widget = None
self.state = "PLANNED"
self.action = ""
self.show_socket_depth = 0

def to_dict(self, short: bool = False) -> Dict[str, Any]:
from aiida.orm.utils.serialize import serialize
Expand Down Expand Up @@ -174,18 +177,22 @@ def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any:
print(WIDGET_INSTALLATION_MESSAGE)
return
# if ipywdigets > 8.0.0, use _repr_mimebundle_ instead of _ipython_display_
self._widget.from_node(self)
self._widget.from_node(self, show_socket_depth=self.show_socket_depth)
if hasattr(self._widget, "_repr_mimebundle_"):
return self._widget._repr_mimebundle_(*args, **kwargs)
else:
return self._widget._ipython_display_(*args, **kwargs)

def to_html(self, output: str = None, **kwargs):
def to_html(
self, output: str = None, show_socket_depth: Optional[int] = None, **kwargs
):
"""Write a standalone html file to visualize the task."""
if show_socket_depth is None:
show_socket_depth = self.show_socket_depth
if self._widget is None:
print(WIDGET_INSTALLATION_MESSAGE)
return
self._widget.from_node(self)
self._widget.from_node(node=self, show_socket_depth=show_socket_depth)
return self._widget.to_html(output=output, **kwargs)


Expand Down
22 changes: 22 additions & 0 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,25 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
processed_inout_list.append(item)

return processed_inout_list


def filter_keys_namespace_depth(
dict_: dict[Any, Any], max_depth: int = 0
) -> dict[Any, Any]:
"""
Filter top-level keys of a dictionary based on the namespace nesting level (number of periods) in the key.
:param dict dict_: The dictionary to filter.
:param int max_depth: Maximum depth of namespaces to retain (number of periods).
:return: The filtered dictionary with only keys satisfying the depth condition.
:rtype: dict
"""
result: dict[Any, Any] = {}

for key, value in dict_.items():
depth = key.count(".")

if depth <= max_depth:
result[key] = value

return result
21 changes: 14 additions & 7 deletions aiida_workgraph/widget/src/widget/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import anywidget
import traitlets
from .utils import wait_to_link
from aiida_workgraph.utils import filter_keys_namespace_depth

try:
__version__ = importlib.metadata.version("widget")
Expand Down Expand Up @@ -37,16 +38,22 @@ def from_workgraph(self, workgraph: Any) -> None:
wgdata = workgraph_to_short_json(wgdata)
self.value = wgdata

def from_node(self, node: Any) -> None:
def from_node(self, node: Any, show_socket_depth: int = 0) -> None:

tdata = node.to_dict()
tdata.pop("properties", None)
tdata.pop("executor", None)
tdata.pop("node_class", None)
tdata.pop("process", None)
tdata["label"] = tdata["identifier"]

# Remove certain elements of the dict-representation of the Node that we don't want to show
for key in ("properties", "executor", "node_class", "process"):
tdata.pop(key, None)
for input in tdata["inputs"].values():
input.pop("property")
tdata["inputs"] = list(tdata["inputs"].values())

tdata["label"] = tdata["identifier"]

filtered_inputs = filter_keys_namespace_depth(
dict_=tdata["inputs"], max_depth=show_socket_depth
)
tdata["inputs"] = list(filtered_inputs.values())
tdata["outputs"] = list(tdata["outputs"].values())
wgdata = {"name": node.name, "nodes": {node.name: tdata}, "links": []}
self.value = wgdata
Expand Down

0 comments on commit 512d1a8

Please sign in to comment.