Skip to content

Commit

Permalink
Implement plotly (#92)
Browse files Browse the repository at this point in the history
* added plotly function for visualization

* API update

* removed unused import

* API update

* refactored visualize_2d_network

* API update

---------

Co-authored-by: Alina Lacheim <[email protected]>
Co-authored-by: sdRDM Bot <[email protected]>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent 4a78e3f commit 3b679f9
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 14 deletions.
1 change: 0 additions & 1 deletion pyeed/core/abstractannotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ class AbstractAnnotation(
_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"

)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
1 change: 0 additions & 1 deletion pyeed/core/blastdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class BlastData(
_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"

)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
1 change: 0 additions & 1 deletion pyeed/core/numberedsequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class NumberedSequence(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(

default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
)

Expand Down
2 changes: 1 addition & 1 deletion pyeed/core/pairwisealignmentresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PairwiseAlignmentResult(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion pyeed/core/proteinrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ProteinRecord(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
3 changes: 1 addition & 2 deletions pyeed/core/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class Region(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(

default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion pyeed/core/regionset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RegionSet(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_object_terms: Set[str] = PrivateAttr(
Expand Down
3 changes: 1 addition & 2 deletions pyeed/core/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class Sequence(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(

default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion pyeed/core/sequencerecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class SequenceRecord(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion pyeed/core/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Site(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
2 changes: 1 addition & 1 deletion pyeed/core/standardnumbering.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class StandardNumbering(

_repo: Optional[str] = PrivateAttr(default="https://github.com/PyEED/pyeed")
_commit: Optional[str] = PrivateAttr(
default="72d2203f2e3ce4b319b29fa0d2f146b5eead7b00"
default="b926bfec3aa1ec45a5614cf6ac4a546252dd384c"
)

_raw_xml_data: Dict = PrivateAttr(default_factory=dict)
Expand Down
135 changes: 134 additions & 1 deletion pyeed/network/network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from os import name
from typing import List, Optional, Any

import matplotlib.pyplot as plt
Expand All @@ -8,6 +9,7 @@
from py4cytoscape import gen_node_size_map, scheme_c_number_continuous
from pydantic import BaseModel, Field, PrivateAttr, field_serializer, field_validator
from requests import RequestException
import plotly.graph_objects as go

from pyeed.align.pairwise import PairwiseAligner
from pyeed.core.sequencerecord import SequenceRecord
Expand Down Expand Up @@ -476,12 +478,143 @@ def visualize(
plt.savefig(save_path, dpi=dpi)
plt.show()



def visualize_2d_network(
self,
color: Optional[List[str]] = None,
size: bool = False,
edges: bool = True,
edge_color: str = "grey",
save_path: Optional[str] = None
) -> None:
"""
Visualizes a 2D network graph using Plotly by plotting nodes and optionally edges.
For large networks, it is recommended to disable the edges for performance.
Parameters:
color (List[str], optional): List of colors to colorize the nodes. Default is None, where nodes are colored blue.
size (bool, optional): Whether to size the nodes based on centrality. Default is False.
edges (bool, optional): Whether to plot edges. Default is True.
edge_color (str, optional): Color of the edges. Default is grey.
save_path (str, optional): File path to save the plot. If None, the plot is shown directly.
Raises:
ValueError: If the specified color list length does not match the number of nodes.
Returns:
None
"""

# Initialize traces list
traces = []

# Plot edges if enabled
if edges:
for edge in self.network.edges:
x_values = [self.network.nodes[edge[0]]["x_pos"], self.network.nodes[edge[1]]["x_pos"]]
y_values = [self.network.nodes[edge[0]]["y_pos"], self.network.nodes[edge[1]]["y_pos"]]
traces.append(
go.Scatter(
x=x_values,
y=y_values,
mode="lines",
line=dict(width=0.7, color=edge_color), # Correctly define the line dictionary
opacity=0.5 # Set opacity at the correct level in the Scatter object
)
)

# Set node colors
if color and len(color) != len(self.network.nodes):
raise ValueError("Length of the color list must match the number of nodes.")
color_list = color if color else ["blue"] * len(self.network.nodes)

# Calculate node sizes if centrality is enabled
node_sizes = []
if size:
centrality = nx.degree_centrality(self.network)
max_size = 20 # Maximum size for the largest node
node_sizes = [6 + (centrality[node] * max_size) for node in self.network.nodes]
else:
node_sizes = [6] * len(self.network.nodes)

# Plot nodes and handle annotations
annotations = []
for counter, (key, node) in enumerate(self.network.nodes(data=True)):
# Node properties
node_size = node_sizes[counter]
node_color = color_list[counter]
symbol = "circle"
label = node.get("label", key)
name = node.get("name", key)

# Highlight targets differently
if key in self.targets:
node_size = 15
node_color = "black"
symbol = "cross"

# Add node trace
traces.append(
go.Scatter(
x=[node["x_pos"]],
y=[node["y_pos"]],
mode="markers",
marker=dict(
size=node_size,
color=node_color,
symbol=symbol,
),
hovertemplate="<b>ID: %{customdata[0]}</b><extra></extra>",
customdata=[[key]], # Pass node ID as custom data
)
)

# Add node annotation
annotations.append(
dict(
x=node["x_pos"],
y=node["y_pos"],
text=label,
showarrow=False,
xanchor='center',
yanchor='middle',
font=dict(
size=10,
color=node_color
)
)
)

# Create the figure
fig = go.Figure(
data=traces,
layout=go.Layout(
plot_bgcolor="white",
showlegend=False,
hovermode="closest",
margin=dict(b=0, l=0, r=0, t=0)
#annotations=annotations, # Add annotations to layout
),
)

# Hide axis lines and ticks
fig.update_xaxes(visible=False)
fig.update_yaxes(visible=False)

# Show or save the plot
if save_path:
fig.write_image(save_path)
else:
fig.show(scale=10)


def _2d_position_nodes_and_edges(self, graph: nx.Graph):
"""Calculates node positions based on weight metric and
adds position information of nodes and edges to the respective
entry in the graph."""

positions = nx.spring_layout(graph, weight=self.weight, dim=2, seed=42)
positions = nx.spring_layout(graph, weight=self.weight, iterations=400, dim=2, seed=42)

# Add node position as coordinates
for node in graph.nodes():
Expand Down

0 comments on commit 3b679f9

Please sign in to comment.