Skip to content

feat: Make quadmesh support bandwise 2D#1472

Merged
hoxbro merged 32 commits intomainfrom
feat_3d_quadmesh
Feb 6, 2026
Merged

feat: Make quadmesh support bandwise 2D#1472
hoxbro merged 32 commits intomainfrom
feat_3d_quadmesh

Conversation

@hoxbro
Copy link
Member

@hoxbro hoxbro commented Dec 12, 2025

Resolves #1463

This PR adds support for bandwise 2D data for quadmesh, where the "new" dimension is an independent dimension, e.g., same as looping over the 2D data, e.g., cvs.quadmesh(da, x='x', y='y').isel(band=0) == cvs.quadmesh(da.isel(band=0), x='x', y='y'). This is done in the following steps:

  1. Make the glyph_dispatch for data_libraries/xarray.py and data_libraries/dask_xarray.py understand the third dimension, and have the aggregate converted be a bandwise 2D array.
  2. Add two "extender" functions in qlyph/quadmesh.py to convert a 2D calculation to a 3D calculation. This is done in two ways: one for the CPU and one for the CUDA.
  3. Some small if-statements to handle the new structure, and some changes d[:2] to d[-2:].

For 2), this is done differently for CPU and CUDA.

For the CPU, we dynamically generate a function that iterates over the bands. This is done so we can handle different size of *aggs_and_cols, these can differ based on the reduction type rd.mean has two (summing and counting) where as rd.sum only has 1 summing. This function is then numba compiled and cached for future use.

For CUDA, we use cuda.streams(), which was mainly generated by Claude Code. The performance appears to be as expected: 3 bands achieve a slightly smaller than 3x time reduction compared to 2D+loop.

Performance

Results from ccd9319 (10 iterations)

Script

"""
Validate that bandwise 2D quadmesh produces identical results to running 2D quadmesh
in a loop over each band, and benchmark the performance difference.

Uses the same sizes and setup as datashader/tests/benchmarks/test_quadmesh.py
"""
import numpy as np
import xarray as xr
import datashader as ds
import datashader.reductions as rd
import time
import argparse
from tqdm import tqdm

DATA_SIZES = (256, 512, 1024, 2048, 4096)
CANVAS_SIZE = (1024, 1024)

try:
    import cupy as cp
except ImportError:
    cp = None

try:
    import dask.array as da
except ImportError:
    da = None

def get_array_module(backend):
    """Get the array module for the specified backend."""
    if backend == 'numpy':
        return np
    elif backend == 'cupy':
        return cp
    elif backend == 'dask':
        return da
    else:
        raise ValueError(f"Unknown backend: {backend}")


def create_test_data(size, nz, array_module=np, seed=42):
    """
    Create test data with fixed seed for reproducibility.

    Args:
        size: grid size (size x size)
        nz: number of bands
        array_module: array module to use (np, cupy, or dask.array)
        seed: random seed

    Returns:
        array of shape (nz, size, size) using the specified array module
    """
    # Always generate with numpy for consistency, then convert
    rng = np.random.default_rng(seed=seed)
    data_3d = np.zeros((nz, size, size))
    for z in range(nz):
        data_3d[z] = rng.random((size, size)) * 100 + z * 100

    # Convert to appropriate backend
    if array_module is da:
        # Optimal: keep all bands together, minimize spatial chunks
        chunks = (nz, min(size, 1024), min(size, 1024))
        return da.from_array(data_3d, chunks=chunks)
    else:
        return array_module.array(data_3d)



def create_raster_mesh(data_3d, size, nz):
    """Create raster quadmesh data (evenly spaced coordinates)."""
    lon_coords = np.linspace(3123580.0, 4250380.0, size)
    lat_coords = np.linspace(4376200.0, 3249400.0, size)

    data_xr = xr.DataArray(
        data_3d,
        dims=("band", "y", "x"),
        coords={
            "lon": ("x", lon_coords),
            "lat": ("y", lat_coords),
            "band": list(range(nz)),
        },
        name="test_data",
    )
    # Swap dims for raster (matches benchmark)
    data_xr = data_xr.swap_dims({"y": "lat", "x": "lon"})
    return data_xr, "lon", "lat"


def create_rectilinear_mesh(data_3d, size, nz, y_range, seed=42):
    """Create rectilinear quadmesh data (non-uniformly spaced 1D coordinates)."""
    rng = np.random.default_rng(seed=seed)
    lon_coords = np.linspace(3123580.0, 4250380.0, size)
    lat_coords = np.linspace(4376200.0, 3249400.0, size)

    # Add random deltas to make it non-uniform (matches benchmark)
    dy = (y_range[1] - y_range[0]) / size
    deltas = rng.uniform(-dy/2, dy/2, size)
    lat_coords = lat_coords + deltas

    data_xr = xr.DataArray(
        data_3d,
        dims=("band", "y", "x"),
        coords={
            "lon": ("x", lon_coords),
            "lat": ("y", lat_coords),
            "band": list(range(nz)),
        },
        name="test_data",
    )
    # Swap dims for rectilinear (matches benchmark)
    data_xr = data_xr.swap_dims({"y": "lat", "x": "lon"})
    return data_xr, "lon", "lat"


def create_curvilinear_mesh(data_3d, size, nz):
    """Create curvilinear quadmesh data (2D coordinate arrays)."""
    lon_1d = np.linspace(3123580.0, 4250380.0, size)
    lat_1d = np.linspace(4376200.0, 3249400.0, size)

    # Create base DataArray with dims (y, x, band) to match test setup
    data_base = xr.DataArray(
        data_3d.transpose(1, 2, 0),  # Transpose from (nz, size, size) to (size, size, nz)
        dims=("y", "x", "band"),
        coords={
            "x": lon_1d,
            "y": lat_1d,
            "band": list(range(nz)),
        },
        name="test_data",
    )

    # Broadcast to create 2D coordinate arrays (matches benchmark)
    lon_coord, lat_coord = xr.broadcast(data_base.x, data_base.y)

    # If data is dask array, convert coordinates to dask with matching chunks
    if hasattr(data_base.data, 'chunks'):
        # Get spatial chunks (ignore band dimension for coordinate chunks)
        y_chunks = data_base.data.chunks[0]
        x_chunks = data_base.data.chunks[1]

        # Convert coordinate arrays to dask with matching spatial chunks
        lon_coord.data = da.from_array(lon_coord.values, chunks=(y_chunks, x_chunks))
        lat_coord.data = da.from_array(lat_coord.values, chunks=(y_chunks, x_chunks))

    data_base = data_base.assign_coords({"lon": lon_coord, "lat": lat_coord})

    # Transpose to (band, y, x) for 3D processing
    # Note: transpose uses dimension names (y, x), not coordinate names (lon, lat)
    data_xr = data_base.transpose("band", "y", "x")
    return data_xr, "lon", "lat"


def test_correctness(mesh_type, reduction, size, nz=3, benchmark_iters=5, canvas_size=CANVAS_SIZE, backend='numpy'):
    """
    Test that 3D quadmesh matches 2D quadmesh run in a loop, and benchmark performance.

    Args:
        mesh_type: 'raster', 'rectilinear', or 'curvilinear'
        reduction: reduction function (e.g., rd.sum, rd.mean)
        size: grid size
        nz: number of bands
        benchmark_iters: number of iterations for benchmarking (after warmup)
        backend: 'numpy', 'cupy', or 'dask'

    Returns:
        tuple: (passed, time_3d_ms, time_2d_ms, speedup)
    """

    # Use coordinate system from benchmarks/test_quadmesh.py
    west = 3125000.0
    south = 3250000.0
    east = 4250000.0
    north = 4375000.0
    x_range = (west, east)
    y_range = (south, north)

    # Get the appropriate array module for this backend
    array_module = get_array_module(backend)

    # Create test data using the appropriate array module
    data_3d = create_test_data(size, nz, array_module=array_module, seed=42)

    # Create mesh data structure based on type
    if mesh_type == 'raster':
        data_xr, x_name, y_name = create_raster_mesh(data_3d, size, nz)
    elif mesh_type == 'rectilinear':
        data_xr, x_name, y_name = create_rectilinear_mesh(data_3d, size, nz, y_range, seed=42)
    elif mesh_type == 'curvilinear':
        data_xr, x_name, y_name = create_curvilinear_mesh(data_3d, size, nz)
    else:
        raise ValueError(f"Unknown mesh_type: {mesh_type}")

    # Setup canvas (use provided canvas_size)
    cvs = ds.Canvas(plot_width=canvas_size[0], plot_height=canvas_size[1],
                    x_range=x_range, y_range=y_range)

    # Method 1: Run 3D quadmesh directly
    # Warmup run
    result_3d = cvs.quadmesh(data_xr, x=x_name, y=y_name, agg=reduction("test_data"))

    # Benchmark runs
    times_3d = []
    for _ in range(benchmark_iters):
        t0 = time.perf_counter()
        _ = cvs.quadmesh(data_xr, x=x_name, y=y_name, agg=reduction("test_data"))
        t1 = time.perf_counter()
        times_3d.append((t1 - t0) * 1000)  # Convert to ms

    time_3d_ms = np.mean(times_3d)

    # Method 2: Run 2D quadmesh in a loop for each band
    # Function to run 2D loop
    def run_2d_loop():
        results = []
        for z in range(nz):
            # Extract single band
            data_2d = data_xr.isel(band=z)
            result_2d = cvs.quadmesh(data_2d, x=x_name, y=y_name, agg=reduction("test_data"))

            # Get underlying data array (use .data to avoid CuPy->NumPy conversion)
            vals = result_2d.data

            # For dask, need to compute before appending
            if hasattr(vals, 'compute'):
                vals = vals.compute()

            results.append(vals)

        # Stack using appropriate array module
        if backend == 'cupy' and len(results) > 0 and isinstance(results[0], cp.ndarray):
            return cp.stack(results, axis=0)
        else:
            return np.stack(results, axis=0)

    # Warmup run
    result_2d_stacked = run_2d_loop()

    # Benchmark runs
    times_2d = []
    for _ in range(benchmark_iters):
        t0 = time.perf_counter()
        _ = run_2d_loop()
        t1 = time.perf_counter()
        times_2d.append((t1 - t0) * 1000)  # Convert to ms

    time_2d_ms = np.mean(times_2d)

    # Compare results
    speedup = time_2d_ms / time_3d_ms if time_3d_ms > 0 else 0.0

    if result_3d.shape != result_2d_stacked.shape:
        if verbose:
            print("   ❌ FAIL: Shape mismatch!")
            print(f"      3D: {result_3d.shape}")
            print(f"      2D: {result_2d_stacked.shape}")
        return False, time_3d_ms, time_2d_ms, speedup

    # Compare values (accounting for NaN)
    # Get underlying data arrays (use .data to avoid CuPy->NumPy conversion)
    result_3d_vals = result_3d.data
    if backend == "dask" and hasattr(result_3d_vals, 'compute'):
        result_3d_vals = result_3d_vals.compute()

    # Convert to numpy for comparison if needed
    if backend == 'cupy' and isinstance(result_3d_vals, cp.ndarray):
        result_3d_vals = cp.asnumpy(result_3d_vals)
    if backend == 'cupy' and isinstance(result_2d_stacked, cp.ndarray):
        result_2d_vals = cp.asnumpy(result_2d_stacked)
    else:
        result_2d_vals = result_2d_stacked

    # Check NaN locations match
    nan_mask_3d = np.isnan(result_3d_vals)
    nan_mask_2d = np.isnan(result_2d_vals)

    if not np.array_equal(nan_mask_3d, nan_mask_2d):
        return False, time_3d_ms, time_2d_ms, speedup

    # Compare non-NaN values
    valid_mask = ~nan_mask_3d
    diff = np.abs(result_3d_vals[valid_mask] - result_2d_vals[valid_mask])
    max_diff = diff.max() if diff.size > 0 else 0

    # Use relative tolerance for floating point comparison
    atol = 1e-10
    rtol = 1e-10
    passed = np.allclose(result_3d_vals[valid_mask], result_2d_vals[valid_mask],
                         atol=atol, rtol=rtol)
    return passed, time_3d_ms, time_2d_ms, speedup


def main(args):
    # Test configurations
    nz = args.nz
    test_configs = []

    # Build test matrix
    for backend in args.backends:
        for mesh_type in args.mesh_types:
            for size in args.sizes:
                for red_name in args.reductions:
                    test_configs.append((mesh_type, getattr(rd, red_name), size, nz, backend))

    # Run tests with progress bar
    results = []
    pbar = tqdm(test_configs, desc="Running tests", leave=False)

    for mesh_type, reduction, size, nz, backend in pbar:
        desc = f"{backend:<10} | {mesh_type:<13} | {size:4d} | {reduction.__name__:8}"
        pbar.set_description(desc)

        passed, time_3d, time_2d, speedup = test_correctness(
            mesh_type, reduction, size, nz,
            benchmark_iters=args.benchmark_iters,
            backend=backend
        )
        results.append((mesh_type, reduction.__name__, size, nz, backend, passed, time_3d, time_2d, speedup))

    pbar.close()

    # Header
    print("| Backend | Type         | Reduction | Size | 3D (ms) | 2D (ms) | Speedup | Status |")
    print("|:--------|:-------------|:----------|-----:|--------:|--------:|--------:|-------:|")

    # Group by backend
    for backend in args.backends:
        backend_results = [r for r in results if r[4] == backend]
        if not backend_results:
            continue

        for mesh_type, red_name, size, nz, bknd, status, time_3d, time_2d, speedup in backend_results:
            status_str = "✅" if status else "❌"
            print(f"| {backend:<7} | {mesh_type:<12} | {red_name:<9} | {size:>4} | {time_3d:>7.2f} | {time_2d:>7.2f} | {speedup:>6.2f}x | {status_str:>5} |")

    return not all(r[5] for r in results)


if __name__ == "__main__":
    import sys

    parser = argparse.ArgumentParser(
        description="Validate 3D quadmesh correctness and benchmark performance"
    )
    parser.add_argument(
        '--mesh-types',
        nargs='+',
        choices=['raster', 'rectilinear', 'curvilinear'],
        default=['raster', 'rectilinear', 'curvilinear'],
        help='Quadmesh types to test (default: all)'
    )
    parser.add_argument(
        '--sizes',
        nargs='+',
        type=int,
        default=list(DATA_SIZES),
        help=f'Data sizes to test (default: {list(DATA_SIZES)})'
    )
    parser.add_argument(
        '--reductions',
        nargs='+',
        choices=['sum', 'mean'],
        default=['sum', 'mean'],
        help='Reductions to test (default: sum mean)'
    )
    parser.add_argument(
        '--backends',
        nargs='+',
        choices=['numpy', 'cupy', 'dask'],
        default=None,
        help='Array backends to test (default: all available)'
    )
    parser.add_argument(
        '--nz',
        type=int,
        default=3,
        help='Number of bands (default: 3)'
    )
    parser.add_argument(
        '--benchmark-iters',
        type=int,
        default=5,
        help='Number of benchmark iterations (default: 5)'
    )

    args = parser.parse_args()

    sys.exit(main(args))

Backend Type Reduction Size 3D (ms) 2D (ms) Speedup Status
numpy raster sum 256 2.14 6.35 2.97x
numpy raster mean 256 2.00 6.11 3.06x
numpy raster sum 512 1.90 6.35 3.34x
numpy raster mean 512 1.87 6.58 3.51x
numpy raster sum 1024 2.34 6.45 2.75x
numpy raster mean 1024 2.24 6.48 2.89x
numpy raster sum 2048 6.04 11.77 1.95x
numpy raster mean 2048 13.78 20.67 1.50x
numpy raster sum 4096 16.90 20.13 1.19x
numpy raster mean 4096 22.52 30.03 1.33x
numpy rectilinear sum 256 5.02 12.42 2.47x
numpy rectilinear mean 256 15.37 20.51 1.33x
numpy rectilinear sum 512 6.55 20.69 3.16x
numpy rectilinear mean 512 20.16 33.97 1.69x
numpy rectilinear sum 1024 12.90 50.04 3.88x
numpy rectilinear mean 1024 28.83 64.50 2.24x
numpy rectilinear sum 2048 25.86 161.17 6.23x
numpy rectilinear mean 2048 75.88 203.67 2.68x
numpy rectilinear sum 4096 124.79 743.84 5.96x
numpy rectilinear mean 4096 311.92 896.31 2.87x
numpy curvilinear sum 256 18.27 36.62 2.00x
numpy curvilinear mean 256 22.51 49.38 2.19x
numpy curvilinear sum 512 26.33 66.01 2.51x
numpy curvilinear mean 512 26.80 70.30 2.62x
numpy curvilinear sum 1024 42.88 125.15 2.92x
numpy curvilinear mean 1024 52.55 129.08 2.46x
numpy curvilinear sum 2048 131.70 348.74 2.65x
numpy curvilinear mean 2048 140.77 380.37 2.70x
numpy curvilinear sum 4096 443.07 1291.16 2.91x
numpy curvilinear mean 4096 499.67 1453.31 2.91x
dask raster sum 256 10.05 17.97 1.79x
dask raster mean 256 9.68 17.92 1.85x
dask raster sum 512 10.13 19.01 1.88x
dask raster mean 512 9.89 18.58 1.88x
dask raster sum 1024 10.16 19.52 1.92x
dask raster mean 1024 10.49 19.72 1.88x
dask raster sum 2048 50.23 66.40 1.32x
dask raster mean 2048 51.94 65.87 1.27x
dask raster sum 4096 170.34 226.65 1.33x
dask raster mean 4096 144.98 186.76 1.29x
dask rectilinear sum 256 16.78 27.44 1.64x
dask rectilinear mean 256 24.45 33.97 1.39x
dask rectilinear sum 512 18.49 35.71 1.93x
dask rectilinear mean 512 30.05 47.96 1.60x
dask rectilinear sum 1024 20.61 62.71 3.04x
dask rectilinear mean 1024 38.74 78.85 2.04x
dask rectilinear sum 2048 75.24 107.07 1.42x
dask rectilinear mean 2048 63.59 111.73 1.76x
dask rectilinear sum 4096 189.12 269.52 1.43x
dask rectilinear mean 4096 198.30 251.29 1.27x
dask curvilinear sum 256 36.25 74.74 2.06x
dask curvilinear mean 256 39.53 80.47 2.04x
dask curvilinear sum 512 43.01 97.47 2.27x
dask curvilinear mean 512 48.71 105.39 2.16x
dask curvilinear sum 1024 71.30 181.37 2.54x
dask curvilinear mean 1024 79.01 194.01 2.46x
dask curvilinear sum 2048 133.64 271.41 2.03x
dask curvilinear mean 2048 123.59 274.05 2.22x
dask curvilinear sum 4096 396.11 773.43 1.95x
dask curvilinear mean 4096 393.05 744.88 1.90x
cupy raster sum 256 2.28 5.35 2.35x
cupy raster mean 256 2.17 5.26 2.42x
cupy raster sum 512 2.21 5.21 2.36x
cupy raster mean 512 2.20 5.46 2.49x
cupy raster sum 1024 2.31 5.22 2.26x
cupy raster mean 1024 2.29 5.23 2.29x
cupy raster sum 2048 3.40 7.23 2.13x
cupy raster mean 2048 4.47 8.23 1.84x
cupy raster sum 4096 3.80 7.85 2.07x
cupy raster mean 4096 4.81 8.52 1.77x
cupy rectilinear sum 256 6.16 14.58 2.37x
cupy rectilinear mean 256 7.39 16.13 2.18x
cupy rectilinear sum 512 6.08 14.47 2.38x
cupy rectilinear mean 512 7.01 15.77 2.25x
cupy rectilinear sum 1024 6.19 14.47 2.34x
cupy rectilinear mean 1024 7.10 15.89 2.24x
cupy rectilinear sum 2048 6.24 14.76 2.37x
cupy rectilinear mean 2048 7.21 16.20 2.25x
cupy rectilinear sum 4096 7.12 16.17 2.27x
cupy rectilinear mean 4096 8.65 18.00 2.08x
cupy curvilinear sum 256 6.69 15.71 2.35x
cupy curvilinear mean 256 7.66 17.18 2.24x
cupy curvilinear sum 512 9.64 31.65 3.28x
cupy curvilinear mean 512 10.86 26.49 2.44x
cupy curvilinear sum 1024 13.93 33.20 2.38x
cupy curvilinear mean 1024 14.16 36.97 2.61x
cupy curvilinear sum 2048 33.14 87.73 2.65x
cupy curvilinear mean 2048 33.58 97.35 2.90x
cupy curvilinear sum 4096 98.82 271.52 2.75x
cupy curvilinear mean 4096 101.95 298.14 2.92x

GPU Test

image

@hoxbro hoxbro added this to the v0.19.0 milestone Dec 12, 2025
@codspeed-hq
Copy link

codspeed-hq bot commented Dec 12, 2025

CodSpeed Performance Report

Merging this PR will improve performance by 11.51%

Comparing feat_3d_quadmesh (4021bed) with main (5f5bbee)1

Summary

⚡ 1 improved benchmark
✅ 51 untouched benchmarks

Performance Changes

Mode Benchmark BASE HEAD Efficiency
Simulation test_dask_raster[8192] 3.9 s 3.5 s +11.51%

Footnotes

  1. No successful run was found on main (9edecf4) during the generation of this report, so 5f5bbee was used instead as the comparison base. There might be some changes unrelated to this pull request in this report.

@codecov
Copy link

codecov bot commented Dec 12, 2025

Codecov Report

❌ Patch coverage is 84.95935% with 37 lines in your changes missing coverage. Please review.
✅ Project coverage is 88.38%. Comparing base (f641b57) to head (4021bed).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
datashader/glyphs/quadmesh.py 76.03% 29 Missing ⚠️
datashader/core.py 0.00% 3 Missing ⚠️
datashader/data_libraries/xarray.py 93.93% 2 Missing ⚠️
datashader/tests/test_quadmesh.py 95.34% 2 Missing ⚠️
datashader/tests/test_macros.py 83.33% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1472      +/-   ##
==========================================
- Coverage   88.43%   88.38%   -0.05%     
==========================================
  Files          96       97       +1     
  Lines       19183    19432     +249     
==========================================
+ Hits        16964    17175     +211     
- Misses       2219     2257      +38     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jbednar jbednar changed the title enh: Make quadmesh support 3D enh: Make quadmesh support RGB Dec 13, 2025
@hoxbro hoxbro marked this pull request as draft January 9, 2026 13:58
@hoxbro
Copy link
Member Author

hoxbro commented Jan 9, 2026

Have pushed some changes which avoid doing a basic loop.

Still need some cleanup and looking at the code with fresher eyes. Also, some basic support for dask arrays and cupy is likely needed...

Roughly what has been done is (will try to write more details in the original post, closer to this PR being mergeable):

  1. Make xarray dispatch account for 3D data.
  2. Make a 3D function converter for the 2D CPU functions (and cache it). This basically just loops over the bands using prange, which is handled at a lower level than before. I initially played around with the expand_varargs macro, but found it to have a lot of duplicated code.

Some initial benchmark, haven't really reviewed the benchmark code as it is AI generated. Also need to double-check that the tests actually go down the right codepath.

Benchmark script

"""
Validate that 3D quadmesh produces identical results to running 2D quadmesh
in a loop over each band, and benchmark the performance difference.

Uses the same sizes and setup as datashader/tests/benchmarks/test_quadmesh.py
"""
import numpy as np
import xarray as xr
import datashader as ds
from datashader.reductions import sum as ds_sum, mean, count
import time

# Sizes from benchmarks/test_quadmesh.py
DATA_SIZES = (256, 512, 1024, 2048, 4096, 8192)
CANVAS_SIZE = (1024, 1024)

# Test different quadmesh types and reductions
def test_correctness(mesh_type, reduction, size, nz=3, benchmark_iters=5, canvas_size=CANVAS_SIZE):
    """
    Test that 3D quadmesh matches 2D quadmesh run in a loop, and benchmark performance.

    Args:
        mesh_type: 'raster', 'rectilinear', or 'curvilinear'
        reduction: reduction function (e.g., ds_sum, mean)
        size: grid size
        nz: number of bands
        benchmark_iters: number of iterations for benchmarking (after warmup)

    Returns:
        tuple: (passed, time_3d_ms, time_2d_ms, speedup)
    """
    print(f"\n{'='*60}")
    print(f"Testing {mesh_type} quadmesh with {reduction.__name__} reduction")
    print(f"Data size: {nz} bands × {size}×{size}, Canvas: {canvas_size[0]}×{canvas_size[1]}")
    print(f"{'='*60}")

    # Use coordinate system from benchmarks/test_quadmesh.py
    west = 3125000.0
    south = 3250000.0
    east = 4250000.0
    north = 4375000.0
    x_range = (west, east)
    y_range = (south, north)

    # Create test data with values that make it easy to verify correctness
    # Each band has different values: band i has values from i*100 to i*100+size*size
    rng = np.random.default_rng(seed=42)  # Fixed seed for reproducibility
    data_3d = np.zeros((nz, size, size))
    for z in range(nz):
        data_3d[z] = rng.random((size, size)) * 100 + z * 100

    if mesh_type == 'raster':
        # Evenly spaced coordinates (matches benchmark setup)
        lon_coords = np.linspace(3123580.0, 4250380.0, size)
        lat_coords = np.linspace(4376200.0, 3249400.0, size)

        data_xr = xr.DataArray(
            data_3d,
            dims=("band", "y", "x"),
            coords={
                "lon": ("x", lon_coords),
                "lat": ("y", lat_coords),
                "band": list(range(nz)),
            },
            name="test_data",
        )
        # Swap dims for raster (matches benchmark)
        data_xr = data_xr.swap_dims({"y": "lat", "x": "lon"})
        x_name, y_name = "lon", "lat"

    elif mesh_type == 'rectilinear':
        # Non-uniformly spaced 1D coordinates (matches benchmark setup)
        lon_coords = np.linspace(3123580.0, 4250380.0, size)
        lat_coords = np.linspace(4376200.0, 3249400.0, size)

        # Add random deltas to make it non-uniform (matches benchmark)
        dy = (y_range[1] - y_range[0]) / size
        deltas = rng.uniform(-dy/2, dy/2, size)
        lat_coords = lat_coords + deltas

        data_xr = xr.DataArray(
            data_3d,
            dims=("band", "y", "x"),
            coords={
                "lon": ("x", lon_coords),
                "lat": ("y", lat_coords),
                "band": list(range(nz)),
            },
            name="test_data",
        )
        # Swap dims for rectilinear (matches benchmark)
        data_xr = data_xr.swap_dims({"y": "lat", "x": "lon"})
        x_name, y_name = "lon", "lat"

    elif mesh_type == 'curvilinear':
        # 2D coordinate arrays (matches benchmark setup with broadcast)
        lon_1d = np.linspace(3123580.0, 4250380.0, size)
        lat_1d = np.linspace(4376200.0, 3249400.0, size)

        # Create base DataArray with dims (y, x, band) to match test setup
        data_base = xr.DataArray(
            data_3d.transpose(1, 2, 0),  # Transpose from (nz, size, size) to (size, size, nz)
            dims=("y", "x", "band"),
            coords={
                "x": lon_1d,
                "y": lat_1d,
                "band": list(range(nz)),
            },
            name="test_data",
        )

        # Broadcast to create 2D coordinate arrays (matches benchmark)
        lon_coord, lat_coord = xr.broadcast(data_base.x, data_base.y)
        data_base = data_base.assign_coords({"lon": lon_coord, "lat": lat_coord})

        # Transpose to (band, y, x) for 3D processing
        data_xr = data_base.transpose(..., "y", "x")
        x_name, y_name = "lon", "lat"
    else:
        raise ValueError(f"Unknown mesh_type: {mesh_type}")

    # Setup canvas (use provided canvas_size)
    cvs = ds.Canvas(plot_width=canvas_size[0], plot_height=canvas_size[1],
                    x_range=x_range, y_range=y_range)

    # Method 1: Run 3D quadmesh directly
    print("\n1. Running 3D quadmesh (optimized)...")

    # Warmup run
    result_3d = cvs.quadmesh(data_xr, x=x_name, y=y_name, agg=reduction("test_data"))
    print(f"   Result shape: {result_3d.shape}")

    # Benchmark runs
    print(f"   Benchmarking ({benchmark_iters} iterations)...")
    times_3d = []
    for _ in range(benchmark_iters):
        t0 = time.perf_counter()
        _ = cvs.quadmesh(data_xr, x=x_name, y=y_name, agg=reduction("test_data"))
        t1 = time.perf_counter()
        times_3d.append((t1 - t0) * 1000)  # Convert to ms

    time_3d_ms = np.mean(times_3d)
    print(f"   Average time: {time_3d_ms:.3f} ms (±{np.std(times_3d):.3f} ms)")

    # Method 2: Run 2D quadmesh in a loop for each band
    print("\n2. Running 2D quadmesh in loop (reference)...")

    # Function to run 2D loop
    def run_2d_loop():
        results = []
        for z in range(nz):
            # Extract single band
            # For curvilinear, slice from transposed data to ensure lon/lat coords have consistent dims
            data_2d = data_xr.isel(band=z)
            result_2d = cvs.quadmesh(data_2d, x=x_name, y=y_name, agg=reduction("test_data"))
            results.append(result_2d.values)
        return np.stack(results, axis=0)

    # Warmup run
    result_2d_stacked = run_2d_loop()
    print(f"   Stacked shape: {result_2d_stacked.shape}")

    # Benchmark runs
    print(f"   Benchmarking ({benchmark_iters} iterations)...")
    times_2d = []
    for _ in range(benchmark_iters):
        t0 = time.perf_counter()
        _ = run_2d_loop()
        t1 = time.perf_counter()
        times_2d.append((t1 - t0) * 1000)  # Convert to ms

    time_2d_ms = np.mean(times_2d)
    print(f"   Average time: {time_2d_ms:.3f} ms (±{np.std(times_2d):.3f} ms)")

    # Compare results
    print("\n3. Comparing results...")
    speedup = time_2d_ms / time_3d_ms if time_3d_ms > 0 else 0.0

    if result_3d.shape != result_2d_stacked.shape:
        print(f"   ❌ FAIL: Shape mismatch!")
        print(f"      3D: {result_3d.shape}")
        print(f"      2D: {result_2d_stacked.shape}")
        return False, time_3d_ms, time_2d_ms, speedup

    # Compare values (accounting for NaN)
    result_3d_vals = result_3d.values
    result_2d_vals = result_2d_stacked

    # Check NaN locations match
    nan_mask_3d = np.isnan(result_3d_vals)
    nan_mask_2d = np.isnan(result_2d_vals)

    if not np.array_equal(nan_mask_3d, nan_mask_2d):
        print(f"   ❌ FAIL: NaN locations don't match!")
        print(f"      3D NaN count: {nan_mask_3d.sum()}")
        print(f"      2D NaN count: {nan_mask_2d.sum()}")
        return False, time_3d_ms, time_2d_ms, speedup

    # Compare non-NaN values
    valid_mask = ~nan_mask_3d
    diff = np.abs(result_3d_vals[valid_mask] - result_2d_vals[valid_mask])
    max_diff = diff.max() if diff.size > 0 else 0

    # Use relative tolerance for floating point comparison
    atol = 1e-10
    rtol = 1e-10
    close = np.allclose(result_3d_vals[valid_mask], result_2d_vals[valid_mask],
                       atol=atol, rtol=rtol)

    if close:
        print(f"   ✅ PASS: Results match perfectly!")
        print(f"      Max absolute difference: {max_diff:.2e}")
        print(f"      Valid pixels: {valid_mask.sum()}/{valid_mask.size}")
        print(f"\n4. Performance comparison:")
        print(f"   3D optimized: {time_3d_ms:.3f} ms")
        print(f"   2D loop:      {time_2d_ms:.3f} ms")
        print(f"   Speedup:      {speedup:.2f}x")
        return True, time_3d_ms, time_2d_ms, speedup
    else:
        print(f"   ❌ FAIL: Results don't match!")
        print(f"      Max absolute difference: {max_diff:.2e}")
        print(f"      Valid pixels: {valid_mask.sum()}/{valid_mask.size}")

        # Show some examples of differences
        diff_locs = np.where(diff > atol + rtol * np.abs(result_2d_vals[valid_mask]))
        if len(diff_locs[0]) > 0:
            print(f"      First 5 mismatches:")
            for idx in range(min(5, len(diff_locs[0]))):
                i = diff_locs[0][idx]
                print(f"        3D: {result_3d_vals[valid_mask][i]:.6f}, "
                      f"2D: {result_2d_vals[valid_mask][i]:.6f}, "
                      f"diff: {diff[i]:.2e}")
        return False, time_3d_ms, time_2d_ms, speedup


def main():
    """Run all validation tests."""
    print("\n" + "="*80)
    print("3D QUADMESH CORRECTNESS VALIDATION & BENCHMARKS")
    print("="*80)
    print("\nThis validates that the optimized 3D quadmesh (which computes")
    print("coordinates once and loops over bands in numba) produces identical")
    print("results to running 2D quadmesh separately for each band.")
    print(f"\nUsing benchmark sizes from test_quadmesh.py: {DATA_SIZES}")
    print(f"Canvas size: {CANVAS_SIZE}")

    # Test configurations organized by quadmesh type
    # Use 3 bands to simulate RGB data (the main use case)
    nz = 3
    test_configs = []

    # All quadmesh types now support 3D
    for mesh_type in ['raster', 'rectilinear', 'curvilinear']:
        for size in DATA_SIZES:
            # Test with sum (simple reduction)
            test_configs.append((mesh_type, ds_sum, size, nz))
            test_configs.append((mesh_type, mean, size, nz))

    results = []
    for mesh_type, reduction, size, nz in test_configs:
        passed, time_3d, time_2d, speedup = test_correctness(mesh_type, reduction, size, nz)
        results.append((mesh_type, reduction.__name__, size, nz, passed, time_3d, time_2d, speedup))

    # Summary - organize by quadmesh type
    print("\n" + "="*80)
    print("SUMMARY - ORGANIZED BY QUADMESH TYPE")
    print("="*80)
    total = len(results)
    passed_count = sum(1 for r in results if r[4])

    # Group results by mesh type
    for mesh_type in ['raster', 'rectilinear', 'curvilinear']:
        type_results = [r for r in results if r[0] == mesh_type]
        if not type_results:
            continue

        print(f"\n{mesh_type.upper()} QUADMESH:")
        print(f"{'Size':<12} {'Reduction':<10} {'Status':<10} {'3D (ms)':<12} {'2D (ms)':<12} {'Speedup':<10}")
        print("-" * 80)

        for _, red_name, size, nz, status, time_3d, time_2d, speedup in type_results:
            status_str = "✅ PASS" if status else "❌ FAIL"
            size_str = f"{size}×{size}"
            print(f"{size_str:<12} {red_name:<10} {status_str:<10} {time_3d:>10.1f}   {time_2d:>10.1f}   {speedup:>8.2f}x")

        # Calculate average speedup for this type
        type_speedups = [speedup for _, _, _, _, status, _, _, speedup in type_results if status]
        if type_speedups:
            avg_speedup = np.mean(type_speedups)
            print(f"\n  Average speedup for {mesh_type}: {avg_speedup:.2f}x")

    print("\n" + "="*80)
    print(f"Total: {passed_count}/{total} tests passed")

    # Calculate overall average speedup for passed tests
    if passed_count > 0:
        avg_speedup = np.mean([speedup for _, _, _, _, status, _, _, speedup in results if status])
        print(f"Overall average speedup: {avg_speedup:.2f}x")

    if passed_count == total:
        print("\n🎉 All tests passed! 3D quadmesh optimization is working correctly.")
        return 0
    else:
        print(f"\n⚠️  {total - passed_count} test(s) failed!")
        return 1


if __name__ == "__main__":
    import sys
    sys.exit(main())

================================================================================
SUMMARY - ORGANIZED BY QUADMESH TYPE
================================================================================

RASTER QUADMESH:
Size         Reduction  Status     3D (ms)      2D (ms)      Speedup   
--------------------------------------------------------------------------------
256×256      sum        ✅ PASS            7.6          7.1       0.93x
256×256      mean       ✅ PASS            6.8          8.3       1.23x
512×512      sum        ✅ PASS            6.6          6.8       1.03x
512×512      mean       ✅ PASS            6.7          8.5       1.27x
1024×1024    sum        ✅ PASS            6.6          7.8       1.17x
1024×1024    mean       ✅ PASS            9.2          7.5       0.81x
2048×2048    sum        ✅ PASS           11.3         10.4       0.92x
2048×2048    mean       ✅ PASS           19.2         20.4       1.07x
4096×4096    sum        ✅ PASS           18.8         21.0       1.12x
4096×4096    mean       ✅ PASS           27.2         29.0       1.07x
8192×8192    sum        ✅ PASS           55.5         59.4       1.07x
8192×8192    mean       ✅ PASS           63.1         70.1       1.11x

  Average speedup for raster: 1.06x

RECTILINEAR QUADMESH:
Size         Reduction  Status     3D (ms)      2D (ms)      Speedup   
--------------------------------------------------------------------------------
256×256      sum        ✅ PASS            5.2         13.2       2.53x
256×256      mean       ✅ PASS           15.2         21.1       1.39x
512×512      sum        ✅ PASS            6.7         20.8       3.12x
512×512      mean       ✅ PASS           22.7         33.9       1.50x
1024×1024    sum        ✅ PASS           10.6         48.4       4.56x
1024×1024    mean       ✅ PASS           29.2         63.2       2.17x
2048×2048    sum        ✅ PASS           25.7        159.9       6.22x
2048×2048    mean       ✅ PASS           78.8        205.0       2.60x
4096×4096    sum        ✅ PASS          124.9        748.5       5.99x
4096×4096    mean       ✅ PASS          315.3        923.5       2.93x
8192×8192    sum        ✅ PASS          971.0       3377.3       3.48x
8192×8192    mean       ✅ PASS         1296.6       3784.4       2.92x

  Average speedup for rectilinear: 3.28x

CURVILINEAR QUADMESH:
Size         Reduction  Status     3D (ms)      2D (ms)      Speedup   
--------------------------------------------------------------------------------
256×256      sum        ✅ PASS           20.8         44.7       2.15x
256×256      mean       ✅ PASS           25.4         46.7       1.84x
512×512      sum        ✅ PASS           20.2         61.3       3.03x
512×512      mean       ✅ PASS           26.8         72.1       2.69x
1024×1024    sum        ✅ PASS           39.8        128.9       3.24x
1024×1024    mean       ✅ PASS           56.1        147.7       2.63x
2048×2048    sum        ✅ PASS          126.8        360.4       2.84x
2048×2048    mean       ✅ PASS          147.4        398.4       2.70x
4096×4096    sum        ✅ PASS          449.0       1298.1       2.89x
4096×4096    mean       ✅ PASS          502.3       1457.6       2.90x
8192×8192    sum        ✅ PASS         1714.0       5055.6       2.95x
8192×8192    mean       ✅ PASS         1931.7       5653.7       2.93x

  Average speedup for curvilinear: 2.73x

================================================================================
Total: 36/36 tests passed
Overall average speedup: 2.36x

🎉 All tests passed! 3D quadmesh optimization is working correctly.

hoxbro and others added 6 commits January 10, 2026 10:41
Implement GPU parallelization for 3D quadmesh using CUDA streams,
mirroring the CPU factory pattern with prange.

Changes:
- Add _CUDAStreamPool class for managing reusable CUDA streams
- Add _make_3d_from_2d_cuda() factory that returns factory_3d(grid_shape),
  matching the CPU pattern where runtime parameters (n_arrays for CPU,
  grid_shape for GPU) are passed at usage site
- Apply to all three quadmesh types:
  * QuadMeshRaster: upsample_cuda_3d and downsample_cuda_3d
  * QuadMeshRectilinear: extend_cuda_3d
  * QuadMeshCurvilinear: extend_cuda_3d
- Fix CuPy compatibility:
  * Use .data instead of .values to preserve CuPy arrays
  * Use xp.clip() instead of np.clip() for array module compatibility
- Remove NotImplementedError for 3D CUDA support
- Z-slices execute in parallel across CUDA streams (up to 16 concurrent)

Implementation pattern:
  CPU:  do_extend = extend_cpu_3d(n_arrays=len(aggs_and_cols))
  GPU:  do_extend = extend_cuda_3d(grid_shape=(grid_w, grid_h))

Test changes:
- Fix rectilinear coord generation to use xp.array() for CuPy
- Use close=True for floating-point comparison in downsample cases

Results:
- All GPU tests pass: 6/6 ✅
- All CPU tests pass: 12/12 ✅ (no regression)
- Parallel execution of independent z-slices on GPU

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@hoxbro
Copy link
Member Author

hoxbro commented Jan 10, 2026

Got (a lot of) help by Claude to add GPU support:

Test ran locally:
image

@hoxbro

This comment was marked as outdated.

@hoxbro hoxbro marked this pull request as ready for review January 12, 2026 10:39
dask_glyph_dispatch = Dispatcher()


def _flatten_dask_keys(keys_array):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use or vendor dask.base.flatten (I think that's the import path

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 6ddb6f7.

As I may have overlooked some nuances, you are welcome to review or test the PR.

@hoxbro hoxbro changed the title enh: Make quadmesh support RGB enh: Make quadmesh support bandwise 2D Jan 13, 2026
@hoxbro hoxbro changed the title enh: Make quadmesh support bandwise 2D feat: Make quadmesh support bandwise 2D Jan 14, 2026

This is the GPU equivalent of _make_3d_from_2d, creating wrappers that launch
2D CUDA kernels in parallel streams for each z-slice. This achieves true
parallelism on GPU, similar to how prange provides parallelism on CPU.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat!

Copy link
Member

@philippjfr philippjfr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some clarification questions but as far as I'm able to tell this looks great. Checking the 3D implementation against the 2D implementation is fine as long as the existing test coverage for the 2D case is good (which I didn't confirm).

@hoxbro hoxbro enabled auto-merge (squash) February 6, 2026 13:01
@hoxbro hoxbro disabled auto-merge February 6, 2026 13:01
@hoxbro hoxbro enabled auto-merge (squash) February 6, 2026 13:01
@hoxbro hoxbro merged commit c18c685 into main Feb 6, 2026
14 checks passed
@hoxbro hoxbro deleted the feat_3d_quadmesh branch February 6, 2026 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Quadmesh with 3D (RGB) array

3 participants