-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
Description
The FBetaScore metric in metrax.classification_metrics does not implement the merge method. This causes a NotImplementedError when using the metric with the metrax.nnx wrapper (e.g., metrax.nnx.FBetaScore), as the wrapper's update method relies on merge to combine metrics.
It appears the merge method is commented out in the source code of metrax/classification_metrics.py.
metrax/src/metrax/classification_metrics.py
Lines 677 to 695 in 4ff6ccf
| """ | |
| This function is currently unused as the 'from_model_output' function can handle the whole | |
| dataset without needing to split and merge them. I'm leaving this here for now incase we want to | |
| repurpose this or need to change something that requires this function's use again. This function would need | |
| to be reworked for it to work with the current implementation of this class. | |
| """ | |
| # # Merge datasets together | |
| # def merge(self, other: 'FBetaScore') -> 'FBetaScore': | |
| # | |
| # # Check if the incoming beta is the same value as the current beta | |
| # if other.beta == self.beta: | |
| # return type(self)( | |
| # true_positives = self.true_positives + other.true_positives, | |
| # false_positives = self.false_positives + other.false_positives, | |
| # false_negatives = self.false_negatives + other.false_negatives, | |
| # beta=self.beta, | |
| # ) | |
| # else: | |
| # raise ValueError('The "Beta" values between the two are not equal.') |
Minimal Reproduction
import metrax.nnx
import jax.numpy as jnp
import jax.random
# Setup dummy data
predictions = jax.random.normal(jax.random.PRNGKey(0), (3,))
labels = jnp.arange(3) % 2
# Initialize and update metric
f1_metric = metrax.nnx.FBetaScore()
f1_metric.update(predictions=predictions, labels=labels) # Raises NotImplementedErrorTraceback
Traceback (most recent call last):
File "repro.py", line 10, in <module>
f1_metric.update(predictions=predictions, labels=labels)
File ".../site-packages/metrax/nnx/nnx_wrapper.py", line 31, in update
self.clu_metric = self.clu_metric.merge(other_clu_metric)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/clu/metrics.py", line 148, in merge
raise NotImplementedError("Must override merge()")
NotImplementedError: Must override merge()
Metadata
Metadata
Assignees
Labels
No labels