|
6 | 6 | from ome_zarr_models._v06.coordinate_transforms import ( |
7 | 7 | CoordinateSystemIdentifier, |
8 | 8 | Transform, |
| 9 | + Sequence, |
9 | 10 | ) |
10 | 11 | from pydantic import ValidationError |
| 12 | +import logging |
| 13 | +from ome_zarr_models._v06.coordinate_transforms import ( |
| 14 | + Sequence as SequenceTransformation, |
| 15 | +) |
11 | 16 |
|
12 | 17 |
|
13 | 18 | def transform_graph_to_networkx(tgraph: TransformGraph) -> nx.DiGraph: |
@@ -120,7 +125,9 @@ def _get_name_of_subgraph( |
120 | 125 | f"Ambiguous coordinate system name '{cs_identifier}' found in both root and subgraph '{path_name}'. Use full identifier." |
121 | 126 | ) |
122 | 127 | if cs_identifier not in nodes and cs_path_name not in nodes: |
123 | | - raise ValueError(f"Coordinate system '{cs_identifier}' not found in graph nodes.") |
| 128 | + raise ValueError( |
| 129 | + f"Coordinate system '{cs_identifier}' not found in graph nodes." |
| 130 | + ) |
124 | 131 | if cs_path_name in nodes: |
125 | 132 | return cs_path_name |
126 | 133 | return cs_identifier |
@@ -152,7 +159,9 @@ def _add_transform_and_inverse_transformation_edges( |
152 | 159 | pass |
153 | 160 |
|
154 | 161 |
|
155 | | -def draw_graph(g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labels: bool = True) -> None: |
| 162 | +def draw_graph( |
| 163 | + g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labels: bool = True |
| 164 | +) -> None: |
156 | 165 | """ |
157 | 166 | Draw a NetworkX graph showing all nodes and edges with their names. |
158 | 167 |
|
@@ -210,7 +219,9 @@ def draw_graph(g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labe |
210 | 219 | plt.show() |
211 | 220 |
|
212 | 221 |
|
213 | | -def get_relative_path(graph: nx.DiGraph, source_coordinate_system: str, target_coordinate_system: str) -> list[str]: |
| 222 | +def get_relative_path( |
| 223 | + graph: nx.DiGraph, source_coordinate_system: str, target_coordinate_system: str |
| 224 | +) -> list[str]: |
214 | 225 | cost_key = "cost" |
215 | 226 | """ |
216 | 227 | Get the relative path from one node to another in the transformation graph. |
@@ -244,23 +255,51 @@ def get_relative_path(graph: nx.DiGraph, source_coordinate_system: str, target_c |
244 | 255 | return path |
245 | 256 |
|
246 | 257 |
|
247 | | -def create_sequence_transformation_from_path( |
| 258 | +def create_sequence_transformation_from_graph_walk( |
248 | 259 | graph: nx.DiGraph, |
249 | | - path: list[str], |
250 | | -) -> list[Any]: |
| 260 | + walk: list[str | CoordinateSystemIdentifier], |
| 261 | +) -> SequenceTransformation: |
251 | 262 | """ |
252 | | - Create a sequence of transformations from a path of coordinate systems |
| 263 | + Create a sequence of transformations from a walk of coordinate systems |
253 | 264 | in the transformation graph. |
254 | 265 | """ |
255 | 266 | transformations = [] |
256 | | - for i in range(len(path) - 1): |
257 | | - source = path[i] |
258 | | - target = path[i + 1] |
| 267 | + for i in range(len(walk) - 1): |
| 268 | + source = walk[i] |
| 269 | + target = walk[i + 1] |
259 | 270 | edge_transformation = graph.get_edge_data(source, target)["transformation"] |
260 | 271 | transformations.append(edge_transformation) |
261 | 272 |
|
262 | | - from ome_zarr_models._v06.coordinate_transforms import Sequence |
| 273 | + return Sequence( |
| 274 | + input=walk[0], output=walk[-1], transformations=tuple(transformations) |
| 275 | + ) |
263 | 276 |
|
264 | | - transformations = Sequence(transformations=transformations) |
265 | 277 |
|
266 | | - return transformations |
| 278 | +def get_node( |
| 279 | + path: str | None = None, name: str | None = None |
| 280 | +) -> str | CoordinateSystemIdentifier: |
| 281 | + if path is None and name is None: |
| 282 | + raise ValueError("Both path and name of the coordinate system cannot be None") |
| 283 | + if path is None: |
| 284 | + return name |
| 285 | + if name is None: |
| 286 | + return path |
| 287 | + return CoordinateSystemIdentifier(path=path, name=name) |
| 288 | + |
| 289 | + |
| 290 | +def find_walks_in_graph( |
| 291 | + graph, src_path, src_name, tgt_path, tgt_name |
| 292 | +) -> list[str | CoordinateSystemIdentifier]: |
| 293 | + src_node = get_node(src_path, src_name) |
| 294 | + tgt_node = get_node(tgt_path, tgt_name) |
| 295 | + |
| 296 | + graph_walk = list(nx.all_shortest_paths(graph, src_node, tgt_node)) |
| 297 | + if not graph_walk: |
| 298 | + raise ValueError( |
| 299 | + f"No path found from {src_node} to {tgt_node} in the transformation graph." |
| 300 | + ) |
| 301 | + if len(graph_walk) > 1: |
| 302 | + logging.warning( |
| 303 | + f"Multiple paths found from {src_node} to {tgt_node} in the transformation graph. Using the first one." |
| 304 | + ) |
| 305 | + return graph_walk[0] |
0 commit comments