@@ -151,12 +151,13 @@ def replace_template_dims(
151151 template : xarray .Dataset ,
152152 ** dim_replacements : int | np .ndarray | pd .Index | xarray .DataArray ,
153153) -> xarray .Dataset :
154+ # pyformat: disable
154155 """Replaces dimension(s) in a template with updates coordinates and/or sizes.
155156
156157 This is convenient for creating templates from evaluated results for a
157158 single chunk.
158159
159- Example usage:
160+ Example usage::
160161
161162 import numpy as np
162163 import pandas as pd
@@ -178,27 +179,21 @@ def replace_template_dims(
178179 # Dimensions: (time: 1, longitude: 1440, latitude: 721)
179180 # Coordinates:
180181 # * time (time) datetime64[ns] 8B 1940-01-01
181- # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5
182- 359.8
183- # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75
184- 90.0
182+ # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
183+ # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
185184 # Data variables:
186- # foo (time, longitude, latitude) float64 8MB
187- dask.array<chunksize=(1, 1440, 721), meta=np.ndarray>
185+ # foo (time, longitude, latitude) float64 8MB dask.array<chunksize=(1, 1440, 721), meta=np.ndarray>
188186
189187 template = xbeam.replace_template_dims(template, time=times)
190188 print(template)
191189 # <xarray.Dataset> Size: 6TB
192190 # Dimensions: (time: 747769, longitude: 1440, latitude: 721)
193191 # Coordinates:
194- # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5
195- 359.8
196- # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75
197- 90.0
192+ # * longitude (longitude) float64 12kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
193+ # * latitude (latitude) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
198194 # * time (time) datetime64[ns] 6MB 1940-01-01 ... 2025-04-21
199195 # Data variables:
200- # foo (time, longitude, latitude) float64 6TB
201- dask.array<chunksize=(747769, 1440, 721), meta=np.ndarray>
196+ # foo (time, longitude, latitude) float64 6TB dask.array<chunksize=(747769, 1440, 721), meta=np.ndarray>
202197
203198 Args:
204199 template: The template to replace dimensions in.
@@ -209,22 +204,25 @@ def replace_template_dims(
209204 Returns:
210205 Template with the replaced dimensions.
211206 """
212- expansions = {}
207+ # pyformat: enable
208+ expansions_with_axes = {}
213209 for name , variable in template .items ():
214210 if variable .chunks is None :
215211 raise ValueError (
216212 f'Data variable { name } is not chunked with Dask. Please call'
217213 ' xarray_beam.make_template() to create a valid template before '
218214 f' calling replace_template_dims(): { template } '
219215 )
220- expansions [name ] = {
221- dim : replacement
222- for dim , replacement in dim_replacements .items ()
223- if dim in variable .dims
224- }
216+ # identify which dimensions of this variable need to be replaced, in order
217+ dims_to_replace = [dim for dim in variable .dims if dim in dim_replacements ]
218+ if dims_to_replace :
219+ expansions = {dim : dim_replacements [dim ] for dim in dims_to_replace }
220+ axes = [variable .dims .index (dim ) for dim in dims_to_replace ]
221+ expansions_with_axes [name ] = (expansions , axes )
222+
225223 template = template .isel ({dim : 0 for dim in dim_replacements }, drop = True )
226- for name , variable in template .items ():
227- template [name ] = variable .expand_dims (expansions [ name ] )
224+ for name , ( expansions , axes ) in expansions_with_axes .items ():
225+ template [name ] = template [ name ] .expand_dims (expansions , axis = axes )
228226 return template
229227
230228
0 commit comments