flax icon indicating copy to clipboard operation
flax copied to clipboard

Flax NNX and Orbax Checkpointing require hacking to work together

Open hdrwilkinson opened this issue 1 year ago • 6 comments

I'm building a system using flax.nnx and orbax.checkpointing. However, it is overly complicated on how to save and restore models due to the new jax.random.key() being used in flax.nnx rather than jax.random.PRNGkey().

I have had to create a workaround where all layers with rng and key in their path are changed from dtype=key<fry> to a format appropriate for saving. Then, upon restoration, they need to be shanged back.

I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?

https://colab.research.google.com/drive/1ozln9ejG7eRtxvbkqHYU3K6OyPvveH9w?usp=sharing

Note: I am also adding an issue to orbax to see if there is a fix their side (#1337).

hdrwilkinson avatar Nov 15 '24 13:11 hdrwilkinson

Thank you! I'll contact the Orbax team to see if they can fix this on their end.

cgarciae avatar Nov 15 '24 20:11 cgarciae

Hey! Here's a quick and dirty workaround.

Generally the idea is to use nnx.split with the NNX filter functionality to split the nnx.RngState types out of the state and then not save those.

graphdef, rng_state, other_state = nnx.split(model, nnx.RngState, ...)

and then just saving the other_state instead of the full thingy. I've edited your colab notebook to demonstrate this.

This means that RNG state will not be restored, which might be sub-optimal for certain scenarios but should work for most stuff. Hope it helps!

mishmish66 avatar Nov 17 '24 22:11 mishmish66

Another workaround for Dropout layers, and maybe custom layers too if they follow the same pattern, is to initialize them without the rngs arg, and only pass the RNG at __call__ time, like so:

import jax.numpy as jnp
import orbax.checkpoint as ocp
from flax import nnx

# Init dropout without rng arg.
model = nnx.Dropout(0.5)

# Pass RNG at call time.
output = model(jnp.ones(()), rngs=nnx.Rngs(0))

# This now works.
ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints/")
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", nnx.split(model)[1])

Versus if the RNG is supplied at initialization, the last line throws the following:

TypeError: Cannot interpret 'key<fry>' as a data type

But, this is only a workaround, as the RNG state will still not be serialized, and it makes for a more verbose call signature.

jkyl avatar Nov 25 '24 20:11 jkyl

New way to work around this

from copy import deepcopy

def key_uint32(state):
  """ Convert jax.random.key into uint32 for saving """

  dict_state = deepcopy(nnx.to_pure_dict(state))

  def traverse_dict(d):
      for k, v in d.items():
          if isinstance(v, dict):
              traverse_dict(v)
          else:
            if k == 'key':
              d[k] = jax.random.key_data(v)

  traverse_dict(dict_state)
  return dict_state

def uint32_key(dict_state):
   """ Convert uint32 to jax.randdom.key """  

  def traverse_dict(d):
      for k, v in d.items():
          if isinstance(v, dict):
              traverse_dict(v)
          else:
            if k == 'key':
              d[k] = jax.random.wrap_key_data(v)


  traverse_dict(dict_state)
  return dict_state

This is how I restore checkpoint

abstract_model = nnx.eval_shape(create_model)
graphdef, abstract_state = nnx.split(abstract_model)

restored = checkpoint_manager.restore(checkpoint_manager.best_step())

dict_state = uint32_key(restored['state'])
nnx.replace_by_pure_dict(abstract_state, dict_state)
model1 = nnx.merge(graphdef, abstract_state)

in my case, I use ocp.args.Composite. I couldn't load only the state by giving ocp.args.Composite.

jasonzhang2022 avatar Oct 05 '25 17:10 jasonzhang2022

Could this bug be given high priority for immediate fixing?

Impact: This issue affects any model configured with dropout. Consequence: It prevents the automatic saving of any non-trivial model.

jasonzhang2022 avatar Oct 08 '25 03:10 jasonzhang2022

It works sometime for me, and sometime not.

This works if run outside of the "with jax.set_mesh"

checkpoint_manager.save(step,
                                metrics=metrics,
                                args=ocp.args.Composite(
                                        # Save parameters with StandardSave
                                        state= ocp.args.StandardSave(state),
                                        metrics= ocp.args.JsonSave(metrics),
                                        data_idx = pygrain.PyGrainCheckpointSave(data_iter),
                                ),

restored = mngr.restore(mngr.latest_step(),
                        args=ocp.args.Composite(
                          state=ocp.args.StandardRestore(),
                          metrics=ocp.args.JsonRestore(),
                          data_idx=pygrain.PyGrainCheckpointRestore(restored_train_iter),
                      ))
state1 = restored['state']

jasonzhang2022 avatar Oct 18 '25 00:10 jasonzhang2022