flax
flax copied to clipboard
[WIP] Require `transform_metadata` when variables have sharding annotation
Goal: to force user to input a transform_metadata if they do transform upon annotated variables.
Todo: Can we auto-infer the transform axis name from the annotated variables and only throw errors when not? Would that be more clever than requiring it always?