|
4 | 4 | from typing import Optional, Tuple, Dict, Union |
5 | 5 |
|
6 | 6 |
|
7 | | -def seed_competition(seeds: np.ndarray, image: Optional[np.ndarray] = None, graph: Optional[sparse.csr_matrix] = None, |
8 | | - image_3d: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| 7 | +def seed_competition( |
| 8 | + seeds: np.ndarray, |
| 9 | + image: Optional[np.ndarray] = None, |
| 10 | + graph: Optional[Union[sparse.csr_matrix, sparse.csr_array]] = None, |
| 11 | + image_3d: bool = False, |
| 12 | +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
9 | 13 | """ |
10 | 14 | Performs the shortest path classification from the `seeds` nodes |
11 | 15 | using the image foresting transform algorithm [1]_. |
@@ -102,8 +106,8 @@ def seed_competition(seeds: np.ndarray, image: Optional[np.ndarray] = None, grap |
102 | 106 | return _pyift.seed_competition_grid(image, seeds) |
103 | 107 |
|
104 | 108 | # graph is provided |
105 | | - if not isinstance(graph, sparse.csr_matrix): |
106 | | - raise TypeError('`graph` must be a `csr_matrix`.') |
| 109 | + if not isinstance(graph, (sparse.csr_matrix, sparse.csr_array)): |
| 110 | + raise TypeError('`graph` must be a `csr_matrix` or `csr_array`.') |
107 | 111 |
|
108 | 112 | if graph.shape[0] != graph.shape[1]: |
109 | 113 | raise ValueError('`graph` must be a square adjacency matrix, current shape %r.' % graph.shape) |
|
0 commit comments