Skip to content

Commit 9138bbe

Browse files
jpan84philipc2Copilot
authored
Add Azimuthal Averaging (#1255)
* created docstring for azimuthal_mean * wrote azimuthal mean computation in public function * draft azimuthal mean ready to test * fixed typos, return hit count for radial bins * made azimuthal mean more robust to axis ordering * added central coord as attribute in output UxDataArray of azimuthal_mean * run pre-commit * update parameters and set default values to nan * add tests, update to use kdtree * update api * update type hint * Update docs/api.rst Co-authored-by: Copilot <[email protected]> * Update docs/api.rst Co-authored-by: Copilot <[email protected]> * Update test/core/test_azimuthal.py Co-authored-by: Copilot <[email protected]> * return an xr.DataArray, update hit counts * add validation * fix api.rst --------- Co-authored-by: Philip Chmielowiec <[email protected]> Co-authored-by: Philip Chmielowiec <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 71b858f commit 9138bbe

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed

docs/api.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,17 @@ on each face.
453453
UxDataArray.topological_all
454454
UxDataArray.topological_any
455455

456+
Azimuthal
457+
~~~~~~~~~
458+
459+
Azimuthal aggregations apply an aggregation (i.e. averaging) along circles of constant great-circle distance from a specified point on the sphere.
460+
461+
462+
.. autosummary::
463+
:toctree: generated/
464+
465+
UxDataArray.azimuthal_mean
466+
456467
Zonal Average
457468
~~~~~~~~~~~~~
458469
.. autosummary::

test/core/test_azimuthal.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
import uxarray as ux
3+
import numpy as np
4+
5+
6+
7+
def test_gaussian(gridpath, datasetpath):
8+
uxds = ux.open_dataset(
9+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
10+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
11+
)
12+
13+
res = uxds['gaussian'].azimuthal_mean(center_coord=(45, 0), outer_radius=2, radius_step=0.5)
14+
15+
# Expects decreasing values from center
16+
valid_vals = res[1:]
17+
18+
np.testing.assert_array_less(
19+
valid_vals.diff("radius").values, 1e-12
20+
)
21+
22+
23+
24+
def test_inverse_gaussian(gridpath, datasetpath):
25+
26+
uxds = ux.open_dataset(
27+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
28+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
29+
)
30+
31+
res = uxds['inverse_gaussian'].azimuthal_mean(center_coord=(45, 0), outer_radius=2, radius_step=0.5)
32+
33+
# Expects increasing values from center
34+
atol = 1e-12
35+
diffs = res[1:].diff("radius").values
36+
diffs = diffs[np.isfinite(diffs)]
37+
np.testing.assert_array_less(-atol, diffs)
38+
39+
def test_non_zero_hit_counts(gridpath, datasetpath):
40+
uxds = ux.open_dataset(
41+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
42+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
43+
)
44+
45+
res, hit_counts = uxds['inverse_gaussian'].azimuthal_mean(center_coord=(45, 0), outer_radius=2, radius_step=0.5, return_hit_counts=True)
46+
47+
# At least one hit after the first circle
48+
assert np.all(hit_counts[1:] > 1)
49+
50+
assert 'radius' in hit_counts.dims
51+
assert hit_counts.sizes['radius'] == res.sizes['radius']
52+
53+
def test_zero_hit_counts(gridpath, datasetpath):
54+
uxds = ux.open_dataset(
55+
gridpath("mpas", "dyamond-30km", "gradient_grid_subset.nc"),
56+
datasetpath("mpas", "dyamond-30km", "gradient_data_subset.nc")
57+
)
58+
59+
# Outside of grid domain
60+
res, hit_counts = uxds['inverse_gaussian'].azimuthal_mean(center_coord=(-45, 0), outer_radius=2, radius_step=0.5, return_hit_counts=True)
61+
62+
assert 'radius' in hit_counts.dims
63+
assert hit_counts.sizes['radius'] == res.sizes['radius']
64+
65+
# No hits
66+
assert np.all(hit_counts == 0)
67+
68+
print(res)

uxarray/core/dataarray.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,127 @@ def zonal_mean(self, lat=(-90, 90, 10), **kwargs):
588588
# Alias for 'zonal_mean', since this name is also commonly used.
589589
zonal_average = zonal_mean
590590

591+
def azimuthal_mean(
592+
self,
593+
center_coord,
594+
outer_radius: int | float,
595+
radius_step: int | float,
596+
return_hit_counts: bool = False,
597+
):
598+
"""Compute averages along circles of constant great-circle distance from a point.
599+
600+
Parameters
601+
----------
602+
center_coord: tuple, list, ndarray
603+
Longitude and latitude of the center of the bounding circle
604+
outer_radius: scalar, int, float
605+
The maximum radius, in great-circle degrees, at which the azimuthal mean will be computed.
606+
radius_step: scalar, int, float
607+
Means will be computed at intervals of `radius_step` on the interval [0, outer_radius]
608+
return_hit_counts: bool, false
609+
Indicates whether to return the number of hits at each radius
610+
611+
Returns
612+
-------
613+
azimuthal_mean: xr.DataArray
614+
Contains a variable with a dimension 'radius' corresponding to the azimuthal average.
615+
hit_counts: xr.DataArray
616+
The number of hits at each radius
617+
618+
619+
Examples
620+
--------
621+
# Range from 0° to 5° at 0.5° intervals, around the central point lon,lat=10,50
622+
>>> az = uxds["var"].azimuthal_mean(
623+
... center_coord=(10, 50), outer_radius=5.0, radius_step=0.5
624+
... )
625+
>>> az.plot(title="Azimuthal Mean")
626+
627+
Notes
628+
-----
629+
Only supported for face-centered data variables. Candidate faces are determined
630+
using bounding circles - for radii = [r1, r2, r3, ...] faces whose centers lie at distance d,
631+
r2 < d <= r3 are included in calculations for r3.
632+
"""
633+
from uxarray.grid.coordinates import _lonlat_rad_to_xyz
634+
635+
if not self._face_centered():
636+
raise ValueError(
637+
"Azimuthal mean computations are currently only supported for face-centered data variables."
638+
)
639+
640+
if outer_radius <= 0:
641+
raise ValueError("Radius must be a positive scalar.")
642+
643+
kdtree = self.uxgrid._get_scipy_kd_tree()
644+
645+
lon_deg, lat_deg = map(float, np.asarray(center_coord))
646+
center_xyz = np.array(
647+
_lonlat_rad_to_xyz(np.deg2rad(lon_deg), np.deg2rad(lat_deg))
648+
)
649+
650+
radii_deg = np.arange(0.0, outer_radius + radius_step, radius_step, dtype=float)
651+
radii_rad = np.deg2rad(radii_deg)
652+
chord_radii = 2.0 * np.sin(radii_rad / 2.0)
653+
654+
faces_processed = np.array([], dtype=np.int_)
655+
means = np.full(
656+
(radii_deg.size, *self.to_xarray().isel(drop=True, n_face=0).shape), np.nan
657+
)
658+
hit_count = np.zeros_like(radii_deg, dtype=np.int_)
659+
660+
for ii, r_chord in enumerate(chord_radii):
661+
# indices of faces within the bounding circle for this radius
662+
within = np.array(
663+
kdtree.query_ball_point(center_xyz, r_chord), dtype=np.int_
664+
)
665+
if within.size:
666+
within.sort()
667+
668+
# include only the new ring: r_(i-1) < d <= r_i
669+
faces_in_bin = np.setdiff1d(within, faces_processed, assume_unique=True)
670+
hit_count[ii] = faces_in_bin.size
671+
672+
if hit_count[ii] == 0:
673+
continue
674+
675+
faces_processed = within # cumulative set for next iteration
676+
677+
tpose = self.isel(n_face=faces_in_bin).transpose(..., "n_face")
678+
means[ii, ...] = tpose.weighted_mean().data
679+
680+
# swap the leading 'radius' axis into the former n_face position
681+
face_axis = self.dims.index("n_face")
682+
dims = list(self.dims)
683+
dims[face_axis] = "radius"
684+
means = np.moveaxis(means, 0, face_axis)
685+
686+
hit_count = xr.DataArray(
687+
data=hit_count, dims="radius", coords={"radius": radii_deg}
688+
)
689+
690+
uxda = xr.DataArray(
691+
means,
692+
dims=dims,
693+
coords={"radius": radii_deg},
694+
name=self.name + "_azimuthal_mean"
695+
if self.name is not None
696+
else "azimuthal_mean",
697+
attrs={
698+
"azimuthal_mean": True,
699+
"center_lon": lon_deg,
700+
"center_lat": lat_deg,
701+
"radius_units": "degrees",
702+
},
703+
)
704+
705+
if return_hit_counts:
706+
return uxda, hit_count
707+
else:
708+
return uxda
709+
710+
azimuthal_average = azimuthal_mean
711+
591712
def weighted_mean(self, weights=None):
592713
"""Computes a weighted mean.
593714

0 commit comments

Comments
 (0)