Skip to content

Commit 6164dae

Browse files
committed
add edge_label_kwargs interface to :meth:.PipelineState.plot_graph
1 parent 595cc41 commit 6164dae

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

renard/pipeline/core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def plot_graph(
428428
node_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
429429
edge_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
430430
label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
431+
edge_label_kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
431432
tight_layout: bool = False,
432433
legend: bool = False,
433434
):
@@ -458,7 +459,10 @@ def plot_graph(
458459
:param node_kwargs: passed to :func:`nx.draw_networkx_nodes`
459460
:param edge_kwargs: passed to :func:`nx.draw_networkx_nodes`
460461
:param label_kwargs: passed to :func:`nx.draw_networkx_labels`
461-
:param tight_layout: if ``True``, will use matplotlib's tight_layout
462+
:param edge_label_kwargs: passed to
463+
:func:`nx.draw_networkx_labels`
464+
:param tight_layout: if ``True``, will use matplotlib's
465+
tight_layout
462466
:param legend: passed to :func:`.plot_nx_graph_reasonably`
463467
"""
464468
import matplotlib.pyplot as plt
@@ -481,6 +485,7 @@ def plot_graph(
481485
assert not isinstance(node_kwargs, list)
482486
assert not isinstance(edge_kwargs, list)
483487
assert not isinstance(label_kwargs, list)
488+
assert not isinstance(edge_label_kwargs, list)
484489
if tight_layout:
485490
fig.tight_layout()
486491
plot_nx_graph_reasonably(
@@ -490,6 +495,7 @@ def plot_graph(
490495
node_kwargs=node_kwargs,
491496
edge_kwargs=edge_kwargs,
492497
label_kwargs=label_kwargs,
498+
edge_label_kwargs=edge_label_kwargs,
493499
legend=legend,
494500
)
495501
return
@@ -504,6 +510,10 @@ def plot_graph(
504510
assert isinstance(edge_kwargs, list)
505511
label_kwargs = label_kwargs or [{} for _ in range(len(self.character_network))]
506512
assert isinstance(label_kwargs, list)
513+
edge_label_kwargs = edge_label_kwargs or [
514+
{} for _ in range(len(self.character_network))
515+
]
516+
assert isinstance(edge_label_kwargs, list)
507517

508518
if fig is None:
509519
fig, ax = plt.subplots()
@@ -541,6 +551,7 @@ def update(slider_value):
541551
node_kwargs=node_kwargs[slider_i],
542552
edge_kwargs=edge_kwargs[slider_i],
543553
label_kwargs=label_kwargs[slider_i],
554+
edge_label_kwargs=edge_label_kwargs[slider_i],
544555
legend=legend,
545556
)
546557
ax.set_xlim(-1.2, 1.2)

renard/plot_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def plot_nx_graph_reasonably(
2525
node_kwargs: Optional[Dict[str, Any]] = None,
2626
edge_kwargs: Optional[Dict[str, Any]] = None,
2727
label_kwargs: Optional[Dict[str, Any]] = None,
28-
edge_labels_kwargs: Optional[Dict[str, Any]] = None,
28+
edge_label_kwargs: Optional[Dict[str, Any]] = None,
2929
legend: bool = False,
3030
):
3131
"""Try to plot a :class:`nx.Graph` with 'reasonable' parameters
@@ -82,17 +82,17 @@ def plot_nx_graph_reasonably(
8282
edge_kwargs["alpha"] = edge_kwargs.get("alpha", 0.35)
8383
nx.draw_networkx_edges(G, pos, ax=ax, **edge_kwargs)
8484

85-
edge_labels_kwargs = edge_labels_kwargs or {}
86-
if not "edge_labels" in edge_labels_kwargs:
87-
edge_labels_kwargs["edge_labels"] = {
85+
edge_label_kwargs = edge_label_kwargs or {}
86+
if not "edge_labels" in edge_label_kwargs:
87+
edge_label_kwargs["edge_labels"] = {
8888
(char1, char2): data.get("relations", "")
8989
for char1, char2, data in G.edges.data()
9090
}
91-
for (char1, char2), rel in edge_labels_kwargs["edge_labels"].items():
91+
for (char1, char2), rel in edge_label_kwargs["edge_labels"].items():
9292
if rel == set():
93-
edge_labels_kwargs["edge_labels"][(char1, char2)] = ""
94-
edge_labels_kwargs["font_size"] = edge_labels_kwargs.get("font_size", 6)
95-
nx.draw_networkx_edge_labels(G, pos, ax=ax, **edge_labels_kwargs)
93+
edge_label_kwargs["edge_labels"][(char1, char2)] = ""
94+
edge_label_kwargs["font_size"] = edge_label_kwargs.get("font_size", 6)
95+
nx.draw_networkx_edge_labels(G, pos, ax=ax, **edge_label_kwargs)
9696

9797
label_kwargs = label_kwargs or {}
9898
label_kwargs["verticalalignment"] = label_kwargs.get("verticalalignment", "top")

0 commit comments

Comments
 (0)