Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 106 additions & 59 deletions yt/data_objects/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,44 @@ def sanitize_field_tuple_keys(input_dict, data_source):
return input_dict


def _sanitize_dictarg_required_bin_fields(
input_dict, data_source, bin_fields, arg_name
):
"""
Helper function for create_profile to coerce and check arguments
that expect None or expect Mappings with keys for each bin_field
to a standardized format.

While it would be more efficient to avoid explicitly checking for
the presence of bin_field keys in this function, the cost is
worthwhile for a few reasons:
1. len(bin_fields) is always small and the calculation that this
check accompanies is commonly substantially more expensive than
the check itself
2. it helps us simplify the implementation of create_profile
3. it lets us explicitly document slightly idiosyncratic behavior
(that was poorly documented historically).
"""
if input_dict is None:
return input_dict

tmp = sanitize_field_tuple_keys(input_dict, data_source)
out = {}
for bin_field in bin_fields:
try:
out[bin_field] = tmp[bin_field[-1]]
except KeyError:
try:
out[bin_field] = tmp[bin_field]
except KeyError:
raise ValueError(
f"The {arg_name} argument must be None or a dict with keys for "
f"each bin_field. {arg_name} is missing an entry for {bin_field} "
f"(or equivalently, for {bin_field[-1]})"
) from None
return out


def create_profile(
data_source,
bin_fields,
Expand Down Expand Up @@ -1300,12 +1338,35 @@ def create_profile(
if len(bin_fields) > 1 and isinstance(accumulation, bool):
accumulation = [accumulation for _ in range(len(bin_fields))]

# handle sanitization/sanitization
bin_fields = data_source._determine_fields(bin_fields)
fields = data_source._determine_fields(fields)
units = sanitize_field_tuple_keys(units, data_source)
extrema = sanitize_field_tuple_keys(extrema, data_source)
logs = sanitize_field_tuple_keys(logs, data_source)
override_bins = sanitize_field_tuple_keys(override_bins, data_source)
override_bins = _sanitize_dictarg_required_bin_fields(
override_bins, data_source, bin_fields, arg_name="override_bins"
)

# The following logic explicitly maintains the historical idiosyncratic behavior
# when extrema is a dict and it is missing a key for one of the bin_fields:
# -> this is allowed if `all(bool(v) == False for v in collapse(extrema.values()))`
# -> this is an error in **all** other cases
if extrema is None or not any(collapse(extrema.values())):
# when extrema are all Nones, we treat them as though extrema was set as None.
#
# In the future, it may be better to replace `not any(...)` with either:
# `all(v is None for v in collapse(extrema.values()))` OR
# `all(v is None or v == (None, None) for v in collapse(extrema.values))`
#
# Be aware that doing this changes behavior, when extrema has the value:
# `{<bin_field>: (0., None)}` or `{<bin_field>: (None, 0.)}`
# I happen to think think that the existing behavior is a bug (i.e. the
# behavior changes if we replaced 0. with **ANY** other number)
Comment on lines +1354 to +1364
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to reviewers: This is mostly for posterity. I don't think we should deal with this issue in this PR (since this PR is focused on not changing observable behavior). I plan to make an issue for this after the PR is reviewed


I think we should probably change the if-statement to:

    if extrema is None or all(v is None for v in collapse(extrema.values())):

OR

    if extrema is None or all(list(v) == [None, None] for v in extrema.values()):

As the comments mention, both choices technically change the behavior when one of the extrema is a 0 and the other is a None, but I think the existing behavior is probably a bug. The former option is more backwards compatible while the latter option is more consistent with the rest of our extrema handling.

For more context, the following table whether various choices for the extrema argument are considered valid, when bin_fields is [("gas", "velocity_x"), ("gas", "density")]

Current Logic Snippet 1 Snippet 2
{("gas","velocity_x"): (0.0, None)} ️✅*
{("gas","velocity_x"): None, ("gas", "density"): None} ️✅ ️✅
{("gas","velocity_x"): (None, None), ("gas", "density"): (None, None)} ️✅ ️✅ ️✅
{("gas","velocity_x"): (0.0, None), ("gas", "density"): (None, None)} ️✅*
{("gas","velocity_x"): None, ("gas", "density"): (1e-27, 1e-23)}
{("gas","velocity_x"): (None,None), ("gas", "density"): (1e-27, 1e-23)} ️✅ ️✅ ️✅

Importantly: The starred cases treat are scenarios that the snippets would fix. In those cases, the existing logic totally ignores the fact that the caller specifies that the lower extrema for ("gas","velocity") is 0.0 and acts like it was set to None. Furthermore, the existing logic causes the function to behave completely differently for the following cases:

  • {("gas","velocity_x"): (-1.0, None)} (this is considered to be invalid)
  • {("gas","velocity_x"): (1.0, None), ("gas", "density"): (None, None)}
  • {("gas","velocity_x"): (0.0, None), ("gas", "density"): (None, 1e-20)}.

extrema = dict.fromkeys(bin_fields, (None, None))
else:
extrema = _sanitize_dictarg_required_bin_fields(
extrema, data_source, bin_fields, arg_name="extrema"
)

if any(is_pfield) and not all(is_pfield):
if hasattr(data_source.ds, "_sph_ptypes"):
Expand Down Expand Up @@ -1370,36 +1431,53 @@ def create_profile(
logs_list.append(data_source.ds.field_info[bin_field].take_log)
logs = logs_list

# Are the extrema all Nones? Then treat them as though extrema was set as None
if extrema is None or not any(collapse(extrema.values())):
ex = [
data_source.quantities["Extrema"](f, non_zero=l)
for f, l in zip(bin_fields, logs, strict=True)
]
# pad extrema by epsilon so cells at bin edges are not excluded
for i, (mi, ma) in enumerate(ex):
mi = mi - np.spacing(mi)
ma = ma + np.spacing(ma)
ex[i][0], ex[i][1] = mi, ma
if override_bins is None:
o_bins = [None for _ in bin_fields]
else:
ex = []
o_bins = []
for bin_field in bin_fields:
bf_units = data_source.ds.field_info[bin_field].output_units
try:
field_ex = list(extrema[bin_field[-1]])
except KeyError as e:
try:
field_ex = list(extrema[bin_field])
except KeyError:
raise RuntimeError(
f"Could not find field {bin_field[-1]} or {bin_field} in extrema"
) from e
field_obin = override_bins[bin_field]

if field_obin is None:
o_bins.append(None)
continue

if isinstance(field_obin, tuple):
field_obin = data_source.ds.arr(*field_obin)

if units is not None and bin_field in units:
fe = data_source.ds.arr(field_obin, units[bin_field])
else:
if hasattr(field_obin, "units"):
fe = field_obin.to(bf_units)
else:
fe = data_source.ds.arr(field_obin, bf_units)
fe.convert_to_units(bf_units)
field_obin = fe.d
o_bins.append(field_obin)

# infer the extrema for each bin_field
ex = []
for bin_field, log, field_obin in zip(bin_fields, logs, o_bins, strict=True):
if field_obin is not None:
# extrema are **only** used to infer bins. When override_bins are provided,
# the inferred bins are ignored. Thus, it's ok to set the bin_field's
# extrema to any arbitrary pair of values that produce valid bins
ex.append([1, 2]) # we use a positive minimum in case ``log == True``
continue
else:
bf_units = data_source.ds.field_info[bin_field].output_units

# get (& sanitize) the extrema specified for field_ex
field_ex = list(extrema[bin_field])
if isinstance(field_ex[0], tuple):
field_ex = [data_source.ds.quan(*f) for f in field_ex]
if any(exi is None for exi in field_ex):

# compute any missing extrema and use the result to update field_ex
if None in field_ex:
try:
ds_extrema = data_source.quantities.extrema(bin_field)
ds_extrema = data_source.quantities.extrema(bin_field, non_zero=log)
except AttributeError:
# ytdata profile datasets don't have data_source.quantities
bf_vals = data_source[bin_field]
Expand All @@ -1410,6 +1488,8 @@ def create_profile(
# pad extrema by epsilon so cells at bin edges are
# not excluded
field_ex[i] -= (-1) ** i * np.spacing(field_ex[i])

# handle units
if units is not None and bin_field in units:
for i, exi in enumerate(field_ex):
if hasattr(exi, "units"):
Expand All @@ -1423,41 +1503,8 @@ def create_profile(
else:
fe = data_source.ds.arr(field_ex, bf_units)
fe.convert_to_units(bf_units)
field_ex = [fe[0].v, fe[1].v]
if is_sequence(field_ex[0]):
field_ex[0] = data_source.ds.quan(field_ex[0][0], field_ex[0][1])
field_ex[0] = field_ex[0].in_units(bf_units)
if is_sequence(field_ex[1]):
field_ex[1] = data_source.ds.quan(field_ex[1][0], field_ex[1][1])
field_ex[1] = field_ex[1].in_units(bf_units)
ex.append(field_ex)

if override_bins is not None:
o_bins = []
for bin_field in bin_fields:
bf_units = data_source.ds.field_info[bin_field].output_units
try:
field_obin = override_bins[bin_field[-1]]
except KeyError:
field_obin = override_bins[bin_field]

if field_obin is None:
o_bins.append(None)
continue

if isinstance(field_obin, tuple):
field_obin = data_source.ds.arr(*field_obin)

if units is not None and bin_field in units:
fe = data_source.ds.arr(field_obin, units[bin_field])
else:
if hasattr(field_obin, "units"):
fe = field_obin.to(bf_units)
else:
fe = data_source.ds.arr(field_obin, bf_units)
fe.convert_to_units(bf_units)
field_obin = fe.d
o_bins.append(field_obin)
ex.append([fe[0].v, fe[1].v]) # record the extrema

args = [data_source]
for f, n, (mi, ma), l in zip(bin_fields, n_bins, ex, logs, strict=True):
Expand Down
Loading