diff --git a/yt/data_objects/profiles.py b/yt/data_objects/profiles.py index 15a266e96d3..058a871dac6 100644 --- a/yt/data_objects/profiles.py +++ b/yt/data_objects/profiles.py @@ -1202,6 +1202,44 @@ def sanitize_field_tuple_keys(input_dict, data_source): return input_dict +def _sanitize_dictarg_required_bin_fields( + input_dict, data_source, bin_fields, arg_name +): + """ + Helper function for create_profile to coerce and check arguments + that expect None or expect Mappings with keys for each bin_field + to a standardized format. + + While it would be more efficient to avoid explicitly checking for + the presence of bin_field keys in this function, the cost is + worthwhile for a few reasons: + 1. len(bin_fields) is always small and the calculation that this + check accompanies is commonly substantially more expensive than + the check itself + 2. it helps us simplify the implementation of create_profile + 3. it lets us explicitly document slightly idiosyncratic behavior + (that was poorly documented historically). + """ + if input_dict is None: + return input_dict + + tmp = sanitize_field_tuple_keys(input_dict, data_source) + out = {} + for bin_field in bin_fields: + try: + out[bin_field] = tmp[bin_field[-1]] + except KeyError: + try: + out[bin_field] = tmp[bin_field] + except KeyError: + raise ValueError( + f"The {arg_name} argument must be None or a dict with keys for " + f"each bin_field. {arg_name} is missing an entry for {bin_field} " + f"(or equivalently, for {bin_field[-1]})" + ) from None + return out + + def create_profile( data_source, bin_fields, @@ -1300,12 +1338,35 @@ def create_profile( if len(bin_fields) > 1 and isinstance(accumulation, bool): accumulation = [accumulation for _ in range(len(bin_fields))] + # handle sanitization/sanitization bin_fields = data_source._determine_fields(bin_fields) fields = data_source._determine_fields(fields) units = sanitize_field_tuple_keys(units, data_source) - extrema = sanitize_field_tuple_keys(extrema, data_source) logs = sanitize_field_tuple_keys(logs, data_source) - override_bins = sanitize_field_tuple_keys(override_bins, data_source) + override_bins = _sanitize_dictarg_required_bin_fields( + override_bins, data_source, bin_fields, arg_name="override_bins" + ) + + # The following logic explicitly maintains the historical idiosyncratic behavior + # when extrema is a dict and it is missing a key for one of the bin_fields: + # -> this is allowed if `all(bool(v) == False for v in collapse(extrema.values()))` + # -> this is an error in **all** other cases + if extrema is None or not any(collapse(extrema.values())): + # when extrema are all Nones, we treat them as though extrema was set as None. + # + # In the future, it may be better to replace `not any(...)` with either: + # `all(v is None for v in collapse(extrema.values()))` OR + # `all(v is None or v == (None, None) for v in collapse(extrema.values))` + # + # Be aware that doing this changes behavior, when extrema has the value: + # `{: (0., None)}` or `{: (None, 0.)}` + # I happen to think think that the existing behavior is a bug (i.e. the + # behavior changes if we replaced 0. with **ANY** other number) + extrema = dict.fromkeys(bin_fields, (None, None)) + else: + extrema = _sanitize_dictarg_required_bin_fields( + extrema, data_source, bin_fields, arg_name="extrema" + ) if any(is_pfield) and not all(is_pfield): if hasattr(data_source.ds, "_sph_ptypes"): @@ -1370,36 +1431,53 @@ def create_profile( logs_list.append(data_source.ds.field_info[bin_field].take_log) logs = logs_list - # Are the extrema all Nones? Then treat them as though extrema was set as None - if extrema is None or not any(collapse(extrema.values())): - ex = [ - data_source.quantities["Extrema"](f, non_zero=l) - for f, l in zip(bin_fields, logs, strict=True) - ] - # pad extrema by epsilon so cells at bin edges are not excluded - for i, (mi, ma) in enumerate(ex): - mi = mi - np.spacing(mi) - ma = ma + np.spacing(ma) - ex[i][0], ex[i][1] = mi, ma + if override_bins is None: + o_bins = [None for _ in bin_fields] else: - ex = [] + o_bins = [] for bin_field in bin_fields: bf_units = data_source.ds.field_info[bin_field].output_units - try: - field_ex = list(extrema[bin_field[-1]]) - except KeyError as e: - try: - field_ex = list(extrema[bin_field]) - except KeyError: - raise RuntimeError( - f"Could not find field {bin_field[-1]} or {bin_field} in extrema" - ) from e + field_obin = override_bins[bin_field] + + if field_obin is None: + o_bins.append(None) + continue + + if isinstance(field_obin, tuple): + field_obin = data_source.ds.arr(*field_obin) + + if units is not None and bin_field in units: + fe = data_source.ds.arr(field_obin, units[bin_field]) + else: + if hasattr(field_obin, "units"): + fe = field_obin.to(bf_units) + else: + fe = data_source.ds.arr(field_obin, bf_units) + fe.convert_to_units(bf_units) + field_obin = fe.d + o_bins.append(field_obin) + # infer the extrema for each bin_field + ex = [] + for bin_field, log, field_obin in zip(bin_fields, logs, o_bins, strict=True): + if field_obin is not None: + # extrema are **only** used to infer bins. When override_bins are provided, + # the inferred bins are ignored. Thus, it's ok to set the bin_field's + # extrema to any arbitrary pair of values that produce valid bins + ex.append([1, 2]) # we use a positive minimum in case ``log == True`` + continue + else: + bf_units = data_source.ds.field_info[bin_field].output_units + + # get (& sanitize) the extrema specified for field_ex + field_ex = list(extrema[bin_field]) if isinstance(field_ex[0], tuple): field_ex = [data_source.ds.quan(*f) for f in field_ex] - if any(exi is None for exi in field_ex): + + # compute any missing extrema and use the result to update field_ex + if None in field_ex: try: - ds_extrema = data_source.quantities.extrema(bin_field) + ds_extrema = data_source.quantities.extrema(bin_field, non_zero=log) except AttributeError: # ytdata profile datasets don't have data_source.quantities bf_vals = data_source[bin_field] @@ -1410,6 +1488,8 @@ def create_profile( # pad extrema by epsilon so cells at bin edges are # not excluded field_ex[i] -= (-1) ** i * np.spacing(field_ex[i]) + + # handle units if units is not None and bin_field in units: for i, exi in enumerate(field_ex): if hasattr(exi, "units"): @@ -1423,41 +1503,8 @@ def create_profile( else: fe = data_source.ds.arr(field_ex, bf_units) fe.convert_to_units(bf_units) - field_ex = [fe[0].v, fe[1].v] - if is_sequence(field_ex[0]): - field_ex[0] = data_source.ds.quan(field_ex[0][0], field_ex[0][1]) - field_ex[0] = field_ex[0].in_units(bf_units) - if is_sequence(field_ex[1]): - field_ex[1] = data_source.ds.quan(field_ex[1][0], field_ex[1][1]) - field_ex[1] = field_ex[1].in_units(bf_units) - ex.append(field_ex) - if override_bins is not None: - o_bins = [] - for bin_field in bin_fields: - bf_units = data_source.ds.field_info[bin_field].output_units - try: - field_obin = override_bins[bin_field[-1]] - except KeyError: - field_obin = override_bins[bin_field] - - if field_obin is None: - o_bins.append(None) - continue - - if isinstance(field_obin, tuple): - field_obin = data_source.ds.arr(*field_obin) - - if units is not None and bin_field in units: - fe = data_source.ds.arr(field_obin, units[bin_field]) - else: - if hasattr(field_obin, "units"): - fe = field_obin.to(bf_units) - else: - fe = data_source.ds.arr(field_obin, bf_units) - fe.convert_to_units(bf_units) - field_obin = fe.d - o_bins.append(field_obin) + ex.append([fe[0].v, fe[1].v]) # record the extrema args = [data_source] for f, n, (mi, ma), l in zip(bin_fields, n_bins, ex, logs, strict=True):