Davis Yoshida
Davis Yoshida
@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku [here](https://github.com/davisyoshida/haiku-mup). If you're not attached to FLAX in particular you could use this. You could...
Ah that makes perfect sense, I'll generate new versions of the figures. Thanks!
@n2cholas Ah that's a good idea, I'll look into it
I replied over on the discussion post: https://github.com/google/flax/discussions/3586
> the Triton version we have internally crashes if any of the dot operands is transposed Good to know, I'm sure that just saved me an hour of fruitless debugging....
@superbobry The monkeypatch works for me, and thanks for making it so much easier to get pallas up and running on GPU
+1 as well, I spent quite a while installing different permutations of versions of things but I couldn't find one that worked.
Worked for me as well, thanks!
Just ran into this while using JAX's pallas on a Quadro RTX 6000, CUDA 12.3.
@pidajay The issue here is that get_model is just called once. You want to apply the `@checkpointable` decorator to a callable that will actually be run during the training loop....