flax icon indicating copy to clipboard operation
flax copied to clipboard

Add `BatchRenorm` layer to `linen.normalization`

Open danielpalenicek opened this issue 2 years ago • 0 comments

I propose adding a batch renormalization (BatchRenorm) layer to flax. I would be happy to make a PR.

BatchRenorm (https://arxiv.org/pdf/1702.03275.pdf) is an improved version of the vanilla BatchNorm layer. The difference to BatchNorm is that after a warm-up period, the running statistics are used to normalize the batch, both in train and eval mode. This helps to deal with BatchNorm's stability issues during long training runs. In contrast, BatchNorm uses the min batch statistics during train mode.

Alternatively, the BatchNorm layer could be refactored to support renormalization. However, I believe that it would be cleaner to put this into a separate BatchRenorm class.

Just recently, BatchRenorm has been shown to yield new state-of-the-art results in deep reinforcement learning (https://openreview.net/pdf?id=PczQtTsTIX), and I believe this might also lead to wider adoption in this community.

danielpalenicek avatar Apr 03 '24 21:04 danielpalenicek