-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restoring flax model checkpoints using orbax throws ValueError #1305
Comments
You need to pass restore_args when restoring with a different sharding that the one that the checkpoint was saved with. I'd recommend following this documentation: https://orbax.readthedocs.io/en/latest/guides/checkpoint/checkpointing_pytrees.html |
I managed to recover the model states and batch norm layer statistics by just changing
the error is as the following
from the documentation that was shared, i could change both saving and restoring to make it work on different devices but it's really important for me to be able to recover the existing training runs on a different device, thank you. |
I see, I understand your wish now. The problem is that you are calling restore with no arguments Or, cast to a numpy array, using an example from the docs I linked:
|
The following code blocks are being utlized to save the train state of the model during training and to restore the state back into memory.
Version being used
The training of the model was done on a GPU and loading the state onto GPU device works fine, however, when trying to load the model to cpu, the following error is being thrown by the orbax checkpoint manager's restore method
Please let me know if you'd like any other details, I also added most of the traceback, it's a mess but hope it works
The text was updated successfully, but these errors were encountered: