Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the test model looks like based on cuda device #3

Open
hi589dsTq92 opened this issue May 13, 2020 · 2 comments
Open

the test model looks like based on cuda device #3

hi589dsTq92 opened this issue May 13, 2020 · 2 comments

Comments

@hi589dsTq92
Copy link

Hi bor:
when I try to run test.py, and when the code go into line 58:
trainer.load_checkpoint(log_dir + 'checkpoint')
then an error happend at here.
'''
raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

'''
this looks like means I must run these code on a cuda device.
or because of this test model is trained on a cuda device, I must build a new model which supported on CPU only?

@Xu-Yao
Copy link
Contributor

Xu-Yao commented May 14, 2020

Hello knowasdf, I tried to reproduce the issue but did not succeed. I searched your question and find this. So I propose to replace the line 219 of trainer.py by
"""
state_dict = torch.load(checkpoint_path, map_location={'cuda:0': 'cpu'})
"""
I hope that could help you solve the problem :)

@hi589dsTq92
Copy link
Author

Hey bro:
cool job!! it's WORKING now.
for adaptability, I revised some code like these:
trainer.py at line 218,219:
def load_checkpoint(self, checkpoint_path, device={'cuda:0': 'cpu'}):
state_dict = torch.load(checkpoint_path, device)

train.py at line 77:
epoch_0 = trainer.load_checkpoint(opts.checkpoint, device)
test.py at line 56:
trainer.load_checkpoint(opts.checkpoint, device)
test.py at line 58:
trainer.load_checkpoint(log_dir + 'checkpoint', device)

;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants