gabeweisz
gabeweisz
Seeing the same error. Commit 97ae4b672f0a9d8bc30ab536d4bac42c3d044aff works for me on GPU
For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86 I'm not part of your google collaboration, but...
I used commit 97ae4b6 and did not change anything. I installed packages using the requirements.txt in that commit, and then updated the two packages that I mention above manually using...
The newest commit (https://github.com/LargeWorldModel/LWM/commit/b8e36023d17d965f40071dcbc6dcdb1865d84a49) fixes this error for me.
Instead of checking out, you can also run: pip install git+https://github.com/state-spaces/mamba.git To check out, build, and install in one step
I ran this successfully on the rocm/pytorch:latest docker image. Can you try?
I wrote a script to look at the checkpoint that we generated and compare it to the original data. What I found (at least in my initial look) is that...
We ran the script in a way very similar to how you ran it - my colleague Gowtham has shared what we did earlier. When we ran this, we didn't...
We just found the place in the documentation where orbax says that all nodes need to [write to the same filesystem](https://orbax.readthedocs.io/en/latest/_modules/orbax/checkpoint/type_handlers.html#merge_ocdbt_per_process_files) - that explains what went wrong for us.
Please go ahead and resolve the ticket