CommonLoopUtils icon indicating copy to clipboard operation
CommonLoopUtils copied to clipboard

metrics: Use Array and ArrayLike types thoughout

Open copybara-service[bot] opened this issue 2 years ago • 0 comments

metrics: Use Array and ArrayLike types thoughout

Currently the inputs to from_model_output are not typed. However, these functions cannot accept arbitrary inputs, they need to be a value convertable to a jax.Array. This change fixes this so that:

  • from_model_output takes in types of Array or ArrayLike
  • Removes use of jnp.array as a type as it's equivalent to Any
  • Makes members of Metric classes have type Array
  • Moves mask checking code into its own function

While we could make everything use Array (instead of ArrayLike), this would break code like:

@flax.struct.dataclass
class Collection(metrics.Collection):
  train_accuracy: metrics.Accuracy
  learning_rate: metrics.LastValue.from_output("learning_rate")

Collection.gather_from_model_output(learning_rate=0.02, ...)

which seems undesirable.

Note that count and value for LastValue have type ArrayLike, as this code needs to support passing a plain number for value or count. Also, the base Metric.compute() method has type Any, because some metrics return Array while others use dict[str, Array].

copybara-service[bot] avatar May 04 '23 00:05 copybara-service[bot]