Skip to content

Commit

Permalink
Add fix for self-hosted HF models (#167)
Browse files Browse the repository at this point in the history
* add fix for self-hosted HF models

* Update src/agentlab/llm/huggingface_utils.py

* Update huggingface_utils.py

* updating test

---------

Co-authored-by: Thibault LSDC <[email protected]>
Co-authored-by: ThibaultLSDC <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 9e9b800 commit 8dc809c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers import AutoTokenizer, GPT2TokenizerFast

from agentlab.llm.base_api import AbstractChatModel
from agentlab.llm.llm_utils import Discussion
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template


Expand Down Expand Up @@ -59,6 +60,8 @@ def __call__(
if self.tokenizer:
# messages_formated = _convert_messages_to_dict(messages) ## ?
try:
if isinstance(messages, Discussion):
messages.merge()
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
except Exception as e:
if "Conversation roles must alternate" in str(e):
Expand Down
2 changes: 2 additions & 0 deletions src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def merge(self):
else:
new_content.append(elem)
self["content"] = new_content
if len(self["content"]) == 1:
self["content"] = self["content"][0]["text"]


class SystemMessage(BaseMessage):
Expand Down
3 changes: 1 addition & 2 deletions tests/llm/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ def test_message_merge_only_text():
]
message = llm_utils.BaseMessage(role="system", content=content)
message.merge()
assert len(message["content"]) == 1
assert message["content"][0]["text"] == "Hello, world!\nThis is a test."
assert message["content"] == "Hello, world!\nThis is a test."


def test_message_merge_text_image():
Expand Down

0 comments on commit 8dc809c

Please sign in to comment.