|
7 | 7 | from collections.abc import Iterator, Sequence |
8 | 8 | from dataclasses import dataclass |
9 | 9 | from enum import Enum |
10 | | -from functools import reduce |
| 10 | +from functools import lru_cache, reduce |
11 | 11 | from types import EllipsisType |
12 | 12 | from typing import ( |
13 | 13 | TYPE_CHECKING, |
@@ -1467,16 +1467,21 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]: |
1467 | 1467 | return tuple(out) |
1468 | 1468 |
|
1469 | 1469 |
|
1470 | | -def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: |
1471 | | - i = 0 |
| 1470 | +@lru_cache |
| 1471 | +def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: |
| 1472 | + n_total = product(chunk_shape) |
1472 | 1473 | order: list[tuple[int, ...]] = [] |
1473 | | - while len(order) < product(chunk_shape): |
| 1474 | + i = 0 |
| 1475 | + while len(order) < n_total: |
1474 | 1476 | m = decode_morton(i, chunk_shape) |
1475 | | - if m not in order and all(x < y for x, y in zip(m, chunk_shape, strict=False)): |
| 1477 | + if all(x < y for x, y in zip(m, chunk_shape, strict=False)): |
1476 | 1478 | order.append(m) |
1477 | 1479 | i += 1 |
1478 | | - for j in range(product(chunk_shape)): |
1479 | | - yield order[j] |
| 1480 | + return tuple(order) |
| 1481 | + |
| 1482 | + |
| 1483 | +def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: |
| 1484 | + return iter(_morton_order(tuple(chunk_shape))) |
1480 | 1485 |
|
1481 | 1486 |
|
1482 | 1487 | def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: |
|
0 commit comments