Skip to content

Commit

Permalink
[enh] Throw error if StaticEmbedding-based model is trained with in…
Browse files Browse the repository at this point in the history
…compatible loss (#2990)
  • Loading branch information
tomaarsen authored Oct 17, 2024
1 parent a1db32d commit 72d5649
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
7 changes: 6 additions & 1 deletion sentence_transformers/losses/CachedGISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.checkpoint import get_device_states, set_device_states

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer
from sentence_transformers.models import StaticEmbedding, Transformer


class RandContext:
Expand Down Expand Up @@ -139,6 +139,11 @@ def __init__(
trainer.train()
"""
super().__init__()
if isinstance(model[0], StaticEmbedding):
raise ValueError(
"CachedGISTEmbedLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. "
"Consider using GISTEmbedLoss instead."
)
self.model = model
self.guide = guide
self.temperature = temperature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.checkpoint import get_device_states, set_device_states

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.models import StaticEmbedding


class RandContext:
Expand Down Expand Up @@ -145,6 +146,12 @@ def __init__(
trainer.train()
"""
super().__init__()
if isinstance(model[0], StaticEmbedding):
raise ValueError(
"CachedMultipleNegativesRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. "
"Consider using MultipleNegativesRankingLoss instead."
)

self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import RandContext
from sentence_transformers.models import StaticEmbedding


def _backward_hook(
Expand Down Expand Up @@ -114,6 +115,12 @@ def __init__(
- Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf
"""
super().__init__()
if isinstance(model[0], StaticEmbedding):
raise ValueError(
"CachedMultipleNegativesSymmetricRankingLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding. "
"Consider using MultipleNegativesSymmetricRankingLoss instead."
)

self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
Expand Down
7 changes: 7 additions & 0 deletions sentence_transformers/losses/DenoisingAutoEncoderLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,6 +74,12 @@ def __init__(
)
"""
super().__init__()

if isinstance(model[0], StaticEmbedding):
raise ValueError(
"DenoisingAutoEncoderLoss is not compatible with a SentenceTransformer model based on a StaticEmbedding."
)

self.encoder = model # This will be the final model used during the inference time.
self.tokenizer_encoder = model.tokenizer

Expand Down
8 changes: 7 additions & 1 deletion sentence_transformers/losses/GISTEmbedLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor, nn

from sentence_transformers.models import Transformer
from sentence_transformers.models import StaticEmbedding, Transformer
from sentence_transformers.SentenceTransformer import SentenceTransformer


Expand Down Expand Up @@ -91,6 +91,12 @@ def __init__(
if self.must_retokenize:
self.tokenizer = self.model.tokenizer

if isinstance(self.model[0], StaticEmbedding):
raise ValueError(
"If we must retokenize because the guide model has a different tokenizer, "
"then the Sentence Transformer model must not be based on a StaticEmbedding."
)

def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))

Expand Down

0 comments on commit 72d5649

Please sign in to comment.