Skip to content

Commit 7ea8f69

Browse files
shoyerXarray-Beam authors
authored andcommitted
Round-up zarr shards when they are equal to the full dimension size
Otherwise, it is extremely hard to use shards for dimensions with irregular sizes (e.g., "time" in a reanalysis dataset). With this change, a dataset with `sizes={'x': 19}` and `chunks={'x': 10}` can be sharded with `shards={'x': 20}`. PiperOrigin-RevId: 827316686
1 parent 1e5ddef commit 7ea8f69

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

xarray_beam/_src/dataset.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,6 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
392392
return method
393393

394394

395-
396395
class _CountNamer:
397396

398397
def __init__(self):
@@ -408,6 +407,7 @@ def apply(self, name: str) -> str:
408407
@dataclasses.dataclass(frozen=True)
409408
class _LazyPCollection:
410409
"""Pipeline and PTransform not yet been combined into a PCollection."""
410+
411411
# Beam does not provide a public API for manipulating Pipeline objects, so
412412
# instead of applying pipelines eagerly, we store them in this wrapper. This
413413
# allows for performance optimizations specialized to Xarray-Beam PTransforms,
@@ -715,12 +715,16 @@ def _check_shards_or_chunks(
715715
zarr_chunks: Mapping[str, int],
716716
chunks_name: Literal['shards', 'chunks'],
717717
) -> None:
718-
if any(self.chunks[k] % zarr_chunks[k] for k in self.chunks):
719-
raise ValueError(
720-
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
721-
f'{chunks_name} {zarr_chunks}, which do not divide evenly into '
722-
f'{chunks_name}'
723-
)
718+
for k in self.chunks:
719+
if (
720+
self.chunks[k] % zarr_chunks[k]
721+
and self.chunks[k] != self.template.sizes[k]
722+
):
723+
raise ValueError(
724+
f'cannot write a dataset with chunks {self.chunks} to Zarr with '
725+
f'{chunks_name} {zarr_chunks}, which do not divide evenly into '
726+
f'{chunks_name}'
727+
)
724728

725729
def to_zarr(
726730
self,
@@ -804,6 +808,16 @@ def to_zarr(
804808
previous_chunks=self.chunks,
805809
)
806810
if zarr_shards is not None:
811+
# Zarr shards are currently constrained to be an integer multiple of
812+
# chunk sizes, which means shard sizes must be rounded up to be larger
813+
# than the full dimension size. This will likely be relaxed in the future:
814+
# https://github.com/zarr-developers/zarr-extensions/issues/34
815+
zarr_shards = dict(zarr_shards)
816+
for k in zarr_shards:
817+
if zarr_shards[k] == self.sizes[k]:
818+
zarr_shards[k] = (
819+
math.ceil(zarr_shards[k] / zarr_chunks[k]) * zarr_chunks[k]
820+
)
807821
self._check_shards_or_chunks(zarr_shards, 'shards')
808822
else:
809823
self._check_shards_or_chunks(zarr_chunks, 'chunks')
@@ -956,9 +970,7 @@ def rechunk(
956970
):
957971
# Rechunking can be performed by re-reading the source dataset with new
958972
# chunks, rather than using a separate rechunking transform.
959-
ptransform = core.DatasetToChunks(
960-
ptransform.dataset, chunks, split_vars
961-
)
973+
ptransform = core.DatasetToChunks(ptransform.dataset, chunks, split_vars)
962974
ptransform.label = _concat_labels(ptransform.label, label)
963975
if pipeline is not None:
964976
ptransform = _LazyPCollection(pipeline, ptransform)

xarray_beam/_src/dataset_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,28 @@ def test_to_zarr_shards(self):
715715
zarr_shards={'x': 9},
716716
)
717717

718+
@parameterized.named_parameters(
719+
dict(testcase_name='empty', zarr_shards={}),
720+
dict(testcase_name='minus_one', zarr_shards=-1),
721+
dict(testcase_name='explicit_19', zarr_shards={'x': 19}),
722+
dict(testcase_name='explicit_20', zarr_shards={'x': 20}),
723+
)
724+
def test_to_zarr_shards_round_up(self, zarr_shards):
725+
temp_dir = self.create_tempdir().full_path
726+
ds = xarray.Dataset({'foo': ('x', np.arange(19, dtype='int64'))})
727+
beam_ds = xbeam.Dataset.from_xarray(ds, {'x': 19})
728+
729+
with beam.Pipeline() as p:
730+
p |= beam_ds.to_zarr(
731+
temp_dir,
732+
zarr_chunks={'x': 10},
733+
zarr_shards=zarr_shards,
734+
)
735+
opened, chunks = xbeam.open_zarr(temp_dir)
736+
xarray.testing.assert_identical(ds, opened)
737+
self.assertEqual(chunks, {'x': 10})
738+
self.assertEqual(opened['foo'].encoding['shards'], (20,))
739+
718740
def test_to_zarr_chunks_per_shard(self):
719741
temp_dir = self.create_tempdir().full_path
720742
ds = xarray.Dataset({'foo': ('x', np.arange(12))})

0 commit comments

Comments
 (0)