Davis Yoshida

Results 29 comments of Davis Yoshida

@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku [here](https://github.com/davisyoshida/haiku-mup). If you're not attached to FLAX in particular you could use this. You could...

Ah that makes perfect sense, I'll generate new versions of the figures. Thanks!

@n2cholas Ah that's a good idea, I'll look into it

I replied over on the discussion post: https://github.com/google/flax/discussions/3586

> the Triton version we have internally crashes if any of the dot operands is transposed Good to know, I'm sure that just saved me an hour of fruitless debugging....

@superbobry The monkeypatch works for me, and thanks for making it so much easier to get pallas up and running on GPU

+1 as well, I spent quite a while installing different permutations of versions of things but I couldn't find one that worked.

Worked for me as well, thanks!

Just ran into this while using JAX's pallas on a Quadro RTX 6000, CUDA 12.3.

@pidajay The issue here is that get_model is just called once. You want to apply the `@checkpointable` decorator to a callable that will actually be run during the training loop....