Skip to content

Commit

Permalink
feat(torchscript): support _extra_files (#3480)
Browse files Browse the repository at this point in the history
provides `_extra_files` to `torch.jit.save` and `torch.jit.load` 

Co-authored-by: Sauyon Lee <[email protected]>
  • Loading branch information
aarnphm and sauyon authored Feb 9, 2023
1 parent 928a480 commit 5f7ebe3
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/bentoml/_internal/frameworks/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ def get(tag_like: str | Tag) -> Model:

def load_model(
bentoml_model: str | Tag | Model,
device_id: t.Optional[str] = "cpu",
) -> torch.ScriptModule:
device_id: str | None = "cpu",
*,
_extra_files: dict[str, t.Any] | None = None,
) -> torch.ScriptModule | tuple[torch.ScriptModule, dict[str, t.Any]]:
"""
Load a model from BentoML local modelstore with given name.
Args:
tag (:code:`Union[str, Tag]`):
tag:
Tag of a saved model in BentoML local modelstore.
device_id (:code:`str`, `optional`):
device_id:
Optional devices to put the given model on. Refer to https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
model_store (:mod:`~bentoml._internal.models.store.ModelStore`, default to :mod:`BentoMLContainer.model_store`):
BentoML modelstore, provided by DI Container.
_extra_files:
A dictionary of file names and a empty string. See https://pytorch.org/docs/stable/generated/torch.jit.load.html.
Returns:
:obj:`torch.ScriptModule`: an instance of :obj:`torch.ScriptModule` from BentoML modelstore.
Expand All @@ -67,7 +69,12 @@ def load_model(
f"Model {bentoml_model.tag} was saved with module {bentoml_model.info.module}, not loading with {MODULE_NAME}."
)
weight_file = bentoml_model.path_of(MODEL_FILENAME)
model: torch.ScriptModule = torch.jit.load(weight_file, map_location=device_id) # type: ignore[reportPrivateImportUsage]

model: torch.ScriptModule = torch.jit.load(
weight_file,
map_location=device_id,
_extra_files=_extra_files,
)
return model


Expand All @@ -82,6 +89,7 @@ def save_model(
metadata: t.Dict[str, t.Any] | None = None,
_framework_name: str = "torchscript",
_module_name: str = MODULE_NAME,
_extra_files: dict[str, t.Any] | None = None,
) -> bentoml.Model:
"""
Save a model instance to BentoML modelstore.
Expand Down Expand Up @@ -113,8 +121,6 @@ def save_model(
import bentoml
import torch
TODO(jiang)
"""
if not isinstance(model, (torch.ScriptModule, torch.jit.ScriptModule)):
raise TypeError(f"Given model ({model}) is not a torch.ScriptModule.")
Expand All @@ -131,6 +137,10 @@ def save_model(
framework_name=_framework_name,
framework_versions=framework_versions,
)
if _extra_files is not None:
if metadata is None:
metadata = {}
metadata["_extra_files"] = [f for f in _extra_files]

if signatures is None:
signatures = {"__call__": {"batchable": False}}
Expand All @@ -152,8 +162,9 @@ def save_model(
context=context,
metadata=metadata,
) as bento_model:
weight_file = bento_model.path_of(MODEL_FILENAME)
torch.jit.save(model, weight_file) # type: ignore
torch.jit.save(
model, bento_model.path_of(MODEL_FILENAME), _extra_files=_extra_files
)
return bento_model


Expand Down

0 comments on commit 5f7ebe3

Please sign in to comment.