Skip to content

Commit

Permalink
Fixed compatibility issue with networkX
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaCappelletti94 committed May 4, 2023
1 parent d79633b commit ef19d21
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
Binary file added dist/embiggen-0.11.47.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion embiggen/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Current version of package Embiggen."""
__version__ = "0.11.47"
__version__ = "0.11.48"
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def _fit_transform(
"It is not clear what to do with this object."
)

graph_nx = convert_ensmallen_graph_to_networkx_graph(graph)
graph_nx = convert_ensmallen_graph_to_networkx_graph(
graph,
numeric_node_ids=True
)
model.fit(graph_nx)

node_embeddings: np.ndarray = model.get_embedding()
Expand Down
32 changes: 26 additions & 6 deletions embiggen/utils/networkx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@


def convert_ensmallen_graph_to_networkx_graph(
graph: Graph
graph: Graph,
numeric_node_ids: bool = False
) -> nx.Graph:
"""Return NetworkX graph derived from the provided Ensmallen Graph.
Parameters
-----------
graph: Graph
The graph to be converted from ensmallen to NetworkX.
numeric_node_ids: bool = False
Whether to use numeric node IDs or string node IDs.
By default, we use the string node IDs as they are more
interpretable.
"""
if graph.is_directed():
result_graph = nx.DiGraph(name=graph.get_name())
Expand All @@ -26,9 +31,10 @@ def convert_ensmallen_graph_to_networkx_graph(
desc="Parsing nodes"
):
result_graph.add_node(
graph.get_node_name_from_node_id(node_id),
node_id if numeric_node_ids else graph.get_node_name_from_node_id(node_id),
node_types=graph.get_unchecked_node_type_names_from_node_id(
node_id),
node_id
),
)

for edge_id in trange(
Expand All @@ -38,9 +44,23 @@ def convert_ensmallen_graph_to_networkx_graph(
desc="Parsing edges"
):
result_graph.add_edge(
*graph.get_node_names_from_edge_id(edge_id),
weight=graph.get_unchecked_edge_weight_from_edge_id(edge_id),
edge_type=graph.get_unchecked_edge_type_name_from_edge_id(edge_id)
*(
graph.get_node_ids_from_edge_id(edge_id)
if numeric_node_ids else
graph.get_node_names_from_edge_id(edge_id)
),
**(
dict(
weight=graph.get_unchecked_edge_weight_from_edge_id(edge_id),
)
if graph.has_edge_weights() else dict()
),
**(
dict(
edge_type=graph.get_unchecked_edge_type_name_from_edge_id(edge_id),
)
if graph.has_edge_types() else dict()
),
)

return result_graph
Expand Down

0 comments on commit ef19d21

Please sign in to comment.