Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/).

## Unreleased

Fixed:
- Conservative regridding with sparse weights (the default when the optional `sparse` package is installed) no longer depends on `opt-einsum` for acceptable performance. The per-axis weights are now applied with a scipy CSR sparse-dense matmul instead of a multi-operand sparse `xr.dot`, which was 20–80x slower when `opt-einsum` was absent (e.g. when installing the `conservative-2d` extra, which brings `sparse`, without `accel`, which brings `opt-einsum`). The weights stay sparse (no extra memory at high resolution), the regridded result is now a dense array rather than `sparse.COO`, and numerical results are unchanged.

## 0.4.2 (2026-01-28)

New contributors:
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,7 @@ warn_return_any = true
warn_unused_ignores = true
show_error_codes = true
exclude = ["tests/*", "docs"]

[[tool.mypy.overrides]]
module = ["scipy.*"]
ignore_missing_imports = true
96 changes: 82 additions & 14 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import overload

import numpy as np
import scipy.sparse
import xarray as xr

try:
Expand Down Expand Up @@ -178,34 +179,101 @@ def conservative_regrid_dataset(
return ds_regridded


def _csr_apply_axis(
da: xr.DataArray, weight: xr.DataArray, coord: Hashable
) -> xr.DataArray:
"""Contract ``da`` along ``coord`` with ``weight`` via a scipy CSR matmul.

``weight`` has dims ``(coord, target_{coord})``; the result replaces ``coord``
with ``target_{coord}``. A direct CSR sparse-dense matmul is fast and --
unlike the multi-operand sparse ``xr.dot`` -- does not need ``opt_einsum`` to
find an efficient contraction path. The weight is compressed to CSR (whether
stored as ``sparse.COO`` or dense), so the ``(n_src, n_dst)`` matrix is never
held dense; only the dense regridded result (one value per target cell) is
produced.
"""
target_dim = f"target_{coord}"
target_coords = weight[target_dim].to_numpy()

wdata = weight.data
if hasattr(wdata, "compute"): # dask-backed weight; materialize (it's small)
wdata = wdata.compute()
scipy_coo = (
wdata.to_scipy_sparse()
if sparse is not None and isinstance(wdata, sparse.COO)
else scipy.sparse.coo_matrix(np.asarray(wdata))
)
# store as (n_dst, n_src) CSR so the kernel is ``csr @ dense -> dense``
csr = scipy_coo.T.tocsr()
n_dst = csr.shape[0]
out_dtype = np.result_type(da.dtype, wdata.dtype)

def _matmul(arr: np.ndarray) -> np.ndarray:
flat = arr.reshape(-1, arr.shape[-1]).astype(out_dtype, copy=False)
dense: np.ndarray = np.asarray(csr @ flat.T).T # (n_rows, n_dst)
return dense.reshape(*arr.shape[:-1], n_dst)

result: xr.DataArray = xr.apply_ufunc(
_matmul,
da,
input_core_dims=[[coord]],
output_core_dims=[[target_dim]],
exclude_dims={coord},
dask="parallelized",
output_dtypes=[out_dtype],
dask_gufunc_kwargs={
"output_sizes": {target_dim: n_dst},
"allow_rechunk": True,
},
)
result = result.assign_coords({target_dim: target_coords})
return result


def apply_weights(
da: xr.DataArray,
weights: dict[Hashable, xr.DataArray],
skipna: bool,
nan_threshold: float,
) -> xr.DataArray:
"""Apply the weights over all regridding dimensions simultaneously with `xr.dot`."""
coords = list(weights.keys())
weight_arrays = list(weights.values())
"""Apply the regridding weights over all regridding dimensions.

if skipna:
valid_frac = xr.dot(
da.notnull(), *weight_arrays, dim=list(weights.keys()), optimize=True
)
Each per-axis weight is applied with a scipy CSR sparse-dense matmul
(:func:`_csr_apply_axis`); separability lets us contract one axis at a time.
A direct CSR matmul is fast and, unlike a multi-operand sparse ``xr.dot``,
does not depend on ``opt_einsum`` (which makes that contraction 20-100x
slower when absent). The weights stay compressed and the result is dense.
"""
coords = list(weights.keys())

da_regrid: xr.DataArray = xr.dot(
da.fillna(0), *weight_arrays, dim=list(weights.keys()), optimize=True
)
def apply_all(arr: xr.DataArray) -> xr.DataArray:
for coord, weight in weights.items():
arr = _csr_apply_axis(arr, weight, coord)
return arr

da_regrid = apply_all(da.fillna(0))
if skipna:
da_regrid /= valid_frac
valid_frac = apply_all(da.notnull())
# Divide by the valid fraction, avoiding 0/0 where a target cell has no
# valid source (those cells are masked to NaN by the threshold below).
da_regrid = da_regrid / valid_frac.where(valid_frac != 0, 1.0)
da_regrid = da_regrid.where(valid_frac >= get_valid_threshold(nan_threshold))

# apply_ufunc collapses/splits the new target dims, so restore the output
# chunking format_weights chose (from output_chunks / the input chunks).
rechunk = {
f"target_{coord}": weight.chunksizes[f"target_{coord}"]
for coord, weight in weights.items()
if weight.chunksizes.get(f"target_{coord}") is not None
}
if rechunk:
da_regrid = da_regrid.chunk(rechunk)

# Rename temporary coordinates and ensure original dimension order
coord_map = {f"target_{coord}": coord for coord in coords}
da_regrid = da_regrid.rename(coord_map).transpose(*da.dims)

return da_regrid
regridded: xr.DataArray = da_regrid.rename(coord_map)
regridded = regridded.transpose(*da.dims)
return regridded


def get_valid_threshold(nan_threshold: float) -> float:
Expand Down
71 changes: 71 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,77 @@ def test_conservative_nan_thresholds_against_xesmf():
xr.testing.assert_equal(data_regrid.isnull(), data_esmf.isnull())


def test_conservative_conserves_known_integral():
"""Gold-standard conservation for the axis-factored method.

``cos^2(lat) * (1.5 + sin(lon))`` integrates to ``4*pi`` over the unit
sphere. Regridded with the spherical correction (``latitude_coord``) onto a
co-extensive coarser global grid, the integral is conserved to the grid
quadrature floor when measured with INDEPENDENT analytic spherical cell
areas (sin-latitude bands times dlon) -- not the regridder's own weights --
and the source integral approaches the known value.

A spatially-varying field + independent areas + matched domains exercises
the real weights, unlike a constant field (which any row-normalized
regridder reproduces) or a self-area check (a row-sum identity).
"""

def centers(n, lo, hi): # global cell centers; edges land exactly on lo/hi
edges = np.linspace(lo, hi, n + 1)
return 0.5 * (edges[:-1] + edges[1:])

def analytic_area(n_lon, n_lat): # independent of the regridder
lat_e = np.deg2rad(np.linspace(-90, 90, n_lat + 1))
lon_e = np.deg2rad(np.linspace(-180, 180, n_lon + 1))
return np.diff(np.sin(lat_e))[:, None] * np.diff(lon_e)[None, :]

ns_lat, ns_lon, nt_lat, nt_lon = 120, 240, 40, 80
lat_s, lon_s = centers(ns_lat, -90, 90), centers(ns_lon, -180, 180)
lat_t, lon_t = centers(nt_lat, -90, 90), centers(nt_lon, -180, 180)
grid_lat, grid_lon = np.meshgrid(lat_s, lon_s, indexing="ij")
field = np.cos(np.deg2rad(grid_lat)) ** 2 * (1.5 + np.sin(np.deg2rad(grid_lon)))
da = xr.DataArray(field, dims=("lat", "lon"), coords={"lat": lat_s, "lon": lon_s})
target = xr.Dataset(coords={"lat": lat_t, "lon": lon_t})

out = (
da.regrid.conservative(target, latitude_coord="lat", skipna=False)
.transpose("lat", "lon")
.values
)
assert np.isfinite(out).all() # co-extensive global grids -> full coverage

i_src = float((da.values * analytic_area(ns_lon, ns_lat)).sum())
i_tgt = float((out * analytic_area(nt_lon, nt_lat)).sum())
np.testing.assert_allclose(i_tgt, i_src, rtol=1e-5) # mass conserved
np.testing.assert_allclose(i_src, 4 * np.pi, rtol=2e-3) # ~ the known value


def test_conservative_returns_dense_output():
"""Regression guard: the regridded result must be a dense ndarray, not a
``sparse.COO`` array.

Sparse weights are applied with a scipy CSR matmul (see
``methods.conservative.apply_weights``), which produces a dense result. A
``sparse.COO`` result here means the apply regressed to the sparse ``xr.dot``
path -- 20-100x slower without ``opt_einsum``, and a wasteful sparse
container for a dense field.
"""
lat = np.linspace(-89, 89, 60)
lon = np.linspace(-179, 179, 120)
da = xr.DataArray(
np.cos(np.deg2rad(lat))[:, None] * np.ones(lon.size),
dims=("lat", "lon"),
coords={"lat": lat, "lon": lon},
)
target = xr.Dataset(
coords={"lat": np.linspace(-88, 88, 30), "lon": np.linspace(-178, 178, 60)}
)
out = da.regrid.conservative(target, latitude_coord="lat")
assert isinstance(out.data, np.ndarray), (
f"expected dense ndarray, got {type(out.data).__name__}"
)


class TestCoordOrder:
@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
@pytest.mark.parametrize("dataarray", [True, False])
Expand Down