44from datashader .glyphs .quadmesh import (
55 QuadMeshRaster , QuadMeshRectilinear , QuadMeshCurvilinear , build_scale_translate
66)
7+ from .xarray import _extract_third_dim
78from datashader .utils import apply
89import dask
910import numpy as np
1011import xarray as xr
11- from dask .base import tokenize , compute
12+ from dask .base import tokenize , compute , flatten
1213from dask .array .overlap import overlap
1314dask_glyph_dispatch = Dispatcher ()
1415
1516
17+ def _prepare_3d_coords_and_dims (third_dim , xr_ds , axis , glyph ):
18+ dims_list = [glyph .y_label , glyph .x_label ]
19+
20+ if third_dim :
21+ coords_dict = axis .copy ()
22+ coords_dict [third_dim ] = xr_ds .coords [third_dim ]
23+ dims_list = [third_dim , glyph .y_label , glyph .x_label ]
24+ else :
25+ coords_dict = axis
26+
27+ return coords_dict , dims_list
28+
29+
1630def dask_xarray_pipeline (glyph , xr_ds , schema , canvas , summary , * , antialias = False , cuda = False ):
1731 dsk , name = dask_glyph_dispatch (
1832 glyph , xr_ds , schema , canvas , summary , antialias = antialias , cuda = cuda )
@@ -54,9 +68,19 @@ def shape_bounds_st_and_axis(xr_ds, canvas, glyph):
5468
5569 return shape , bounds , st , axis
5670
71+ def _data_info_3d (xr_ds , canvas , glyph ):
72+ shape , bounds , st , axis = shape_bounds_st_and_axis (xr_ds , canvas , glyph )
73+
74+ third_dim = _extract_third_dim (glyph , xr_ds )
75+ if third_dim :
76+ height , width = shape
77+ shape = (len (xr_ds .coords [third_dim ]), height , width )
78+
79+ return shape , bounds , st , axis , third_dim
80+
5781
5882def dask_rectilinear (glyph , xr_ds , schema , canvas , summary , * , antialias = False , cuda = False ):
59- shape , bounds , st , axis = shape_bounds_st_and_axis (xr_ds , canvas , glyph )
83+ shape , bounds , st , axis , third_dim = _data_info_3d (xr_ds , canvas , glyph )
6084
6185 # Compile functions
6286 create , info , append , combine , finalize , antialias_stage_2 , antialias_stage_2_funcs , _ = \
@@ -129,17 +153,19 @@ def chunk(np_arr, *inds):
129153 return aggs
130154
131155 name = tokenize (xr_ds .__dask_tokenize__ (), canvas , glyph , summary )
132- keys = [ k for row in xr_ds .__dask_keys__ ()[0 ] for k in row ]
156+ keys = tuple ( flatten ( xr_ds .__dask_keys__ ()[0 ]))
133157 keys2 = [(name , i ) for i in range (len (keys ))]
134- dsk = dict ((k2 , (chunk , k , k [1 ], k [2 ])) for (k2 , k ) in zip (keys2 , keys ))
158+ dsk = dict ((k2 , (chunk , k , * k [1 :])) for (k2 , k ) in zip (keys2 , keys ))
159+
160+ coords_dict , dims_list = _prepare_3d_coords_and_dims (third_dim , xr_ds , axis , glyph )
135161 dsk [name ] = (apply , finalize , [(combine , keys2 )],
136- dict (cuda = cuda , coords = axis , dims = [ glyph . y_label , glyph . x_label ] ,
162+ dict (cuda = cuda , coords = coords_dict , dims = dims_list ,
137163 attrs = dict (x_range = x_range , y_range = y_range )))
138164 return dsk , name
139165
140166
141167def dask_raster (glyph , xr_ds , schema , canvas , summary , * , antialias = False , cuda = False ):
142- shape , bounds , st , axis = shape_bounds_st_and_axis (xr_ds , canvas , glyph )
168+ shape , bounds , st , axis , third_dim = _data_info_3d (xr_ds , canvas , glyph )
143169
144170 # Compile functions
145171 create , info , append , combine , finalize , antialias_stage_2 , antialias_stage_2_funcs , _ = \
@@ -177,7 +203,7 @@ def dask_raster(glyph, xr_ds, schema, canvas, summary, *, antialias=False, cuda=
177203 ybinsize = abs (float (xr_ds [y_name ][1 ] - xr_ds [y_name ][0 ]))
178204
179205 # Compute scale/translate
180- out_h , out_w = shape
206+ out_h , out_w = shape [ - 2 :]
181207 src_h , src_w = [xr_ds [glyph .name ].shape [i ] for i in [ydim_ind , xdim_ind ]]
182208 out_x0 , out_x1 , out_y0 , out_y1 = bounds
183209 scale_y , translate_y = build_scale_translate (
@@ -224,19 +250,20 @@ def chunk(np_arr, *inds):
224250 return aggs
225251
226252 name = tokenize (xr_ds .__dask_tokenize__ (), canvas , glyph , summary )
227- keys = [ k for row in xr_ds .__dask_keys__ ()[0 ] for k in row ]
253+ keys = tuple ( flatten ( xr_ds .__dask_keys__ ()[0 ]))
228254 keys2 = [(name , i ) for i in range (len (keys ))]
229- dsk = dict ((k2 , (chunk , k , k [1 ], k [2 ])) for (k2 , k ) in zip (keys2 , keys ))
255+ dsk = dict ((k2 , (chunk , k , * k [1 :])) for (k2 , k ) in zip (keys2 , keys ))
256+
257+ coords_dict , dims_list = _prepare_3d_coords_and_dims (third_dim , xr_ds , axis , glyph )
230258 dsk [name ] = (apply , finalize , [(combine , keys2 )],
231- dict (cuda = cuda , coords = axis , dims = [ glyph . y_label , glyph . x_label ] ,
259+ dict (cuda = cuda , coords = coords_dict , dims = dims_list ,
232260 attrs = dict (x_range = x_range , y_range = y_range )))
233261 return dsk , name
234262
235263
236264def dask_curvilinear (glyph , xr_ds , schema , canvas , summary , * , antialias = False , cuda = False ):
237- shape , bounds , st , axis = shape_bounds_st_and_axis (xr_ds , canvas , glyph )
265+ shape , bounds , st , axis , third_dim = _data_info_3d (xr_ds , canvas , glyph )
238266
239- # Compile functions
240267 create , info , append , combine , finalize , antialias_stage_2 , antialias_stage_2_funcs , _ = \
241268 compile_components (summary , schema , glyph , antialias = antialias , cuda = cuda , partitioned = True )
242269 x_mapper = canvas .x_axis .mapper
@@ -259,20 +286,29 @@ def dask_curvilinear(glyph, xr_ds, schema, canvas, summary, *, antialias=False,
259286
260287 var_name = list (xr_ds .data_vars .keys ())[0 ]
261288
262- # Validate coordinates
289+ # Validate coordinates - for 3D, compare only spatial dimensions (exclude third_dim)
290+ if third_dim :
291+ expected_dims = tuple (d for d in xr_ds [glyph .name ].dims if d != third_dim )
292+ expected_chunks = tuple (
293+ c for d , c in zip (xr_ds [glyph .name ].dims , xr_ds [glyph .name ].chunks ) if d != third_dim
294+ )
295+ else :
296+ expected_dims = xr_ds [glyph .name ].dims
297+ expected_chunks = xr_ds [glyph .name ].chunks
298+
263299 err_msg = (
264300 "DataArray {name} is backed by a Dask array, \n "
265301 "but coordinate {coord} is not backed by a Dask array with identical \n "
266302 "dimension order and chunks"
267303 )
268304 if (not isinstance (x_centers , dask .array .Array ) or
269- xr_ds [glyph .name ].dims != xr_ds [ glyph . x ]. dims or
270- xr_ds [glyph .name ].chunks != xr_ds [ glyph . x ]. chunks ):
305+ xr_ds [glyph .x ].dims != expected_dims or
306+ xr_ds [glyph .x ].chunks != expected_chunks ):
271307 raise ValueError (err_msg .format (name = glyph .name , coord = glyph .x ))
272308
273309 if (not isinstance (y_centers , dask .array .Array ) or
274- xr_ds [glyph .name ].dims != xr_ds [ glyph . y ]. dims or
275- xr_ds [glyph .name ].chunks != xr_ds [ glyph . y ]. chunks ):
310+ xr_ds [glyph .y ].dims != expected_dims or
311+ xr_ds [glyph .y ].chunks != expected_chunks ):
276312 raise ValueError (err_msg .format (name = glyph .name , coord = glyph .y ))
277313
278314 # Make sure coordinates are floats so that overlap with nan will behave properly
@@ -324,9 +360,9 @@ def chunk(np_zs, np_x_centers, np_y_centers):
324360
325361 result_name = tokenize (xr_ds .__dask_tokenize__ (), canvas , glyph , summary )
326362
327- z_keys = [ k for row in zs .__dask_keys__ () for k in row ]
328- x_overlap_keys = [ k for row in x_overlapped_centers .__dask_keys__ () for k in row ]
329- y_overlap_keys = [ k for row in y_overlapped_centers .__dask_keys__ () for k in row ]
363+ z_keys = tuple ( flatten ( zs .__dask_keys__ ()))
364+ x_overlap_keys = tuple ( flatten ( x_overlapped_centers .__dask_keys__ ()))
365+ y_overlap_keys = tuple ( flatten ( y_overlapped_centers .__dask_keys__ ()))
330366
331367 result_keys = [(result_name , i ) for i in range (len (z_keys ))]
332368
@@ -337,9 +373,10 @@ def chunk(np_zs, np_x_centers, np_y_centers):
337373 )
338374 )
339375
376+ coords_dict , dims_list = _prepare_3d_coords_and_dims (third_dim , xr_ds , axis , glyph )
340377 dsk [result_name ] = (
341378 apply , finalize , [(combine , result_keys )],
342- dict (cuda = cuda , coords = axis , dims = [ glyph . y_label , glyph . x_label ] ,
379+ dict (cuda = cuda , coords = coords_dict , dims = dims_list ,
343380 attrs = dict (x_range = x_range , y_range = y_range ))
344381 )
345382
0 commit comments