Struggling to restore metadata on other device
Hello,
I am trying to load metadata on a new device from a checkpoint via CheckpointManager API, but somehow struggle to find a solution. Below is a minimal example of what I am trying to do.
First I do "training" on GPU by running:
from orbax import checkpoint as ocp
import pathlib
import jax.numpy as jnp
ckpt_dir = pathlib.Path('.').expanduser().absolute()
ckpt_mngr = ocp.CheckpointManager(
ocp.test_utils.create_empty(ckpt_dir / 'checkpoints'),
item_names=('params', )
)
params = {'a': jnp.array([1.])}
for i in jnp.arange(10):
ckpt_mngr.save(
i,
args=ocp.args.Composite(params=ocp.args.StandardSave(params)),
)
I then copy the checkpoint to my local machine, which has only CPU available. When I try to get metadata I get the following behaviour.
# Load with the old API
from orbax import checkpoint as ocp
import pathlib
import jax.numpy as jnp
ckpt_dir = pathlib.Path('.').expanduser().absolute()
ckpt_load = ocp.CheckpointManager(
ckpt_dir / 'checkpoints',
{'params': ocp.PyTreeCheckpointer()}
)
latest_step = ckpt_load.latest_step()
ckpt_load.item_metadata(0)
Gives me
File ~/Documents/venvs/mlff/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:164, in _deserialize_sharding_from_json_string(sharding_string)
159 if device := _deserialize_sharding_from_json_string.device_map.get(
160 device_str, None
161 ):
162 return SingleDeviceSharding(device)
--> 164 raise ValueError(
165 f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
166 f' Device={device_str} was not found in jax.local_devices().'
167 )
169 else:
170 raise NotImplementedError(
171 'Sharding types other than `jax.sharding.NamedSharding` have not been '
172 'implemented.'
173 )
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
Does that mean that calls to metadata are only available as long as I am on the same device? How else could I get pytree structure without calling model.init itself? Following this issue https://github.com/google/orbax/issues/648 I could delete the _sharding file and then restore metadata by setting the restore_kwargs appropriately. However, this only works with the old API (see below) and seems a bit hacky to me, so I feel I am doing something wrong here. Using the new API
# Load with new API
ckpt_dir = pathlib.Path('.').expanduser().absolute()
ckpt_load = ocp.CheckpointManager(
ckpt_dir / 'checkpoints',
item_names=('params', )
)
latest_step = ckpt_load.latest_step()
ckpt_load.item_metadata(0)
I get
CompositeArgs({})
so no metadata at all.
We're working on a fix to this, unfortunately the sharding metadata doesn't work that well in every case yet. If you must call metadata, just delete the sharding file and continue using the old API for now.
Hi @cpgaffney1! Are there any updates on this matter?
I found a PR with fix for metadata reading, but it was not updated since January 17: https://github.com/google/orbax/pull/671
Thanks, Simon
Hi, apologies for the long delay on this - we concluded that using jax.Sharding directly in the metadata was a bad decision from the start, since it can't always be loaded correctly. We're adding a new representation of the sharding metadata that doesn't try to interact directly with real devices. You can track changes here: https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/sharding_metadata.py (the latest change doesn't have an external pull request yet). I expect this can be fixed by this week or the next.
cc @liangyaning33 who is working on the implementation.
Hi, sorry about the delay. The issue is now fixed. Can you please try again? Thanks!
Hi, I run into a similar issue.
I save my checkpoints with metrics, train it only on CPU and then on the same machine I want to load a checkpoint. But somehow it looks for a cuda:0 device for metadata. Any help would be greatly appreciated!!
Error:
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
Checkpoint Manager creation:
options = CheckpointManagerOptions(
best_fn=lambda metrics: metrics["metric1"],
best_mode="min",
max_to_keep=1,
save_interval_steps=1,
)
checkpoint_manager = CheckpointManager(
directory=checkpoint_dir,
options=options,
)
Checkpoint saving:
for step in tbar:
train_batch = generate_batch(datamodule, "train")
valid_batch = generate_batch(datamodule, "valid")
state_neural_net, current_logs = step_fn(
state_neural_net, train_batch, valid_batch
)
ckpt =state_neural_net
checkpoint_manager.save(
step,
args=StandardSave(ckpt),
metrics={
"metric1": float(metric1),
"metric2": float(metric2),
"metric3": float(metric3),
},
)
checkpoint_manager.wait_until_finished()
And then to load the checkpoint:
# Sets up Ckpt manager as described above
out_class = cls(
jobid=jobid,
logger_path=logger_path,
config=config,
datamodule=datamodule,
)
if step is None:
# Only checks steps with metrics available
step = out_class.checkpoint_manager.best_step()
out_class.neural_net = out_class.checkpoint_manager.restore(
step, args=StandardRestore()
)
But I get the following output and error:
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py:951, in CheckpointManager.restore(self, step, items, restore_kwargs, directory, args)
948 args = typing.cast(args_lib.Composite, args)
950 restore_directory = self._get_read_step_directory(step, directory)
--> 951 restored = self._checkpointer.restore(restore_directory, args=args)
952 if self._single_item:
953 return restored[DEFAULT_ITEM_NAME]
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/async_checkpointer.py:338, in AsyncCheckpointer.restore(self, directory, *args, **kwargs)
336 """See superclass documentation."""
337 self.wait_until_finished()
--> 338 return super().restore(directory, *args, **kwargs)
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:168, in Checkpointer.restore(self, directory, *args, **kwargs)
166 logging.info('Restoring item from %s.', directory)
167 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 168 restored = self._handler.restore(directory, args=ckpt_args)
169 logging.info('Finished restoring checkpoint from %s.', directory)
170 return restored
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:464, in CompositeCheckpointHandler.restore(self, directory, args)
462 continue
463 handler = self._get_or_set_handler(item_name, arg)
--> 464 restored[item_name] = handler.restore(
465 self._get_item_directory(directory, item_name), args=arg
466 )
467 return CompositeResults(**restored)
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py:166, in StandardCheckpointHandler.restore(self, directory, item, args)
163 restore_args = checkpoint_utils.construct_restore_args(args.item)
164 else:
165 restore_args = checkpoint_utils.construct_restore_args(
--> 166 self.metadata(directory)
167 )
168 return super().restore(
169 directory,
170 args=pytree_checkpoint_handler.PyTreeRestoreArgs(
171 item=args.item, restore_args=restore_args
172 ),
173 )
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1453, in PyTreeCheckpointHandler.metadata(self, directory)
1427 """Returns tree metadata.
1428
1429 The result will be a PyTree matching the structure of the saved checkpoint.
(...)
1450 tree containing metadata.
1451 """
1452 try:
-> 1453 return self._get_user_metadata(directory)
1454 except FileNotFoundError as e:
1455 raise FileNotFoundError('Could not locate metadata file.') from e
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1418, in PyTreeCheckpointHandler._get_user_metadata(self, directory)
1415 async def _get_metadata():
1416 return await asyncio.gather(*metadata_ops)
-> 1418 batched_metadatas = asyncio.run(_get_metadata())
1419 for keypath_batch, metadata_batch in zip(
1420 batched_keypaths.values(), batched_metadatas
1421 ):
1422 for keypath, value in zip(keypath_batch, metadata_batch):
File /.conda/envs/condreq/lib/python3.10/site-packages/nest_asyncio.py:30, in _patch_asyncio.<locals>.run(main, debug)
28 task = asyncio.ensure_future(main)
29 try:
---> 30 return loop.run_until_complete(task)
31 finally:
32 if not task.done():
File /.conda/envs/condreq/lib/python3.10/site-packages/nest_asyncio.py:98, in _patch_loop.<locals>.run_until_complete(self, future)
95 if not f.done():
96 raise RuntimeError(
97 'Event loop stopped before Future completed.')
---> 98 return f.result()
File /.conda/envs/condreq/lib/python3.10/asyncio/futures.py:201, in Future.result(self)
199 self.__log_traceback = False
200 if self._exception is not None:
--> 201 raise self._exception.with_traceback(self._exception_tb)
202 return self._result
File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:234, in Task.__step(***failed resolving arguments***)
232 result = coro.send(None)
233 else:
--> 234 result = coro.throw(exc)
235 except StopIteration as exc:
236 if self._must_cancel:
237 # Task is cancelled right before coro stops.
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1416, in PyTreeCheckpointHandler._get_user_metadata.<locals>._get_metadata()
1415 async def _get_metadata():
-> 1416 return await asyncio.gather(*metadata_ops)
File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:304, in Task.__wakeup(self, future)
302 def __wakeup(self, future):
303 try:
--> 304 future.result()
305 except BaseException as exc:
306 # This may also be a cancellation.
307 self.__step(exc)
File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:232, in Task.__step(***failed resolving arguments***)
228 try:
229 if exc is None:
230 # We use the `send` method directly, because coroutines
231 # don't have `__iter__` and `__next__` methods.
--> 232 result = coro.send(None)
233 else:
234 result = coro.throw(exc)
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1403, in ArrayHandler.metadata(self, infos)
1401 shardings.append(None)
1402 continue
-> 1403 deserialized = _deserialize_sharding_from_json_string(
1404 sharding_string.item()
1405 )
1406 shardings.append(deserialized or None)
1407 else:
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:166, in _deserialize_sharding_from_json_string(sharding_string)
161 if device := _deserialize_sharding_from_json_string.device_map.get(
162 device_str, None
163 ):
164 return SingleDeviceSharding(device)
--> 166 raise ValueError(
167 f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
168 f' Device={device_str} was not found in jax.local_devices().'
169 )
171 else:
172 raise NotImplementedError(
173 'Sharding types other than `jax.sharding.NamedSharding` have not been '
174 'implemented.'
175 )
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
datamodule.conditions
Versions: flax==0.7.4 jax==0.4.20 jaxlib==0.4.20+cuda12.cudnn89 optax==0.1.9 orbax-checkpoint==0.5.7
UPDATE SOLVED For me it was solved by downgrading nvidia-cudnn-cu12-9.1.0.70 to match jaxlib-0.4.20+cuda12.cudnn89. So pip install nvidia-cudnn-cu12==8.9.7.29.
Also note: prefer to specify the shardings for your tree in args=StandardRestore() whenever possible. Either that or specify the restore_type as np.ndarray. https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html