33
44import matplotlib .pyplot as plt
55import networkx as nx
6+ from pydantic import ValidationError
67from ome_zarr_models ._utils import TransformGraph
7-
8+ from ome_zarr_models . _v06 . coordinate_transforms import CoordinateSystemIdentifier , Transform
89
910def transform_graph_to_networkx (tgraph : TransformGraph ) -> nx .DiGraph :
1011 """
@@ -34,38 +35,58 @@ def transform_graph_to_networkx(tgraph: TransformGraph) -> nx.DiGraph:
3435 g = nx .DiGraph ()
3536
3637
37- # Add all coordinate systems as nodes
38+ # Add all named coordinate systems as nodes
3839 for cs_name in tgraph ._named_systems :
39- node_name = cs_name
4040 g .add_node (
41- node_name ,
41+ cs_name ,
4242 coordinate_system = tgraph ._named_systems [cs_name ],
43- is_default = (cs_name == tgraph ._default_system ),
4443 )
4544
46- # Add transformations as edges
47- for input_cs , output_dict in tgraph ._graph .items ():
48- for output_cs , transform in output_dict . items () :
49- source_node = input_cs
50- target_node = output_cs
51- g . add_edge (
52- source_node ,
53- target_node ,
54- transformation = transform ,
55- edge_type = "transformation" ,
45+ # Add also the named coordinate systems from the subgraphs
46+ for path_name , subgraph in tgraph ._subgraphs .items ():
47+ for cs_name in subgraph . _named_systems :
48+ identifier = CoordinateSystemIdentifier (
49+ name = cs_name ,
50+ path = path_name ,
51+ )
52+ g . add_node (
53+ identifier ,
54+ coordinate_system = subgraph . _named_systems [ cs_name ] ,
5655 )
57- try :
58- inverse = transform .get_inverse ()
59- g .add_edge (
60- target_node ,
61- source_node ,
62- transformation = inverse ,
63- edge_type = "transformation" ,
56+
57+ # finally add the "paths" coordinate systems as nodes
58+ for path_name , subgraph in tgraph ._subgraphs .items ():
59+ for src , edges in subgraph ._graph .items ():
60+ for tgt , transform in edges .items ():
61+ input_image = transform .input
62+ path = f"{ path_name } /{ input_image } "
63+ g .add_node (
64+ path ,
65+ coordinate_system = None ,
6466 )
65- except NotImplementedError :
66- pass
6767
68- for graph_name , subgraph in tgraph ._subgraphs .items ():
68+ for src , edges in tgraph ._graph .items ():
69+ for tgt , transform in edges .items ():
70+ _add_transform_and_inverse_transformation_edges (
71+ g ,
72+ src ,
73+ tgt ,
74+ transform ,
75+ )
76+
77+
78+ for path_name , subgraph in tgraph ._subgraphs .items ():
79+ for src , edges in subgraph ._graph .items ():
80+ for tgt , transform in edges .items ():
81+ # TODO: hack! Replace
82+ # TODO: we assume the the subgraph does not contain transformation from the parent graph, so we use the CoordinateSystemIdentifier
83+
84+ _add_transform_and_inverse_transformation_edges (
85+ g = g ,
86+ input_cs = _get_name_of_subgraph (src , path_name ),
87+ output_cs = _get_name_of_subgraph (tgt , path_name ),
88+ transform = transform ,
89+ )
6990 pass
7091 # _add_graph_to_networkx(subgraph, g)
7192 # if subgraph._default_system:
@@ -79,6 +100,40 @@ def transform_graph_to_networkx(tgraph: TransformGraph) -> nx.DiGraph:
79100
80101 return g
81102
103+ def _get_name_of_subgraph (cs_name : str , path_name : str ) -> str | CoordinateSystemIdentifier :
104+ if cs_name in ['0' , '1' , '2' , '3' , '4' , '5' ]:
105+ return f'{ path_name } /{ cs_name } '
106+ return CoordinateSystemIdentifier (
107+ name = cs_name ,
108+ path = path_name ,
109+ )
110+
111+ def _add_transform_and_inverse_transformation_edges (
112+ g : nx .DiGraph ,
113+ input_cs : str | CoordinateSystemIdentifier ,
114+ output_cs : str | CoordinateSystemIdentifier ,
115+ transform : Transform ,
116+ ):
117+ source_node = input_cs
118+ target_node = output_cs
119+ g .add_edge (
120+ source_node ,
121+ target_node ,
122+ transformation = transform ,
123+ edge_type = "transformation" ,
124+ )
125+ try :
126+ inverse = transform .get_inverse ()
127+ g .add_edge (
128+ target_node ,
129+ source_node ,
130+ transformation = inverse ,
131+ edge_type = "transformation" ,
132+ )
133+ except (NotImplementedError , ValidationError ):
134+ pass
135+
136+
82137
83138def draw_graph (g : nx .DiGraph , figsize : tuple [int , int ] = (12 , 8 ), with_edge_labels : bool = True ) -> None :
84139 """
@@ -96,7 +151,7 @@ def draw_graph(g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labe
96151 plt .figure (figsize = figsize )
97152
98153 # Use spring layout for node positioning
99- pos = nx .spring_layout (g , k = 2 , iterations = 50 , seed = 42 )
154+ pos = nx .spring_layout (g )
100155
101156 # Draw nodes
102157 nx .draw_networkx_nodes (
0 commit comments