@@ -392,7 +392,6 @@ def method(self: Dataset, *args, **kwargs) -> Dataset:
392392 return method
393393
394394
395-
396395class _CountNamer :
397396
398397 def __init__ (self ):
@@ -408,6 +407,7 @@ def apply(self, name: str) -> str:
408407@dataclasses .dataclass (frozen = True )
409408class _LazyPCollection :
410409 """Pipeline and PTransform not yet been combined into a PCollection."""
410+
411411 # Beam does not provide a public API for manipulating Pipeline objects, so
412412 # instead of applying pipelines eagerly, we store them in this wrapper. This
413413 # allows for performance optimizations specialized to Xarray-Beam PTransforms,
@@ -715,12 +715,16 @@ def _check_shards_or_chunks(
715715 zarr_chunks : Mapping [str , int ],
716716 chunks_name : Literal ['shards' , 'chunks' ],
717717 ) -> None :
718- if any (self .chunks [k ] % zarr_chunks [k ] for k in self .chunks ):
719- raise ValueError (
720- f'cannot write a dataset with chunks { self .chunks } to Zarr with '
721- f'{ chunks_name } { zarr_chunks } , which do not divide evenly into '
722- f'{ chunks_name } '
723- )
718+ for k in self .chunks :
719+ if (
720+ self .chunks [k ] % zarr_chunks [k ]
721+ and self .chunks [k ] != self .template .sizes [k ]
722+ ):
723+ raise ValueError (
724+ f'cannot write a dataset with chunks { self .chunks } to Zarr with '
725+ f'{ chunks_name } { zarr_chunks } , which do not divide evenly into '
726+ f'{ chunks_name } '
727+ )
724728
725729 def to_zarr (
726730 self ,
@@ -804,6 +808,16 @@ def to_zarr(
804808 previous_chunks = self .chunks ,
805809 )
806810 if zarr_shards is not None :
811+ # Zarr shards are currently constrained to be an integer multiple of
812+ # chunk sizes, which means shard sizes must be rounded up to be larger
813+ # than the full dimension size. This will likely be relaxed in the future:
814+ # https://github.com/zarr-developers/zarr-extensions/issues/34
815+ zarr_shards = dict (zarr_shards )
816+ for k in zarr_shards :
817+ if zarr_shards [k ] == self .sizes [k ]:
818+ zarr_shards [k ] = (
819+ math .ceil (zarr_shards [k ] / zarr_chunks [k ]) * zarr_chunks [k ]
820+ )
807821 self ._check_shards_or_chunks (zarr_shards , 'shards' )
808822 else :
809823 self ._check_shards_or_chunks (zarr_chunks , 'chunks' )
@@ -956,9 +970,7 @@ def rechunk(
956970 ):
957971 # Rechunking can be performed by re-reading the source dataset with new
958972 # chunks, rather than using a separate rechunking transform.
959- ptransform = core .DatasetToChunks (
960- ptransform .dataset , chunks , split_vars
961- )
973+ ptransform = core .DatasetToChunks (ptransform .dataset , chunks , split_vars )
962974 ptransform .label = _concat_labels (ptransform .label , label )
963975 if pipeline is not None :
964976 ptransform = _LazyPCollection (pipeline , ptransform )
0 commit comments