Skip to content

Commit c18c685

Browse files
authored
feat: Make quadmesh support bandwise 2D (#1472)
1 parent 9edecf4 commit c18c685

File tree

7 files changed

+507
-86
lines changed

7 files changed

+507
-86
lines changed

datashader/core.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,23 @@ def polygons(self, source, geometry, agg=None):
780780
return bypixel(source, self, glyph, agg)
781781

782782
def quadmesh(self, source, x=None, y=None, agg=None):
783-
"""Samples a recti- or curvi-linear quadmesh by canvas size and bounds.
783+
r"""Samples a raster, rectilinear or curvilinear quadmesh by canvas size and bounds.
784+
785+
+---------------------+---------------------+---------------------+
786+
| RASTER | RECTILINEAR | CURVILINEAR |
787+
+---------------------+---------------------+---------------------+
788+
| Regular spacing | Variable spacing | Variable 2D spacing |
789+
| o---o---o---o---o | o-o---o----o-o | o---o---o---o |
790+
| | | | | | | | | | | | | / / / / |
791+
| o---o---o---o---o | o-o---o----o-o | o----o--o---o |
792+
| | | | | | | | | | | | | / / / / |
793+
| o---o---o---o---o | o-o---o----o-o | o--o----o---o |
794+
| | | | | | | | | | | | | \ \ \ \ |
795+
| o---o---o---o---o | o-o---o----o-o | o---o---o---o |
796+
| `dx = dy = constant` | `dx` & `dy` vary in 1D | `dx` & `dy` vary in 2D |
797+
| `x[i] = i * dx` | `x[i], y[j]` | `x[i,j], y[i,j]` |
798+
| `y[j] = j * dy` | | |
799+
+---------------------+---------------------+---------------------+
784800
785801
Parameters
786802
----------
@@ -795,10 +811,12 @@ def quadmesh(self, source, x=None, y=None, agg=None):
795811
Returns
796812
-------
797813
data : xarray.DataArray
814+
815+
Note
816+
----
817+
Table from EarthMover
798818
"""
799819
from .glyphs import QuadMeshRaster, QuadMeshRectilinear, QuadMeshCurvilinear
800-
801-
# Determine reduction operation
802820
from .reductions import mean as mean_rnd
803821

804822
if isinstance(source, Dataset):
@@ -819,7 +837,10 @@ def quadmesh(self, source, x=None, y=None, agg=None):
819837
agg = mean_rnd(name)
820838

821839
if x is None and y is None:
822-
y, x = source[name].dims
840+
if len(source[name].dims) == 2:
841+
y, x = source[name].dims
842+
else:
843+
raise ValueError("x and y must be specified if dims is not 2D.")
823844
elif not x or not y:
824845
raise ValueError("Either specify both x and y coordinates"
825846
"or allow them to be inferred.")

datashader/data_libraries/dask_xarray.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,29 @@
44
from datashader.glyphs.quadmesh import (
55
QuadMeshRaster, QuadMeshRectilinear, QuadMeshCurvilinear, build_scale_translate
66
)
7+
from .xarray import _extract_third_dim
78
from datashader.utils import apply
89
import dask
910
import numpy as np
1011
import xarray as xr
11-
from dask.base import tokenize, compute
12+
from dask.base import tokenize, compute, flatten
1213
from dask.array.overlap import overlap
1314
dask_glyph_dispatch = Dispatcher()
1415

1516

17+
def _prepare_3d_coords_and_dims(third_dim, xr_ds, axis, glyph):
18+
dims_list = [glyph.y_label, glyph.x_label]
19+
20+
if third_dim:
21+
coords_dict = axis.copy()
22+
coords_dict[third_dim] = xr_ds.coords[third_dim]
23+
dims_list = [third_dim, glyph.y_label, glyph.x_label]
24+
else:
25+
coords_dict = axis
26+
27+
return coords_dict, dims_list
28+
29+
1630
def dask_xarray_pipeline(glyph, xr_ds, schema, canvas, summary, *, antialias=False, cuda=False):
1731
dsk, name = dask_glyph_dispatch(
1832
glyph, xr_ds, schema, canvas, summary, antialias=antialias, cuda=cuda)
@@ -54,9 +68,19 @@ def shape_bounds_st_and_axis(xr_ds, canvas, glyph):
5468

5569
return shape, bounds, st, axis
5670

71+
def _data_info_3d(xr_ds, canvas, glyph):
72+
shape, bounds, st, axis = shape_bounds_st_and_axis(xr_ds, canvas, glyph)
73+
74+
third_dim = _extract_third_dim(glyph, xr_ds)
75+
if third_dim:
76+
height, width = shape
77+
shape = (len(xr_ds.coords[third_dim]), height, width)
78+
79+
return shape, bounds, st, axis, third_dim
80+
5781

5882
def dask_rectilinear(glyph, xr_ds, schema, canvas, summary, *, antialias=False, cuda=False):
59-
shape, bounds, st, axis = shape_bounds_st_and_axis(xr_ds, canvas, glyph)
83+
shape, bounds, st, axis, third_dim = _data_info_3d(xr_ds, canvas, glyph)
6084

6185
# Compile functions
6286
create, info, append, combine, finalize, antialias_stage_2, antialias_stage_2_funcs, _ = \
@@ -129,17 +153,19 @@ def chunk(np_arr, *inds):
129153
return aggs
130154

131155
name = tokenize(xr_ds.__dask_tokenize__(), canvas, glyph, summary)
132-
keys = [k for row in xr_ds.__dask_keys__()[0] for k in row]
156+
keys = tuple(flatten(xr_ds.__dask_keys__()[0]))
133157
keys2 = [(name, i) for i in range(len(keys))]
134-
dsk = dict((k2, (chunk, k, k[1], k[2])) for (k2, k) in zip(keys2, keys))
158+
dsk = dict((k2, (chunk, k, *k[1:])) for (k2, k) in zip(keys2, keys))
159+
160+
coords_dict, dims_list = _prepare_3d_coords_and_dims(third_dim, xr_ds, axis, glyph)
135161
dsk[name] = (apply, finalize, [(combine, keys2)],
136-
dict(cuda=cuda, coords=axis, dims=[glyph.y_label, glyph.x_label],
162+
dict(cuda=cuda, coords=coords_dict, dims=dims_list,
137163
attrs=dict(x_range=x_range, y_range=y_range)))
138164
return dsk, name
139165

140166

141167
def dask_raster(glyph, xr_ds, schema, canvas, summary, *, antialias=False, cuda=False):
142-
shape, bounds, st, axis = shape_bounds_st_and_axis(xr_ds, canvas, glyph)
168+
shape, bounds, st, axis, third_dim = _data_info_3d(xr_ds, canvas, glyph)
143169

144170
# Compile functions
145171
create, info, append, combine, finalize, antialias_stage_2, antialias_stage_2_funcs, _ = \
@@ -177,7 +203,7 @@ def dask_raster(glyph, xr_ds, schema, canvas, summary, *, antialias=False, cuda=
177203
ybinsize = abs(float(xr_ds[y_name][1] - xr_ds[y_name][0]))
178204

179205
# Compute scale/translate
180-
out_h, out_w = shape
206+
out_h, out_w = shape[-2:]
181207
src_h, src_w = [xr_ds[glyph.name].shape[i] for i in [ydim_ind, xdim_ind]]
182208
out_x0, out_x1, out_y0, out_y1 = bounds
183209
scale_y, translate_y = build_scale_translate(
@@ -224,19 +250,20 @@ def chunk(np_arr, *inds):
224250
return aggs
225251

226252
name = tokenize(xr_ds.__dask_tokenize__(), canvas, glyph, summary)
227-
keys = [k for row in xr_ds.__dask_keys__()[0] for k in row]
253+
keys = tuple(flatten(xr_ds.__dask_keys__()[0]))
228254
keys2 = [(name, i) for i in range(len(keys))]
229-
dsk = dict((k2, (chunk, k, k[1], k[2])) for (k2, k) in zip(keys2, keys))
255+
dsk = dict((k2, (chunk, k, *k[1:])) for (k2, k) in zip(keys2, keys))
256+
257+
coords_dict, dims_list = _prepare_3d_coords_and_dims(third_dim, xr_ds, axis, glyph)
230258
dsk[name] = (apply, finalize, [(combine, keys2)],
231-
dict(cuda=cuda, coords=axis, dims=[glyph.y_label, glyph.x_label],
259+
dict(cuda=cuda, coords=coords_dict, dims=dims_list,
232260
attrs=dict(x_range=x_range, y_range=y_range)))
233261
return dsk, name
234262

235263

236264
def dask_curvilinear(glyph, xr_ds, schema, canvas, summary, *, antialias=False, cuda=False):
237-
shape, bounds, st, axis = shape_bounds_st_and_axis(xr_ds, canvas, glyph)
265+
shape, bounds, st, axis, third_dim = _data_info_3d(xr_ds, canvas, glyph)
238266

239-
# Compile functions
240267
create, info, append, combine, finalize, antialias_stage_2, antialias_stage_2_funcs, _ = \
241268
compile_components(summary, schema, glyph, antialias=antialias, cuda=cuda, partitioned=True)
242269
x_mapper = canvas.x_axis.mapper
@@ -259,20 +286,29 @@ def dask_curvilinear(glyph, xr_ds, schema, canvas, summary, *, antialias=False,
259286

260287
var_name = list(xr_ds.data_vars.keys())[0]
261288

262-
# Validate coordinates
289+
# Validate coordinates - for 3D, compare only spatial dimensions (exclude third_dim)
290+
if third_dim:
291+
expected_dims = tuple(d for d in xr_ds[glyph.name].dims if d != third_dim)
292+
expected_chunks = tuple(
293+
c for d, c in zip(xr_ds[glyph.name].dims, xr_ds[glyph.name].chunks) if d != third_dim
294+
)
295+
else:
296+
expected_dims = xr_ds[glyph.name].dims
297+
expected_chunks = xr_ds[glyph.name].chunks
298+
263299
err_msg = (
264300
"DataArray {name} is backed by a Dask array, \n"
265301
"but coordinate {coord} is not backed by a Dask array with identical \n"
266302
"dimension order and chunks"
267303
)
268304
if (not isinstance(x_centers, dask.array.Array) or
269-
xr_ds[glyph.name].dims != xr_ds[glyph.x].dims or
270-
xr_ds[glyph.name].chunks != xr_ds[glyph.x].chunks):
305+
xr_ds[glyph.x].dims != expected_dims or
306+
xr_ds[glyph.x].chunks != expected_chunks):
271307
raise ValueError(err_msg.format(name=glyph.name, coord=glyph.x))
272308

273309
if (not isinstance(y_centers, dask.array.Array) or
274-
xr_ds[glyph.name].dims != xr_ds[glyph.y].dims or
275-
xr_ds[glyph.name].chunks != xr_ds[glyph.y].chunks):
310+
xr_ds[glyph.y].dims != expected_dims or
311+
xr_ds[glyph.y].chunks != expected_chunks):
276312
raise ValueError(err_msg.format(name=glyph.name, coord=glyph.y))
277313

278314
# Make sure coordinates are floats so that overlap with nan will behave properly
@@ -324,9 +360,9 @@ def chunk(np_zs, np_x_centers, np_y_centers):
324360

325361
result_name = tokenize(xr_ds.__dask_tokenize__(), canvas, glyph, summary)
326362

327-
z_keys = [k for row in zs.__dask_keys__() for k in row]
328-
x_overlap_keys = [k for row in x_overlapped_centers.__dask_keys__() for k in row]
329-
y_overlap_keys = [k for row in y_overlapped_centers.__dask_keys__() for k in row]
363+
z_keys = tuple(flatten(zs.__dask_keys__()))
364+
x_overlap_keys = tuple(flatten(x_overlapped_centers.__dask_keys__()))
365+
y_overlap_keys = tuple(flatten(y_overlapped_centers.__dask_keys__()))
330366

331367
result_keys = [(result_name, i) for i in range(len(z_keys))]
332368

@@ -337,9 +373,10 @@ def chunk(np_zs, np_x_centers, np_y_centers):
337373
)
338374
)
339375

376+
coords_dict, dims_list = _prepare_3d_coords_and_dims(third_dim, xr_ds, axis, glyph)
340377
dsk[result_name] = (
341378
apply, finalize, [(combine, result_keys)],
342-
dict(cuda=cuda, coords=axis, dims=[glyph.y_label, glyph.x_label],
379+
dict(cuda=cuda, coords=coords_dict, dims=dims_list,
343380
attrs=dict(x_range=x_range, y_range=y_range))
344381
)
345382

datashader/data_libraries/xarray.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datashader.core import bypixel
66
import xarray as xr
77
from datashader.utils import Dispatcher
8+
from datashader.compiler import compile_components
89

910

1011
try:
@@ -33,6 +34,61 @@ def xarray_pipeline(xr_ds, schema, canvas, glyph, summary, *, antialias=False):
3334
glyph, xr_ds, schema, canvas, summary, antialias=antialias, cuda=cuda)
3435

3536

37+
def _extract_third_dim(glyph, source):
38+
x_dims = set(source.coords[glyph.x].dims) if glyph.x in source.coords else {glyph.x}
39+
y_dims = set(source.coords[glyph.y].dims) if glyph.y in source.coords else {glyph.y}
40+
dims = set(source.dims) - (x_dims | y_dims)
41+
match len(dims):
42+
case 0:
43+
return None
44+
case 1:
45+
return next(iter(dims))
46+
case _:
47+
raise ValueError("Only one additional dimension supported for QuadMesh glyphs.")
48+
49+
50+
@glyph_dispatch.register(_QuadMeshLike)
51+
def quadmesh_default(glyph, source, schema, canvas, summary, *, antialias=False, cuda=False):
52+
third_dim = _extract_third_dim(glyph, source)
53+
if not third_dim:
54+
return default(glyph, source, schema, canvas, summary, antialias=antialias, cuda=cuda)
55+
56+
create, info, append, _, finalize, antialias_stage_2, antialias_stage_2_funcs, _ = \
57+
compile_components(summary, schema, glyph, antialias=antialias, cuda=cuda,
58+
partitioned=False)
59+
x_mapper = canvas.x_axis.mapper
60+
y_mapper = canvas.y_axis.mapper
61+
extend = glyph._build_extend(
62+
x_mapper, y_mapper, info, append, antialias_stage_2, antialias_stage_2_funcs)
63+
64+
x_range = canvas.x_range or glyph.compute_x_bounds(source)
65+
y_range = canvas.y_range or glyph.compute_y_bounds(source)
66+
canvas.validate_ranges(x_range, y_range)
67+
68+
width = canvas.plot_width
69+
height = canvas.plot_height
70+
71+
x_st = canvas.x_axis.compute_scale_and_translate(x_range, width)
72+
y_st = canvas.y_axis.compute_scale_and_translate(y_range, height)
73+
74+
x_axis = canvas.x_axis.compute_index(x_st, width)
75+
y_axis = canvas.y_axis.compute_index(y_st, height)
76+
77+
bases = create((len(source.coords[third_dim]), height, width))
78+
79+
extend(bases, source, x_st + y_st, x_range + y_range)
80+
81+
return finalize(
82+
bases,
83+
cuda=cuda,
84+
coords=dict([
85+
(third_dim, source.coords[third_dim]),
86+
(glyph.x_label, x_axis),
87+
(glyph.y_label, y_axis),
88+
]),
89+
dims=[third_dim, glyph.y_label, glyph.x_label],
90+
attrs=dict(x_range=x_range, y_range=y_range)
91+
)
92+
3693
# Default to default pandas implementation
37-
glyph_dispatch.register(_QuadMeshLike)(default)
3894
glyph_dispatch.register(LinesXarrayCommonX)(default)

0 commit comments

Comments
 (0)