You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With some Bert and Roberta models like sentence-transformers/all-MiniLM-L12-v2 I found that the output is not similar to the one generated by sentence-transformers. If I place the following prints in _normalize_prompt_text_to_input() in serving_engine.py
print(f"{input_ids=}")
I get [101, 100, 3007, 1997, 100, 2003, 100, 1012, 102] for the sentence "The capital of France is Paris.". 100 is the UNK token. When I run with sentence-transformers, I get [ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012, 102] . This problem happens both with --tokenizer-mode auto and --tokenizer-mode slow.
Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
The text was updated successfully, but these errors were encountered:
In the model's sentence_bert_config.json do_lower_case is false
{
"max_seq_length": 128,
"do_lower_case": false
}
but in the tokenizer_config.json it's true:
{"do_lower_case": true, ....
When we added support for sentence_bert_config.json we assumed this configuration was meant to override the tokenizer configuration, but that's probably wrong:
Your current environment
The output of `python collect_env.py`
Model Input Dumps
No response
🐛 Describe the bug
With some Bert and Roberta models like
sentence-transformers/all-MiniLM-L12-v2
I found that the output is not similar to the one generated bysentence-transformers
. If I place the following prints in_normalize_prompt_text_to_input()
inserving_engine.py
I get
[101, 100, 3007, 1997, 100, 2003, 100, 1012, 102]
for the sentence "The capital of France is Paris.". 100 is theUNK
token. When I run with sentence-transformers, I get[ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012, 102]
. This problem happens both with--tokenizer-mode auto
and--tokenizer-mode slow
.cc: @DarkLight1337
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: