Skip to content

Commit de2a762

Browse files
shoyerXarray-Beam authors
authored andcommitted
Add check to validate_chunks() that input datasets are not Dask chunked.
This is generally a bad idea for chunks used in xarray-beam pipelines, because dask arrays may have arbitrarily large associated task graph, which may be very expensive to serialize. PiperOrigin-RevId: 753317208
1 parent c52f5ca commit de2a762

File tree

2 files changed

+60
-36
lines changed

2 files changed

+60
-36
lines changed

xarray_beam/_src/core.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import immutabledict
3333
import numpy as np
3434
import xarray
35-
3635
from xarray_beam._src import threadmap
3736

3837
_DEFAULT = object()
@@ -473,10 +472,25 @@ def expand(self, pcoll):
473472

474473

475474
def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None:
476-
"""Verify that keys correspond to Dataset properties."""
475+
"""Verify that a key and dataset(s) are valid for xarray-beam transforms."""
477476
if isinstance(datasets, xarray.Dataset):
478477
datasets: list[xarray.Dataset] = [datasets]
478+
479479
for dataset in datasets:
480+
# Verify that no variables are chunked with Dask
481+
for var_name, variable in dataset.variables.items():
482+
if variable.chunks is not None:
483+
raise ValueError(
484+
f"Dataset variable {var_name!r} corresponding to key {key} is"
485+
" chunked with Dask. Datasets passed to validate_chunk must be"
486+
f" fully computed (not chunked): {dataset}\nThis typically arises"
487+
" with datasets originating with `xarray.open_zarr()`, which by"
488+
" default use Dask. If this is the case, you can fix it by passing"
489+
" `chunks=None` or xarray_beam.open_zarr(). Alternatively, you"
490+
" can load datasets explicitly into memory with `.compute()`."
491+
)
492+
493+
# Validate key offsets
480494
missing_keys = [
481495
repr(k) for k in key.offsets.keys() if k not in dataset.dims
482496
]
@@ -486,22 +500,20 @@ def validate_chunk(key: Key, datasets: DatasetOrDatasets) -> None:
486500
f" Dataset dimensions: {dataset!r}"
487501
)
488502

489-
if key.vars is None:
490-
continue
491-
492-
missing_vars = [repr(v) for v in key.vars if v not in dataset.data_vars]
493-
if missing_vars:
494-
raise ValueError(
495-
f"Key var(s) {', '.join(missing_vars)} in {key} not found in Dataset"
496-
f" data variables: {dataset!r}"
497-
)
503+
# Validate key vars
504+
if key.vars is not None:
505+
missing_vars = [repr(v) for v in key.vars if v not in dataset.data_vars]
506+
if missing_vars:
507+
raise ValueError(
508+
f"Key var(s) {', '.join(missing_vars)} in {key} not found in"
509+
f" Dataset data variables: {dataset!r}"
510+
)
498511

499512

500513
class ValidateEachChunk(beam.PTransform):
501-
"""Check that keys match the dataset for each key, dataset tuple."""
514+
"""Check that keys and dataset(s) are valid for xarray-beam transforms."""
502515

503516
def _validate(self, key, dataset):
504-
# Other checks may come later...
505517
validate_chunk(key, dataset)
506518
return key, dataset
507519

xarray_beam/_src/core_test.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Tests for xarray_beam._src.core."""
1515

16+
import re
1617
from absl.testing import absltest
1718
from absl.testing import parameterized
1819
import apache_beam as beam
@@ -509,48 +510,59 @@ def test_validate(self):
509510

510511
class ValidateEachChunkTest(test_util.TestCase):
511512

513+
def test_validate_chunk_raises_on_dask_chunked(self):
514+
dataset = xarray.Dataset({'foo': ('x', np.arange(6))}).chunk()
515+
key = xbeam.Key({'x': 0})
516+
517+
with self.assertRaisesRegex(
518+
ValueError,
519+
re.escape(
520+
"Dataset variable 'foo' corresponding to key Key(offsets={'x': 0},"
521+
' vars=None) is chunked with Dask. Datasets passed to'
522+
' validate_chunk must be fully computed (not chunked):'
523+
),
524+
):
525+
core.validate_chunk(key, dataset)
526+
512527
def test_unmatched_dimension_raises_error(self):
513528
dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
514-
with self.assertRaises(ValueError) as e:
515-
([(xbeam.Key({'x': 0, 'y': 0}), dataset)] | xbeam.ValidateEachChunk())
516-
self.assertIn(
517-
(
529+
key = xbeam.Key({'x': 0, 'y': 0})
530+
with self.assertRaisesRegex(
531+
ValueError,
532+
re.escape(
518533
"Key offset(s) 'y' in Key(offsets={'x': 0, 'y': 0}, vars=None) not "
519534
'found in Dataset dimensions'
520535
),
521-
e.exception.args[0],
522-
)
536+
):
537+
core.validate_chunk(key, dataset)
523538

524-
def test_unmatched_variables_raises_error(self):
539+
def test_unmatched_variables_raises_error_core(self):
525540
dataset = xarray.Dataset({'foo': ('x', np.arange(6))})
526-
with self.assertRaises(ValueError) as e:
527-
([(xbeam.Key({'x': 0}, {'bar'}), dataset)] | xbeam.ValidateEachChunk())
528-
self.assertIn(
529-
(
541+
key = xbeam.Key({'x': 0}, {'bar'})
542+
with self.assertRaisesRegex(
543+
ValueError,
544+
re.escape(
530545
"Key var(s) 'bar' in Key(offsets={'x': 0}, vars={'bar'}) not found"
531546
' in Dataset data variables'
532547
),
533-
e.exception.args[0],
534-
)
548+
):
549+
core.validate_chunk(key, dataset)
535550

536-
def test_unmatched_variables_multiple_datasets_raises_error(self):
551+
def test_unmatched_variables_multiple_datasets_raises_error_core(self):
537552
datasets = [
538553
xarray.Dataset({'foo': ('x', i + np.arange(6))}) for i in range(11)
539554
]
540555
datasets[5] = datasets[5].rename({'foo': 'bar'})
556+
key = xbeam.Key({'x': 0}, vars={'foo'})
541557

542-
with self.assertRaisesWithLiteralMatch(
558+
with self.assertRaisesRegex(
543559
ValueError,
544-
(
560+
re.escape(
545561
"Key var(s) 'foo' in Key(offsets={'x': 0}, vars={'foo'}) "
546-
f'not found in Dataset data variables: {datasets[5]} '
547-
"[while running 'ValidateEachChunk/MapTuple(_validate)']"
562+
f'not found in Dataset data variables: {datasets[5]}'
548563
),
549-
) as e:
550-
(
551-
[(xbeam.Key({'x': 0}, vars={'foo'}), datasets)]
552-
| xbeam.ValidateEachChunk()
553-
)
564+
):
565+
core.validate_chunk(key, datasets)
554566

555567
def test_validate_chunks_compose_in_pipeline(self):
556568
dataset = xarray.Dataset({'foo': ('x', np.arange(6))})

0 commit comments

Comments
 (0)