Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hide metadata inputs for tasks in HTML visualization. #346

Merged
merged 8 commits into from
Dec 2, 2024
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ dmypy.json
tests/work
/tests/**/*.png
/tests/**/*txt
.vscode
.vscode/
13 changes: 10 additions & 3 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from node_graph.node import Node as GraphNode
from aiida_workgraph import USE_WIDGET
from aiida_workgraph.properties import property_pool
Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
self._widget = None
self.state = "PLANNED"
self.action = ""
self.show_socket_depth = 0

def to_dict(self, short: bool = False) -> Dict[str, Any]:
from aiida.orm.utils.serialize import serialize
Expand Down Expand Up @@ -174,18 +177,22 @@ def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any:
print(WIDGET_INSTALLATION_MESSAGE)
return
# if ipywdigets > 8.0.0, use _repr_mimebundle_ instead of _ipython_display_
self._widget.from_node(self)
self._widget.from_node(self, show_socket_depth=self.show_socket_depth)
if hasattr(self._widget, "_repr_mimebundle_"):
return self._widget._repr_mimebundle_(*args, **kwargs)
else:
return self._widget._ipython_display_(*args, **kwargs)

def to_html(self, output: str = None, **kwargs):
def to_html(
self, output: str = None, show_socket_depth: Optional[int] = None, **kwargs
):
"""Write a standalone html file to visualize the task."""
if show_socket_depth is None:
show_socket_depth = self.show_socket_depth
if self._widget is None:
print(WIDGET_INSTALLATION_MESSAGE)
return
self._widget.from_node(self)
self._widget.from_node(node=self, show_socket_depth=show_socket_depth)
return self._widget.to_html(output=output, **kwargs)


Expand Down
22 changes: 22 additions & 0 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,25 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di
processed_inout_list.append(item)

return processed_inout_list


def filter_keys_namespace_depth(
dict_: dict[Any, Any], max_depth: int = 0
) -> dict[Any, Any]:
"""
Filter top-level keys of a dictionary based on the namespace nesting level (number of periods) in the key.
:param dict dict_: The dictionary to filter.
:param int max_depth: Maximum depth of namespaces to retain (number of periods).
:return: The filtered dictionary with only keys satisfying the depth condition.
:rtype: dict
"""
result: dict[Any, Any] = {}

for key, value in dict_.items():
depth = key.count(".")

if depth <= max_depth:
result[key] = value

return result
21 changes: 14 additions & 7 deletions aiida_workgraph/widget/src/widget/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import anywidget
import traitlets
from .utils import wait_to_link
from aiida_workgraph.utils import filter_keys_namespace_depth

try:
__version__ = importlib.metadata.version("widget")
Expand Down Expand Up @@ -37,16 +38,22 @@ def from_workgraph(self, workgraph: Any) -> None:
wgdata = workgraph_to_short_json(wgdata)
self.value = wgdata

def from_node(self, node: Any) -> None:
def from_node(self, node: Any, show_socket_depth: int = 0) -> None:

tdata = node.to_dict()
tdata.pop("properties", None)
tdata.pop("executor", None)
tdata.pop("node_class", None)
tdata.pop("process", None)
tdata["label"] = tdata["identifier"]

# Remove certain elements of the dict-representation of the Node that we don't want to show
for key in ("properties", "executor", "node_class", "process"):
tdata.pop(key, None)
for input in tdata["inputs"].values():
input.pop("property")
tdata["inputs"] = list(tdata["inputs"].values())

tdata["label"] = tdata["identifier"]

filtered_inputs = filter_keys_namespace_depth(
dict_=tdata["inputs"], max_depth=show_socket_depth
)
tdata["inputs"] = list(filtered_inputs.values())
tdata["outputs"] = list(tdata["outputs"].values())
wgdata = {"name": node.name, "nodes": {node.name: tdata}, "links": []}
self.value = wgdata
Expand Down
Loading