flax icon indicating copy to clipboard operation
flax copied to clipboard

NNX WeightNorm

Open rbflx opened this issue 1 year ago • 4 comments

Is there an NNX implementation of the WeightNorm module as used in Linen? I can't seem to find it in the normalization layers implemented in NNX.

I'm in the process of writing a new model in NNX which would require them.

rbflx avatar Dec 10 '24 19:12 rbflx

We don't have a WeightNorm implementation yet in NNX. It should be very easy to implement using nnx.state + nnx.statelib.map + nnx.update. Contributions welcomed.

cgarciae avatar Dec 10 '24 23:12 cgarciae

Is it okay if I take a shot at this implementation? I see that NNX doesn't have InstanceNorm and SpectralNorm from Linen as well.

prakashkagitha avatar Dec 23 '24 08:12 prakashkagitha

@cgarciae I see the old implementation of WeightNorm from Linen here: https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/normalization.html#WeightNorm

Are you suggesting that there's a better way using nnx.state, nnx.statelib.map, and nnx.update?

mattbahr avatar Feb 24 '25 17:02 mattbahr

I have a branch where I've added the InstanceNorm and SpectralNorm implementations to NNX. @cgarciae do you think it would be worth opening a PR for these?

mattbahr avatar Mar 11 '25 22:03 mattbahr

Any updates on this? @mattbahr have you tried to use the weight norm implementation here?

AshwinSankar17 avatar Aug 25 '25 08:08 AshwinSankar17