Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restoring flax model checkpoints using orbax throws ValueError #1305

Open
ybangaru opened this issue Nov 7, 2024 · 3 comments
Open

Restoring flax model checkpoints using orbax throws ValueError #1305

ybangaru opened this issue Nov 7, 2024 · 3 comments

Comments

@ybangaru
Copy link

ybangaru commented Nov 7, 2024

The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory.
Version being used

orbax-checkpoint                   0.8.0

from flax.training import orbax_utils
import orbax.checkpoint

directory_gen_path = "checkpoints_loc"
orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer()
gen_options = orbax.checkpoint.CheckpointManagerOptions(save_interval_steps=5, create=True)
gen_checkpoint_manager = orbax.checkpoint.CheckpointManager(
    directory_gen_path, orbax_checkpointer_gen, gen_options
)

def save_model_checkpoints(step_, generator_state, generator_batch_stats):

    gen_ckpt = {
        "model": generator_state,
        "batch_stats": generator_batch_stats,
    }

    save_args_gen = orbax_utils.save_args_from_target(gen_ckpt)
    gen_checkpoint_manager.save(step_, gen_ckpt, save_kwargs={"save_args": save_args_gen})

def load_model_checkpoints(generator_state, generator_batch_stats):
    gen_target = {
        "model": generator_state,
        "batch_stats": generator_batch_stats,
    }

    latest_step = gen_checkpoint_manager.latest_step()
    gen_ckpt = gen_checkpoint_manager.restore(latest_step, items=gen_target)
    generator_state = gen_ckpt["model"]
    generator_batch_stats = gen_ckpt["batch_stats"]

    return generator_state, generator_batch_stats


The training of the model was done on a GPU and loading the state onto GPU device works fine, however, when trying to load the model to cpu, the following error is being thrown by the orbax checkpoint manager's restore method

/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py:1386: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(
ERROR:root:Device cuda:0 was not found in jax.local_devices().
ERROR:root:Device cuda:0 was not found in jax.local_devices()
.......
......
......

  File "/user/yashbangaru/simgan/pysrc/package/model_handlers.py", line 453, in load_model_checkpoints
    gen_ckpt = self.gen_checkpoint_manager.restore(latest_step, items=gen_target)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1356, in restore
    restored = self._checkpointer.restore(restore_directory, args=args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpointer.py", line 239, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 811, in restore
    restored[item_name] = handler.restore(
                          ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 769, in restore
    return self._handler_impl.restore(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 694, in restore
    tree_memory_size, restored_item = asyncio_utils.run_sync(
                                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 551, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py", line 1442, in deserialize
    ret = await asyncio.gather(*deserialize_ops)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/serialization.py", line 591, in async_deserialize
    raise ValueError(
ValueError: sharding passed to deserialization should be specified, concrete and an instance of `jax.sharding.Sharding`. Got None

Please let me know if you'd like any other details, I also added most of the traceback, it's a mess but hope it works

@cpgaffney1
Copy link
Collaborator

You need to pass restore_args when restoring with a different sharding that the one that the checkpoint was saved with. I'd recommend following this documentation: https://orbax.readthedocs.io/en/latest/guides/checkpoint/checkpointing_pytrees.html

@ybangaru
Copy link
Author

I managed to recover the model states and batch norm layer statistics by just changing orbax_checkpointer_gen = orbax.checkpoint.PyTreeCheckpointer() to orbax_checkpointer_gen = orbax.checkpoint.StandardCheckpointer(). However, i also had continuous normalization metrics of the different channels of my 3d arrays which is of the following form as shown in the image, basically a dictionary of integer keys and dict values, unfortunately, i'm not able to recover this by making the foretold change, can you please tell me if you have any thoughts on how i may be able to recover these values on a different device i.e. the cpu?

Image

        norm_data_ckpt = {"data_norm_states": self.data_handler.norm_state_and_config["curr_scaler_state"]}
        save_args_data_norm = orbax_utils.save_args_from_target(norm_data_ckpt)
        self.data_norm_checkpoint_manager.save(
            iter_value, norm_data_ckpt, save_kwargs={"save_args": save_args_data_norm}
        )

        latest_step = self.data_norm_checkpoint_manager.latest_step()
        norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step)

the error is as the following

 File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 95, in __init__
    self.load_checkpoints(checkpoint_flag)
  File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 285, in load_checkpoints
    latest_state = self._load_data_norm_from_checkpoints(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/simgan/pysrc/package/training_handlers.py", line 331, in _load_data_norm_from_checkpoints
    norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1356, in restore
    restored = self._checkpointer.restore(restore_directory, args=args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/async_checkpointer.py", line 429, in restore
    return super().restore(directory, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/checkpointer.py", line 239, in restore
    restored = self._handler.restore(directory, args=ckpt_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 811, in restore
    restored[item_name] = handler.restore(
                          ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/standard_checkpoint_handler.py", line 220, in restore
    return self._impl.restore(
           ^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 769, in restore
    return self._handler_impl.restore(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 694, in restore
    tree_memory_size, restored_item = asyncio_utils.run_sync(
                                      ^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/asyncio/base_events.py", line 687, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 551, in _maybe_deserialize
    deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/type_handlers.py", line 1382, in deserialize
    sharding = arg.sharding.to_jax_sharding()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/user/yashbangaru/miniconda3/envs/simganenv/lib/python3.12/site-packages/orbax/checkpoint/_src/metadata/sharding.py", line 290, in to_jax_sharding
    raise ValueError(
ValueError: Device cuda:0 was not found in jax.local_devices().

from the documentation that was shared, i could change both saving and restoring to make it work on different devices but it's really important for me to be able to recover the existing training runs on a different device, thank you.

@cpgaffney1
Copy link
Collaborator

I see, I understand your wish now. The problem is that you are calling restore with no arguments norm_data_ckpt = self.data_norm_checkpoint_manager.restore(latest_step). The CheckpointManager does not know that you want to restore it on CPU. By default it tries to restore the checkpoint with the same topology it was saved with. You must explicitly instruct it to restore on a different topology. This can be done by providing a target tree structure where shardings are SingleDeviceSharding('cpu') (something like that).

Or, cast to a numpy array, using an example from the docs I linked:

ckptr.restore(
    path / '2',
    args=ocp.args.PyTreeRestore(
        # `item` serves as a guide to what the result tree structure should look
        # like.
        item={
            # Value doesn't really matter here, as long as it's not None.
            'c': ...,
            # Can add in extra keys.
            'd': np.arange(8)
        },
        # `restore_args` must be relative to the result tree, not the
        # checkpoint.
        restore_args={
          'c': ocp.RestoreArgs(restore_type=np.ndarray),
        },
        transforms={
            'c': ocp.Transform(original_key='a')
        },
    ),
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants