diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index ee8e3b089..c8e27cc09 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -576,6 +576,7 @@ def save( model_name: Optional[str] = None, create_model_card: bool = True, train_datasets: Optional[List[str]] = None, + safe_serialization: bool = True, ): """ Saves all elements for this seq. sentence embedder into different sub-folders @@ -584,6 +585,7 @@ def save( :param model_name: Optional model name :param create_model_card: If True, create a README.md with basic information about this model :param train_datasets: Optional list with the names of the datasets used to to train the model + :param safe_serialization: If true, save the model using safetensors. If false, save the model the traditional PyTorch way """ if path is None: return @@ -616,7 +618,11 @@ def save( model_path = os.path.join(path, str(idx) + "_" + type(module).__name__) os.makedirs(model_path, exist_ok=True) - module.save(model_path) + if isinstance(module, Transformer): + module.save(model_path, safe_serialization=safe_serialization) + else: + module.save(model_path) + modules_config.append( {"idx": idx, "name": name, "path": os.path.basename(model_path), "type": type(module).__module__} ) diff --git a/sentence_transformers/cross_encoder/CrossEncoder.py b/sentence_transformers/cross_encoder/CrossEncoder.py index 02a22cf8e..9dc70840a 100644 --- a/sentence_transformers/cross_encoder/CrossEncoder.py +++ b/sentence_transformers/cross_encoder/CrossEncoder.py @@ -448,7 +448,7 @@ def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, if save_best_model: self.save(output_path) - def save(self, path: str) -> None: + def save(self, path: str, safe_serialization: bool = True) -> None: """ Saves all model and tokenizer to path """ @@ -456,7 +456,7 @@ def save(self, path: str) -> None: return logger.info("Save model to {}".format(path)) - self.model.save_pretrained(path) + self.model.save_pretrained(path, safe_serialization=safe_serialization) self.tokenizer.save_pretrained(path) def save_pretrained(self, path: str) -> None: diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 1e595e155..3e72c865c 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -156,8 +156,8 @@ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): def get_config_dict(self): return {key: self.__dict__[key] for key in self.config_keys} - def save(self, output_path: str): - self.auto_model.save_pretrained(output_path) + def save(self, output_path: str, safe_serialization: bool = True): + self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization) self.tokenizer.save_pretrained(output_path) with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut: diff --git a/tests/test_cross_encoder.py b/tests/test_cross_encoder.py index dce2313b3..2a7f16d5b 100644 --- a/tests/test_cross_encoder.py +++ b/tests/test_cross_encoder.py @@ -5,6 +5,8 @@ import csv import gzip import os +from pathlib import Path +import tempfile import pytest import torch @@ -151,3 +153,21 @@ def test_rank() -> None: ranks = model.rank(query, corpus) pred_ranking = [rank["corpus_id"] for rank in ranks] assert pred_ranking == expected_ranking + + +@pytest.mark.parametrize("safe_serialization", [True, False, None]) +def test_safe_serialization(safe_serialization: bool) -> None: + with tempfile.TemporaryDirectory() as cache_folder: + model = CrossEncoder("cross-encoder/stsb-distilroberta-base") + if safe_serialization: + model.save(cache_folder, safe_serialization=safe_serialization) + model_files = list(Path(cache_folder).glob("**/model.safetensors")) + assert 1 == len(model_files) + elif safe_serialization is None: + model.save(cache_folder) + model_files = list(Path(cache_folder).glob("**/model.safetensors")) + assert 1 == len(model_files) + else: + model.save(cache_folder, safe_serialization=safe_serialization) + model_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) + assert 1 == len(model_files) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 87e9e7364..5bd1d4b71 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -176,6 +176,24 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs): ) +@pytest.mark.parametrize("safe_serialization", [True, False, None]) +def test_safe_serialization(safe_serialization: bool) -> None: + with tempfile.TemporaryDirectory() as cache_folder: + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + if safe_serialization: + model.save(cache_folder, safe_serialization=safe_serialization) + model_files = list(Path(cache_folder).glob("**/model.safetensors")) + assert 1 == len(model_files) + elif safe_serialization is None: + model.save(cache_folder) + model_files = list(Path(cache_folder).glob("**/model.safetensors")) + assert 1 == len(model_files) + else: + model.save(cache_folder, safe_serialization=safe_serialization) + model_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) + assert 1 == len(model_files) + + def test_load_with_revision() -> None: main_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="main") latest_model = SentenceTransformer(