Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions datashader/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,27 @@ def validate(self, input_dshape):
for v in self.values:
v.validate(input_dshape)

for key, value in zip(self.keys, self.values, strict=True):
# Summary keys become Dataset variable names. They cannot collide with
# extra dimensions introduced by reductions (e.g. by(cat_column)).
dim_names = set()
stack = [value]
while stack:
reduction = stack.pop()
if isinstance(reduction, by):
dim_names.add(reduction.cat_column)
stack.append(reduction.reduction)
elif isinstance(reduction, where):
stack.append(reduction.selector)
elif isinstance(reduction, FloatingNReduction):
dim_names.add("n")

if key in dim_names:
raise ValueError(
f"Invalid summary reduction name {key!r}: it conflicts with "
f"a generated dimension name. Rename the summary key or reduction."
)

# Check that any included FloatingNReductions have the same n values.
n_values = []
for v in self.values:
Expand Down
13 changes: 13 additions & 0 deletions datashader/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy
import numpy as np
from numpy import nan
import pandas as pd
import xarray as xr

import datashader as ds
Expand Down Expand Up @@ -47,6 +48,18 @@ def assert_eq(agg, b):
assert agg.equals(b)


def test_summary_name_conflicts_with_by_dimension():
source = xr.Dataset(coords={
"x": xr.DataArray(np.array([0, 0, 1, 1]), dims="record"),
"y": xr.DataArray(np.array([0, 1, 0, 1]), dims="record"),
"foo foo": xr.DataArray(pd.Categorical(["a", "a", "b", "b"]), dims="record"),
})
agg = ds.summary(**{"foo foo": ds.by("foo foo")})
msg = "Invalid summary reduction name 'foo foo': it conflicts with a generated dimension name"
with pytest.raises(ValueError, match=msg):
c.points(source, "x", "y", agg)


@pytest.mark.parametrize("source", [xda, xdda, xds, xdds])
def test_count(source):
if source is None:
Expand Down