Skip to content

Commit 1e5ddef

Browse files
shoyerXarray-Beam authors
authored andcommitted
Refactor xbeam.Dataset.to_zarr()
Moved the zarr_chunks_per_shard logic to a helper function. PiperOrigin-RevId: 827309513
1 parent 1ad6808 commit 1e5ddef

File tree

1 file changed

+38
-26
lines changed

1 file changed

+38
-26
lines changed

xarray_beam/_src/dataset.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
392392
return method
393393

394394

395+
395396
class _CountNamer:
396397

397398
def __init__(self):
@@ -676,6 +677,39 @@ def from_zarr(
676677
result._ptransform = _LazyPCollection(pipeline, result.ptransform)
677678
return result
678679

680+
def _zarr_chunks_per_shard_to_chunks(
681+
self,
682+
zarr_chunks_per_shard: Mapping[str, int],
683+
zarr_shards: Mapping[str, int],
684+
) -> Mapping[str, int]:
685+
"""Convert chunks per shard to chunks."""
686+
chunks_per_shard = dict(zarr_chunks_per_shard)
687+
if ... in chunks_per_shard:
688+
default_cps = chunks_per_shard.pop(...)
689+
else:
690+
default_cps = 1
691+
692+
extra_keys = set(chunks_per_shard) - set(self.template.dims)
693+
if extra_keys:
694+
raise ValueError(
695+
f'{zarr_chunks_per_shard=} includes keys that are not dimensions '
696+
f' in template: {extra_keys}'
697+
)
698+
699+
zarr_chunks = {}
700+
for dim, shard_size in zarr_shards.items():
701+
cps = chunks_per_shard.get(dim, default_cps)
702+
chunk_size, remainder = divmod(shard_size, cps)
703+
if remainder != 0:
704+
raise ValueError(
705+
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
706+
f'{zarr_chunks_per_shard=}, which do not evenly divide into '
707+
f'chunks. Computed chunk size for dimension {dim!r} is '
708+
f'{chunk_size}, based on {cps} chunks per shard.'
709+
)
710+
zarr_chunks[dim] = chunk_size
711+
return zarr_chunks
712+
679713
def _check_shards_or_chunks(
680714
self,
681715
zarr_chunks: Mapping[str, int],
@@ -754,33 +788,11 @@ def to_zarr(
754788
)
755789
if zarr_shards is None:
756790
zarr_shards = self.chunks
791+
zarr_chunks = self._zarr_chunks_per_shard_to_chunks(
792+
zarr_chunks_per_shard, zarr_shards
793+
)
757794

758-
chunks_per_shard = dict(zarr_chunks_per_shard)
759-
if ... in chunks_per_shard:
760-
default_cps = chunks_per_shard.pop(...)
761-
else:
762-
default_cps = 1
763-
764-
extra_keys = set(chunks_per_shard) - set(self.template.dims)
765-
if extra_keys:
766-
raise ValueError(
767-
f'{zarr_chunks_per_shard=} includes keys that are not dimensions '
768-
f' in template: {extra_keys}'
769-
)
770-
771-
zarr_chunks = {}
772-
for dim, shard_size in zarr_shards.items():
773-
cps = chunks_per_shard.get(dim, default_cps)
774-
chunk_size, remainder = divmod(shard_size, cps)
775-
if remainder != 0:
776-
raise ValueError(
777-
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
778-
f'{zarr_chunks_per_shard=}, which do not evenly divide into '
779-
f'chunks. Computed chunk size for dimension {dim!r} is '
780-
f'{chunk_size}, based on {cps} chunks per shard.'
781-
)
782-
zarr_chunks[dim] = chunk_size
783-
elif zarr_chunks is None:
795+
if zarr_chunks is None:
784796
if zarr_shards is not None:
785797
raise ValueError('cannot supply zarr_shards without zarr_chunks')
786798
zarr_chunks = {}

0 commit comments

Comments
 (0)