Colin Gaffney

Results 54 comments of Colin Gaffney

I recognize that it's not the most convenient solution, but you could also implement a TypeHandler to deal with this. Would be a relatively simple override of the existing NumpyHandler.

Thanks for reporting, we're looking some refactoring that will resolve these empty node issues.

pmap arrays are always difficult to work with. A few things I would try: 1. Use pjit instead of pmap. Probably not a reasonable suggestion for me to make, but...

Hi Simon, This seems quite reasonable to me. We don't have a mechanism set up for mirroring code from external to internal, but probably an external contributor could just submit...

We're working on a fix to this, unfortunately the sharding metadata doesn't work that well in every case yet. If you must call `metadata`, just delete the sharding file and...

Hi, apologies for the long delay on this - we concluded that using `jax.Sharding` directly in the metadata was a bad decision from the start, since it can't always be...

Also note: prefer to specify the shardings for your tree in `args=StandardRestore()` whenever possible. Either that or specify the `restore_type` as `np.ndarray`. https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html

Warning looks to be erroneous, probably has to do with your `state` containing some empty nodes, or similar. I'd say don't worry about it. We're doing some refactoring that should...

* If you are _not_ using OCDBT, it is expected that checkpoints with more parameters take longer to load. So for a model with stacked weights, it should be much...

It may have to do with loading to GPU, this is not something that is well tested. Could you provide a minimal repro, along with details about the environment you're...