diff --git a/deeppavlov/core/models/torch_model.py b/deeppavlov/core/models/torch_model.py index e14ac658de..9fc5912baf 100644 --- a/deeppavlov/core/models/torch_model.py +++ b/deeppavlov/core/models/torch_model.py @@ -16,7 +16,7 @@ from copy import deepcopy from logging import getLogger from pathlib import Path -from typing import Optional +from typing import Optional, Union import torch from overrides import overrides @@ -123,7 +123,7 @@ def init_from_opt(self, model_func: str) -> None: raise AttributeError("Model is not defined.") @overrides - def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: + def load(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None: """Load model from `fname` (if `fname` is not given, use `self.load_path`) to `self.model` along with the optimizer `self.optimizer`, optionally `self.lr_scheduler`. If `fname` (if `fname` is not given, use `self.load_path`) does not exist, initialize model from scratch. @@ -139,15 +139,18 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: if fname is not None: self.load_path = fname + if isinstance(self.load_path, str): + self.load_path = Path(self.load_path) + model_func = getattr(self, self.opt.get("model_name"), None) if self.load_path: log.info(f"Load path {self.load_path} is given.") - if isinstance(self.load_path, Path) and not self.load_path.parent.is_dir(): + if not self.load_path.parent.is_dir(): raise ConfigError("Provided load path is incorrect!") - weights_path = Path(self.load_path.resolve()) - weights_path = weights_path.with_suffix(f".pth.tar") + weights_path = self.load_path.resolve() + weights_path = weights_path.with_suffix(".pth.tar") if weights_path.exists(): log.info(f"Load path {weights_path} exists.") log.info(f"Initializing `{self.__class__.__name__}` from saved.") @@ -169,7 +172,7 @@ def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: self.init_from_opt(model_func) @overrides - def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: + def save(self, fname: Optional[Union[str, Path]] = None, *args, **kwargs) -> None: """Save torch model to `fname` (if `fname` is not given, use `self.save_path`). Checkpoint includes `model_state_dict`, `optimizer_state_dict`, and `epochs_done` (number of training epochs). @@ -184,10 +187,11 @@ def save(self, fname: Optional[str] = None, *args, **kwargs) -> None: if fname is None: fname = self.save_path + fname = Path(fname) if not fname.parent.is_dir(): raise ConfigError("Provided save path is incorrect!") - weights_path = Path(fname).with_suffix(f".pth.tar") + weights_path = fname.with_suffix(".pth.tar") log.info(f"Saving model to {weights_path}.") # move the model to `cpu` before saving to provide consistency torch.save({