Skip to content

Commit

Permalink
Enable saving modules as pytorch_model.bin (#2542)
Browse files Browse the repository at this point in the history
* Enable saving as modules as pytorch_model.bin

* Update sentence_transformers/models/Transformer.py

Co-authored-by: Tom Aarsen <[email protected]>

* add test to SentenceTransformer for default  serialization

* enable switching serialization type between 'pytorch' and 'safetensors' in CrossEncoder

* fix code qualitiy

* Update sentence_transformers/SentenceTransformer.py

Co-authored-by: Tom Aarsen <[email protected]>

---------

Co-authored-by: christopherkeibel <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
3 people authored Mar 15, 2024
1 parent b9255d9 commit fc2a2d8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 5 deletions.
8 changes: 7 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def save(
model_name: Optional[str] = None,
create_model_card: bool = True,
train_datasets: Optional[List[str]] = None,
safe_serialization: bool = True,
):
"""
Saves all elements for this seq. sentence embedder into different sub-folders
Expand All @@ -584,6 +585,7 @@ def save(
:param model_name: Optional model name
:param create_model_card: If True, create a README.md with basic information about this model
:param train_datasets: Optional list with the names of the datasets used to to train the model
:param safe_serialization: If true, save the model using safetensors. If false, save the model the traditional PyTorch way
"""
if path is None:
return
Expand Down Expand Up @@ -616,7 +618,11 @@ def save(
model_path = os.path.join(path, str(idx) + "_" + type(module).__name__)

os.makedirs(model_path, exist_ok=True)
module.save(model_path)
if isinstance(module, Transformer):
module.save(model_path, safe_serialization=safe_serialization)
else:
module.save(model_path)

modules_config.append(
{"idx": idx, "name": name, "path": os.path.basename(model_path), "type": type(module).__module__}
)
Expand Down
4 changes: 2 additions & 2 deletions sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,15 +448,15 @@ def _eval_during_training(self, evaluator, output_path, save_best_model, epoch,
if save_best_model:
self.save(output_path)

def save(self, path: str) -> None:
def save(self, path: str, safe_serialization: bool = True) -> None:
"""
Saves all model and tokenizer to path
"""
if path is None:
return

logger.info("Save model to {}".format(path))
self.model.save_pretrained(path)
self.model.save_pretrained(path, safe_serialization=safe_serialization)
self.tokenizer.save_pretrained(path)

def save_pretrained(self, path: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
def get_config_dict(self):
return {key: self.__dict__[key] for key in self.config_keys}

def save(self, output_path: str):
self.auto_model.save_pretrained(output_path)
def save(self, output_path: str, safe_serialization: bool = True):
self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
self.tokenizer.save_pretrained(output_path)

with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import csv
import gzip
import os
from pathlib import Path
import tempfile

import pytest
import torch
Expand Down Expand Up @@ -151,3 +153,21 @@ def test_rank() -> None:
ranks = model.rank(query, corpus)
pred_ranking = [rank["corpus_id"] for rank in ranks]
assert pred_ranking == expected_ranking


@pytest.mark.parametrize("safe_serialization", [True, False, None])
def test_safe_serialization(safe_serialization: bool) -> None:
with tempfile.TemporaryDirectory() as cache_folder:
model = CrossEncoder("cross-encoder/stsb-distilroberta-base")
if safe_serialization:
model.save(cache_folder, safe_serialization=safe_serialization)
model_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 1 == len(model_files)
elif safe_serialization is None:
model.save(cache_folder)
model_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 1 == len(model_files)
else:
model.save(cache_folder, safe_serialization=safe_serialization)
model_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
assert 1 == len(model_files)
18 changes: 18 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,24 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs):
)


@pytest.mark.parametrize("safe_serialization", [True, False, None])
def test_safe_serialization(safe_serialization: bool) -> None:
with tempfile.TemporaryDirectory() as cache_folder:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
if safe_serialization:
model.save(cache_folder, safe_serialization=safe_serialization)
model_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 1 == len(model_files)
elif safe_serialization is None:
model.save(cache_folder)
model_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 1 == len(model_files)
else:
model.save(cache_folder, safe_serialization=safe_serialization)
model_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
assert 1 == len(model_files)


def test_load_with_revision() -> None:
main_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="main")
latest_model = SentenceTransformer(
Expand Down

0 comments on commit fc2a2d8

Please sign in to comment.