NNX WeightNorm
Is there an NNX implementation of the WeightNorm module as used in Linen? I can't seem to find it in the normalization layers implemented in NNX.
I'm in the process of writing a new model in NNX which would require them.
We don't have a WeightNorm implementation yet in NNX. It should be very easy to implement using nnx.state + nnx.statelib.map + nnx.update. Contributions welcomed.
Is it okay if I take a shot at this implementation? I see that NNX doesn't have InstanceNorm and SpectralNorm from Linen as well.
@cgarciae I see the old implementation of WeightNorm from Linen here: https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/normalization.html#WeightNorm
Are you suggesting that there's a better way using nnx.state, nnx.statelib.map, and nnx.update?
I have a branch where I've added the InstanceNorm and SpectralNorm implementations to NNX. @cgarciae do you think it would be worth opening a PR for these?
Any updates on this? @mattbahr have you tried to use the weight norm implementation here?