diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index 2ed8f0d6..e22d2405 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -13,7 +13,7 @@ import agentlab.llm.tracking as tracking from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs from agentlab.llm.huggingface_utils import HFBaseChatModel -from agentlab.llm.llm_utils import Discussion +from agentlab.llm.llm_utils import AIMessage, Discussion def make_system_message(content: str) -> dict: @@ -305,7 +305,7 @@ def __call__(self, messages: list[dict]) -> dict: ): tracking.TRACKER.instance(input_tokens, output_tokens, cost) - return make_assistant_message(completion.choices[0].message.content) + return AIMessage(completion.choices[0].message.content) def get_stats(self): return { diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 153978b1..9bb2d7ab 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -6,7 +6,7 @@ from transformers import AutoTokenizer, GPT2TokenizerFast from agentlab.llm.base_api import AbstractChatModel -from agentlab.llm.llm_utils import Discussion +from agentlab.llm.llm_utils import AIMessage, Discussion from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template @@ -80,7 +80,7 @@ def __call__( itr = 0 while True: try: - response = self.llm(prompt) + response = AIMessage(self.llm(prompt)) return response except Exception as e: if itr == self.n_retry_server - 1: