fedjax icon indicating copy to clipboard operation
fedjax copied to clipboard

Support for haiku models with non-trainable state

Open marcociccone opened this issue 4 years ago • 2 comments

Hi! congrats on this great library! I've started using it a few days ago and I love it!

Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)? I didn't find any nontrivial way, but maybe I'm missing something.

Thanks a lot for your help!

marcociccone avatar Dec 29 '21 20:12 marcociccone

Thanks for the feedback. Currently, we do not support using a haiku model with a non-trainable state. Tracking the state across federated rounds is nontrivial and we could not find a good use case for it. If you share your use-case, we are happy to see if there is an alternate way to implement it in fedjax.

stheertha avatar Jan 04 '22 21:01 stheertha

Thanks @stheertha! I agree that tracking statistics is nontrivial in FL. For the moment I've overcome the issue by replacing batchnorm with groupnorm and it seems to be working fine. There might be cases in which you may want to use only client specific stats though.

marcociccone avatar Jan 05 '22 13:01 marcociccone