Skip to content

Commit 26e4612

Browse files
shoyerXarray-Beam authors
authored andcommitted
Preserve existing dimension order in replace_template_dims
This fixes a bug where replace_template_dims() could inadvertently change the order of dimensions. PiperOrigin-RevId: 824059437
1 parent 7db84e4 commit 26e4612

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@
5555
DatasetToZarr as DatasetToZarr,
5656
)
5757

58-
__version__ = '0.11.3' # automatically synchronized to pyproject.toml
58+
__version__ = '0.11.4' # automatically synchronized to pyproject.toml

xarray_beam/_src/zarr.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,13 @@ def replace_template_dims(
151151
template: xarray.Dataset,
152152
**dim_replacements: int | np.ndarray | pd.Index | xarray.DataArray,
153153
) -> xarray.Dataset:
154+
# pyformat: disable
154155
"""Replaces dimension(s) in a template with updates coordinates and/or sizes.
155156
156157
This is convenient for creating templates from evaluated results for a
157158
single chunk.
158159
159-
Example usage:
160+
Example usage::
160161
161162
import numpy as np
162163
import pandas as pd
@@ -178,27 +179,21 @@ def replace_template_dims(
178179
# Dimensions: (time: 1, longitude: 1440, latitude: 721)
179180
# Coordinates:
180181
# * time (time) datetime64[ns] 8B 1940-01-01
181-
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5
182-
359.8
183-
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75
184-
90.0
182+
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
183+
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
185184
# Data variables:
186-
# foo (time, longitude, latitude) float64 8MB
187-
dask.array<chunksize=(1, 1440, 721), meta=np.ndarray>
185+
# foo (time, longitude, latitude) float64 8MB dask.array<chunksize=(1, 1440, 721), meta=np.ndarray>
188186
189187
template = xbeam.replace_template_dims(template, time=times)
190188
print(template)
191189
# <xarray.Dataset> Size: 6TB
192190
# Dimensions: (time: 747769, longitude: 1440, latitude: 721)
193191
# Coordinates:
194-
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5
195-
359.8
196-
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75
197-
90.0
192+
# * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
193+
# * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
198194
# * time (time) datetime64[ns] 6MB 1940-01-01 ... 2025-04-21
199195
# Data variables:
200-
# foo (time, longitude, latitude) float64 6TB
201-
dask.array<chunksize=(747769, 1440, 721), meta=np.ndarray>
196+
# foo (time, longitude, latitude) float64 6TB dask.array<chunksize=(747769, 1440, 721), meta=np.ndarray>
202197
203198
Args:
204199
template: The template to replace dimensions in.
@@ -209,22 +204,25 @@ def replace_template_dims(
209204
Returns:
210205
Template with the replaced dimensions.
211206
"""
212-
expansions = {}
207+
# pyformat: enable
208+
expansions_with_axes = {}
213209
for name, variable in template.items():
214210
if variable.chunks is None:
215211
raise ValueError(
216212
f'Data variable {name} is not chunked with Dask. Please call'
217213
' xarray_beam.make_template() to create a valid template before '
218214
f' calling replace_template_dims(): {template}'
219215
)
220-
expansions[name] = {
221-
dim: replacement
222-
for dim, replacement in dim_replacements.items()
223-
if dim in variable.dims
224-
}
216+
# identify which dimensions of this variable need to be replaced, in order
217+
dims_to_replace = [dim for dim in variable.dims if dim in dim_replacements]
218+
if dims_to_replace:
219+
expansions = {dim: dim_replacements[dim] for dim in dims_to_replace}
220+
axes = [variable.dims.index(dim) for dim in dims_to_replace]
221+
expansions_with_axes[name] = (expansions, axes)
222+
225223
template = template.isel({dim: 0 for dim in dim_replacements}, drop=True)
226-
for name, variable in template.items():
227-
template[name] = variable.expand_dims(expansions[name])
224+
for name, (expansions, axes) in expansions_with_axes.items():
225+
template[name] = template[name].expand_dims(expansions, axis=axes)
228226
return template
229227

230228

xarray_beam/_src/zarr_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,17 @@ def test_replace_template_dims_multiple_vars(self):
182182
self.assertIsInstance(new_template.bar.data, da.Array)
183183
self.assertIsInstance(new_template.baz.data, da.Array)
184184

185+
def test_replace_template_dims_multiple_dims_unordered(self):
186+
source = xarray.Dataset(
187+
{'foo': (('x', 'y', 'z'), np.zeros((1, 2, 3)))},
188+
coords={'x': [0], 'y': [10, 20], 'z': [1, 2, 3]},
189+
)
190+
template = xbeam.make_template(source)
191+
new_template = xbeam.replace_template_dims(template, z=4, x=5)
192+
193+
self.assertEqual(new_template.sizes, {'x': 5, 'y': 2, 'z': 4})
194+
self.assertEqual(new_template.foo.dims, ('x', 'y', 'z'))
195+
185196
def test_replace_template_dims_error_on_non_template(self):
186197
source = xarray.Dataset({'foo': ('x', np.zeros(1))}) # Not a template
187198
with self.assertRaisesRegex(ValueError, 'is not chunked with Dask'):

0 commit comments

Comments
 (0)