-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Description
metrax.nnx.nnx_wrapper.NnxWrapper metrics are currently incompatible with flax.nnx.MultiMetric when used alongside other metrics that require different keyword arguments.
When nnx.MultiMetric.update(**kwargs) is called, it passes all kwargs to every registered metric. However, the underlying metrax metrics (wrapped by NnxWrapper) typically implement from_model_output with specific arguments (e.g., predictions, labels, values) and do not accept arbitrary **kwargs. This results in a TypeError when extraneous arguments (intended for other metrics) are passed.
Minimal Reproduction
Here is an example where two metrax metrics (Average and Accuracy) are used together. They fail because they expect different arguments and do not ignore the arguments intended for the other.
import metrax.nnx
from flax import nnx
import jax.numpy as jnp
# Define metrics
metrics = nnx.MultiMetric(
loss=metrax.nnx.Average(), # Expects 'values'
accuracy=metrax.nnx.Accuracy() # Expects 'predictions', 'labels'
)
# Update with arguments for both.
metrics.update(
values=jnp.array([0.5, 0.2]),
predictions=jnp.array([0, 1]),
labels=jnp.array([0, 1])
)Traceback
TypeError: Average.from_model_output() got an unexpected keyword argument 'predictions'
Suggested Fix
Modify the base Metric class or the individual metric implementations in metrax to accept and ignore arbitrary **kwargs in from_model_output.
For example, in metrax/base.py:
@classmethod
def from_model_output(
cls,
values: jax.Array,
sample_weights: jax.Array | None = None,
**kwargs, # Add this to absorb extra arguments
) -> 'Average':
# ... implementation ...