From 45cb9a8b186f4324c00db8d0f362eac564106936 Mon Sep 17 00:00:00 2001 From: Samet Date: Fri, 16 Sep 2022 07:19:29 -0700 Subject: [PATCH] Add map_location when loading the weights --- anomalib/utils/callbacks/model_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/anomalib/utils/callbacks/model_loader.py b/anomalib/utils/callbacks/model_loader.py index a89b5ecd68..ad1e7ebc71 100644 --- a/anomalib/utils/callbacks/model_loader.py +++ b/anomalib/utils/callbacks/model_loader.py @@ -27,7 +27,7 @@ def on_test_start(self, _trainer, pl_module: AnomalyModule) -> None: # pylint: Loads the model weights from ``weights_path`` into the PyTorch module. """ logger.info("Loading the model from %s", self.weights_path) - pl_module.load_state_dict(torch.load(self.weights_path)["state_dict"]) + pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"]) def on_predict_start(self, _trainer, pl_module: AnomalyModule) -> None: """Call when inference begins. @@ -35,4 +35,4 @@ def on_predict_start(self, _trainer, pl_module: AnomalyModule) -> None: Loads the model weights from ``weights_path`` into the PyTorch module. """ logger.info("Loading the model from %s", self.weights_path) - pl_module.load_state_dict(torch.load(self.weights_path)["state_dict"]) + pl_module.load_state_dict(torch.load(self.weights_path, map_location=pl_module.device)["state_dict"])