-
Notifications
You must be signed in to change notification settings - Fork 317
Description
There is discrepancy of dtype when using inject_hyperparams and without injection.
When using inject_hyperparams, the hyperparams dtype is inferred from the params dtype.
See this:
import optax
import jax.numpy as jnp
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3)
print(optimizer.init(jnp.array([1, 2], dtype=jnp.bfloat16)))
this will output
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.000999451, dtype=bfloat16), 'b1': Array(0.898438, dtype=bfloat16), 'b2': Array(1, dtype=bfloat16), 'eps': Array(1.00117e-08, dtype=bfloat16), 'eps_root': Array(0, dtype=bfloat16), 'weight_decay': Array(0.000100136, dtype=bfloat16)}, hyperparams_states={}, inner_state=(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0, 0], dtype=bfloat16), nu=Array([0, 0], dtype=bfloat16)), EmptyState(), EmptyState()))
The hyperparams dtype is inferred from the params dtype and uses bfloat16.
However, if check dtype of hyperparams without injection, it's float32. Example code https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py#L601.
Should we respect the dtype in function signature in inject_hyperparams instead?