Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster generic products when visualizing graph #51

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .param_table import ParamTable
from .scheduler import Scheduler
from .series import Series
from .typing import Graph, Item, Key, Label, Provider
from .typing import Graph, Item, Key, Label, Provider, get_optional

T = TypeVar('T')
KeyType = TypeVar('KeyType')
Expand Down Expand Up @@ -116,15 +116,6 @@ def _find_nodes_in_paths(
return list(nodes)


def _get_optional(tp: Key) -> Optional[Any]:
if get_origin(tp) != Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] == type(None) else args[1] # noqa: E721


def provide_none() -> None:
return None

Expand Down Expand Up @@ -173,7 +164,7 @@ def _copy_node(
)

def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]:
value_name = _get_optional(value_name) or value_name
value_name = get_optional(value_name) or value_name
label = Label(self._index_name, i)
if isinstance(value_name, Item):
return Item(value_name.label + (label,), value_name.tp)
Expand Down Expand Up @@ -488,7 +479,7 @@ def build(
if get_origin(tp) == Series:
graph.update(self._build_series(tp)) # type: ignore[arg-type]
continue
if (optional_arg := _get_optional(tp)) is not None:
if (optional_arg := get_optional(tp)) is not None:
try:
optional_subgraph = self.build(
optional_arg, search_param_tables=search_param_tables
Expand Down
23 changes: 22 additions & 1 deletion src/sciline/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
get_origin,
)


@dataclass(frozen=True)
Expand All @@ -24,3 +36,12 @@ class Item(Generic[T]):

Key = Union[type, Item]
Graph = Dict[Key, Tuple[Provider, Tuple[Key, ...]]]


def get_optional(tp: Key) -> Optional[Any]:
if get_origin(tp) != Union:
return None
args = get_args(tp)
if len(args) != 2 or type(None) not in args:
return None
return args[0] if args[1] == type(None) else args[1] # noqa: E721
70 changes: 60 additions & 10 deletions src/sciline/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Expand All @@ -17,7 +18,7 @@
from graphviz import Digraph

from .pipeline import Pipeline, SeriesProvider
from .typing import Graph, Item, Key
from .typing import Graph, Item, Key, get_optional


@dataclass
Expand All @@ -26,7 +27,16 @@ class Node:
collapsed: bool = False


def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
FormattedGraph = Dict[str, Tuple[str, List[Node], Node]]


def to_graphviz(
graph: Graph,
compact: bool = False,
cluster_generics: bool = True,
cluster_color: Optional[str] = '#f0f0ff',
**kwargs: Any,
) -> Digraph:
"""
Convert output of :py:class:`sciline.Pipeline.get_graph` to a graphviz graph.

Expand All @@ -36,13 +46,53 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
Output of :py:class:`sciline.Pipeline.get_graph`.
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommendend for large graphs with long parameter tables.
of the branch. Recommended for large graphs with long parameter tables.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
dot = Digraph(strict=True, **kwargs)
for p, (p_name, args, ret) in _format_graph(graph, compact=compact).items():
dot.node(ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle')
formatted_graph = _format_graph(graph, compact=compact)
ordered_graph = dict(
sorted(formatted_graph.items(), key=lambda item: item[1][2].name)
)
subgraphs = _to_subgraphs(ordered_graph)

for origin, subgraph in subgraphs.items():
cluster = cluster_generics and len(subgraph) > 1
name = f'cluster_{origin}' if cluster else None
with dot.subgraph(name=name) as dot_subgraph:
if cluster:
dot_subgraph.attr(rank='same')
if cluster_color is None:
dot_subgraph.attr(style='dotted')
else:
dot_subgraph.attr(style='filled', color=cluster_color)
_add_subgraph(subgraph, dot, dot_subgraph)
return dot


def _to_subgraphs(graph: FormattedGraph) -> Dict[str, FormattedGraph]:
def get_subgraph_name(name: str) -> str:
return name.split('[')[0]

subgraphs: Dict[str, FormattedGraph] = {}
for p, (p_name, args, ret) in graph.items():
subgraph_name = get_subgraph_name(ret.name)
if subgraph_name not in subgraphs:
subgraphs[subgraph_name] = {}
subgraphs[subgraph_name][p] = (p_name, args, ret)
return subgraphs


def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> None:
for p, (p_name, args, ret) in graph.items():
subgraph.node(
ret.name, ret.name, shape='box3d' if ret.collapsed else 'rectangle'
)
# Do not draw dummy providers created by Pipeline when setting instances
if p_name in (
f'{_qualname(Pipeline.__setitem__)}.<locals>.<lambda>',
Expand All @@ -59,7 +109,6 @@ def to_graphviz(graph: Graph, compact: bool = False, **kwargs: Any) -> Digraph:
for arg in args:
dot.edge(arg.name, p)
dot.edge(p, ret.name)
return dot


def _qualname(obj: Any) -> Any:
Expand All @@ -68,9 +117,7 @@ def _qualname(obj: Any) -> Any:
)


def _format_graph(
graph: Graph, compact: bool
) -> Dict[str, Tuple[str, List[Node], Node]]:
def _format_graph(graph: Graph, compact: bool) -> FormattedGraph:
return {
_format_provider(provider, ret, compact=compact): (
_qualname(provider),
Expand Down Expand Up @@ -108,7 +155,10 @@ def _format_type(tp: Key, compact: bool = False) -> Node:

tp, labels = _extract_type_and_labels(tp, compact=compact)

def get_base(tp: type) -> str:
if (tp_ := get_optional(tp)) is not None:
tp = tp_

def get_base(tp: Key) -> str:
return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1]

def format_label(label: Union[type, Tuple[type, Any]]) -> str:
Expand Down
7 changes: 6 additions & 1 deletion tests/visualize_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Generic, TypeVar
from typing import Generic, Optional, TypeVar

import sciline as sl
from sciline.visualize import to_graphviz
Expand Down Expand Up @@ -33,3 +33,8 @@ class SubA(A[T]):
assert sl.visualize._format_type(A[float]).name == 'A[float]'
assert sl.visualize._format_type(SubA[float]).name == 'SubA[float]'
assert sl.visualize._format_type(B[float]).name == 'B[float]'


def test_optional_types_formatted_as_their_content() -> None:
formatted = sl.visualize._format_type(Optional[float]) # type: ignore[arg-type]
assert formatted.name == 'float'