Skip to content

Commit 719e54a

Browse files
committed
wip code cleanup
1 parent 35e6b5a commit 719e54a

File tree

5 files changed

+63
-50
lines changed

5 files changed

+63
-50
lines changed

notebooks/transformations_demo.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
"source": [
1010
"import zarr\n",
1111
"import numpy as np\n",
12+
"\n",
13+
"import ngff_transformations.graph\n",
1214
"import ngff_transformations.transform as ngt\n",
1315
"\n",
1416
"from pathlib import Path\n",
@@ -95,9 +97,7 @@
9597
"id": "72648a07-c8d1-4078-80ef-2b7fb4758451",
9698
"metadata": {},
9799
"outputs": [],
98-
"source": [
99-
"transformation_path, (src_coord_system, tgt_coord_system) = ngt.find_walks_in_graph(nx_graph, 'VOI-01.ome.zarr/1', None, 'overview.ome.zarr', 'anatomical')"
100-
]
100+
"source": "transformation_path, (src_coord_system, tgt_coord_system) = ngff_transformations.graph.find_walks_in_graph(nx_graph, 'VOI-01.ome.zarr/1', None, 'overview.ome.zarr', 'anatomical')"
101101
},
102102
{
103103
"cell_type": "code",

src/ngff_transformations/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
__version__ = version("ngff-transformations")
44

5-
from .graph import (
6-
create_sequence_transformation_from_path,
5+
from ngff_transformations.graph import (
6+
find_walks_in_graph,
77
draw_graph,
88
get_relative_path,
99
transform_graph_to_networkx,

src/ngff_transformations/graph.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
from ome_zarr_models._v06.coordinate_transforms import (
77
CoordinateSystemIdentifier,
88
Transform,
9+
Sequence,
910
)
1011
from pydantic import ValidationError
12+
import logging
13+
from ome_zarr_models._v06.coordinate_transforms import (
14+
Sequence as SequenceTransformation,
15+
)
1116

1217

1318
def transform_graph_to_networkx(tgraph: TransformGraph) -> nx.DiGraph:
@@ -120,7 +125,9 @@ def _get_name_of_subgraph(
120125
f"Ambiguous coordinate system name '{cs_identifier}' found in both root and subgraph '{path_name}'. Use full identifier."
121126
)
122127
if cs_identifier not in nodes and cs_path_name not in nodes:
123-
raise ValueError(f"Coordinate system '{cs_identifier}' not found in graph nodes.")
128+
raise ValueError(
129+
f"Coordinate system '{cs_identifier}' not found in graph nodes."
130+
)
124131
if cs_path_name in nodes:
125132
return cs_path_name
126133
return cs_identifier
@@ -152,7 +159,9 @@ def _add_transform_and_inverse_transformation_edges(
152159
pass
153160

154161

155-
def draw_graph(g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labels: bool = True) -> None:
162+
def draw_graph(
163+
g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labels: bool = True
164+
) -> None:
156165
"""
157166
Draw a NetworkX graph showing all nodes and edges with their names.
158167
@@ -210,7 +219,9 @@ def draw_graph(g: nx.DiGraph, figsize: tuple[int, int] = (12, 8), with_edge_labe
210219
plt.show()
211220

212221

213-
def get_relative_path(graph: nx.DiGraph, source_coordinate_system: str, target_coordinate_system: str) -> list[str]:
222+
def get_relative_path(
223+
graph: nx.DiGraph, source_coordinate_system: str, target_coordinate_system: str
224+
) -> list[str]:
214225
cost_key = "cost"
215226
"""
216227
Get the relative path from one node to another in the transformation graph.
@@ -244,23 +255,51 @@ def get_relative_path(graph: nx.DiGraph, source_coordinate_system: str, target_c
244255
return path
245256

246257

247-
def create_sequence_transformation_from_path(
258+
def create_sequence_transformation_from_graph_walk(
248259
graph: nx.DiGraph,
249-
path: list[str],
250-
) -> list[Any]:
260+
walk: list[str | CoordinateSystemIdentifier],
261+
) -> SequenceTransformation:
251262
"""
252-
Create a sequence of transformations from a path of coordinate systems
263+
Create a sequence of transformations from a walk of coordinate systems
253264
in the transformation graph.
254265
"""
255266
transformations = []
256-
for i in range(len(path) - 1):
257-
source = path[i]
258-
target = path[i + 1]
267+
for i in range(len(walk) - 1):
268+
source = walk[i]
269+
target = walk[i + 1]
259270
edge_transformation = graph.get_edge_data(source, target)["transformation"]
260271
transformations.append(edge_transformation)
261272

262-
from ome_zarr_models._v06.coordinate_transforms import Sequence
273+
return Sequence(
274+
input=walk[0], output=walk[-1], transformations=tuple(transformations)
275+
)
263276

264-
transformations = Sequence(transformations=transformations)
265277

266-
return transformations
278+
def get_node(
279+
path: str | None = None, name: str | None = None
280+
) -> str | CoordinateSystemIdentifier:
281+
if path is None and name is None:
282+
raise ValueError("Both path and name of the coordinate system cannot be None")
283+
if path is None:
284+
return name
285+
if name is None:
286+
return path
287+
return CoordinateSystemIdentifier(path=path, name=name)
288+
289+
290+
def find_walks_in_graph(
291+
graph, src_path, src_name, tgt_path, tgt_name
292+
) -> list[str | CoordinateSystemIdentifier]:
293+
src_node = get_node(src_path, src_name)
294+
tgt_node = get_node(tgt_path, tgt_name)
295+
296+
graph_walk = list(nx.all_shortest_paths(graph, src_node, tgt_node))
297+
if not graph_walk:
298+
raise ValueError(
299+
f"No path found from {src_node} to {tgt_node} in the transformation graph."
300+
)
301+
if len(graph_walk) > 1:
302+
logging.warning(
303+
f"Multiple paths found from {src_node} to {tgt_node} in the transformation graph. Using the first one."
304+
)
305+
return graph_walk[0]

src/ngff_transformations/transform.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import networkx as nx
21
import numpy as np
3-
from ome_zarr_models._v06.coordinate_transforms import CoordinateSystemIdentifier, Sequence
2+
from ome_zarr_models._v06.coordinate_transforms import Sequence
43
from xarray import DataArray
54

65

@@ -9,32 +8,6 @@ def validata_point_shape(point: np.ndarray, transformation_sequence: Sequence):
98
assert len(point) == transformation.ndim, "Point ndim doesn't match transformation ndim"
109

1110

12-
def get_node(path: str | None = None, name: str | None = None) -> str | CoordinateSystemIdentifier:
13-
if path is None and name is None:
14-
raise ValueError("Both path and name of the coordinate system cannot be None")
15-
if path is None:
16-
return name
17-
if name is None:
18-
return path
19-
return CoordinateSystemIdentifier(path=path, name=name)
20-
21-
22-
def find_walks_in_graph(graph, src_path, src_name, tgt_path, tgt_name):
23-
src_node = get_node(src_path, src_name)
24-
tgt_node = get_node(tgt_path, tgt_name)
25-
26-
graph_walk = list(nx.all_shortest_paths(graph, src_node, tgt_node))[0]
27-
28-
transformation_sequence = []
29-
for i in range(len(graph_walk) - 1):
30-
transformation_sequence.append(graph.get_edge_data(graph_walk[i], graph_walk[i + 1])["transformation"])
31-
32-
transformation_sequence = Sequence(
33-
input=graph_walk[0], output=graph_walk[-1], transformations=transformation_sequence
34-
)
35-
return transformation_sequence, (graph_walk[0], graph_walk[-1])
36-
37-
3811
def transform_with_sequence3D(
3912
data: np.ndarray, axes: list[str], transformation_sequence: Sequence, output_axes: list[str]
4013
) -> DataArray:
@@ -58,7 +31,7 @@ def transform_with_sequence3D(
5831
y_prime = transformed_points[:, 1].reshape(Y, X, Z)
5932
z_prime = transformed_points[:, 2].reshape(Y, X, Z)
6033

61-
return xarray.DataArray(
34+
return DataArray(
6235
data,
6336
coords={
6437
"x_prime": (("y", "x", "z"), x_prime),
@@ -91,7 +64,7 @@ def transform_with_sequence(
9164
x_prime = transformed_points[:, 0].reshape(H, W)
9265
y_prime = transformed_points[:, 1].reshape(H, W)
9366

94-
return xarray.DataArray(
67+
return DataArray(
9568
data,
9669
coords={
9770
"x_prime": (("y", "x"), x_prime),

tests/test_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ome_zarr_models._v06.image import Image
99

1010
from ngff_transformations.graph import (
11-
create_sequence_transformation_from_path,
11+
find_walks_in_graph,
1212
get_relative_path,
1313
transform_graph_to_networkx,
1414
)
@@ -67,6 +67,7 @@ def test_graph(zarr_path: Path):
6767

6868
example_edge = list(nx_graph.edges)[0]
6969
path = get_relative_path(nx_graph, example_edge[0], example_edge[1])
70-
sequence_transformation = create_sequence_transformation_from_path(nx_graph, path)
70+
sequence_transformation = find_walks_in_graph(
71+
graph=nx_graph, path)
7172

7273
assert isinstance(sequence_transformation, Sequence)

0 commit comments

Comments
 (0)