CommonLoopUtils
CommonLoopUtils copied to clipboard
metrics: Use Array and ArrayLike types thoughout
metrics: Use Array and ArrayLike types thoughout
Currently the inputs to from_model_output are not typed. However,
these functions cannot accept arbitrary inputs, they need to be a value
convertable to a jax.Array. This change fixes this so that:
-
from_model_outputtakes in types ofArrayorArrayLike - Removes use of
jnp.arrayas a type as it's equivalent toAny - Makes members of Metric classes have type
Array - Moves mask checking code into its own function
While we could make everything use Array (instead of ArrayLike),
this would break code like:
@flax.struct.dataclass
class Collection(metrics.Collection):
train_accuracy: metrics.Accuracy
learning_rate: metrics.LastValue.from_output("learning_rate")
Collection.gather_from_model_output(learning_rate=0.02, ...)
which seems undesirable.
Note that count and value for LastValue have type ArrayLike,
as this code needs to support passing a plain number for value or
count. Also, the base Metric.compute() method has type Any,
because some metrics return Array while others use dict[str, Array].