Skip to content

FBetaScore missing merge method causes NotImplementedError with metrax.nnx wrapper #131

@steventango

Description

@steventango

Description

The FBetaScore metric in metrax.classification_metrics does not implement the merge method. This causes a NotImplementedError when using the metric with the metrax.nnx wrapper (e.g., metrax.nnx.FBetaScore), as the wrapper's update method relies on merge to combine metrics.

It appears the merge method is commented out in the source code of metrax/classification_metrics.py.

"""
This function is currently unused as the 'from_model_output' function can handle the whole
dataset without needing to split and merge them. I'm leaving this here for now incase we want to
repurpose this or need to change something that requires this function's use again. This function would need
to be reworked for it to work with the current implementation of this class.
"""
# # Merge datasets together
# def merge(self, other: 'FBetaScore') -> 'FBetaScore':
#
# # Check if the incoming beta is the same value as the current beta
# if other.beta == self.beta:
# return type(self)(
# true_positives = self.true_positives + other.true_positives,
# false_positives = self.false_positives + other.false_positives,
# false_negatives = self.false_negatives + other.false_negatives,
# beta=self.beta,
# )
# else:
# raise ValueError('The "Beta" values between the two are not equal.')

Minimal Reproduction

import metrax.nnx
import jax.numpy as jnp
import jax.random

# Setup dummy data
predictions = jax.random.normal(jax.random.PRNGKey(0), (3,))
labels = jnp.arange(3) % 2

# Initialize and update metric
f1_metric = metrax.nnx.FBetaScore()
f1_metric.update(predictions=predictions, labels=labels)  # Raises NotImplementedError

Traceback

Traceback (most recent call last):
  File "repro.py", line 10, in <module>
    f1_metric.update(predictions=predictions, labels=labels)
  File ".../site-packages/metrax/nnx/nnx_wrapper.py", line 31, in update
    self.clu_metric = self.clu_metric.merge(other_clu_metric)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/clu/metrics.py", line 148, in merge
    raise NotImplementedError("Must override merge()")
NotImplementedError: Must override merge()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions