Skip to content

Commit dea2462

Browse files
committed
fixed type checking for latest scipy
1 parent 44d4397 commit dea2462

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/pyift/shortestpath.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
from typing import Optional, Tuple, Dict, Union
55

66

7-
def seed_competition(seeds: np.ndarray, image: Optional[np.ndarray] = None, graph: Optional[sparse.csr_matrix] = None,
8-
image_3d: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
7+
def seed_competition(
8+
seeds: np.ndarray,
9+
image: Optional[np.ndarray] = None,
10+
graph: Optional[Union[sparse.csr_matrix, sparse.csr_array]] = None,
11+
image_3d: bool = False,
12+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
913
"""
1014
Performs the shortest path classification from the `seeds` nodes
1115
using the image foresting transform algorithm [1]_.
@@ -102,8 +106,8 @@ def seed_competition(seeds: np.ndarray, image: Optional[np.ndarray] = None, grap
102106
return _pyift.seed_competition_grid(image, seeds)
103107

104108
# graph is provided
105-
if not isinstance(graph, sparse.csr_matrix):
106-
raise TypeError('`graph` must be a `csr_matrix`.')
109+
if not isinstance(graph, (sparse.csr_matrix, sparse.csr_array)):
110+
raise TypeError('`graph` must be a `csr_matrix` or `csr_array`.')
107111

108112
if graph.shape[0] != graph.shape[1]:
109113
raise ValueError('`graph` must be a square adjacency matrix, current shape %r.' % graph.shape)

0 commit comments

Comments
 (0)