Skip to content

Commit

Permalink
Using AIMessage in ChatModels
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibaultLSDC committed Dec 3, 2024
1 parent 38b2c0b commit 7e49aa7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import agentlab.llm.tracking as tracking
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
from agentlab.llm.huggingface_utils import HFBaseChatModel
from agentlab.llm.llm_utils import Discussion
from agentlab.llm.llm_utils import AIMessage, Discussion


def make_system_message(content: str) -> dict:
Expand Down Expand Up @@ -305,7 +305,7 @@ def __call__(self, messages: list[dict]) -> dict:
):
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

return make_assistant_message(completion.choices[0].message.content)
return AIMessage(completion.choices[0].message.content)

def get_stats(self):
return {
Expand Down
4 changes: 2 additions & 2 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import AutoTokenizer, GPT2TokenizerFast

from agentlab.llm.base_api import AbstractChatModel
from agentlab.llm.llm_utils import Discussion
from agentlab.llm.llm_utils import AIMessage, Discussion
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template


Expand Down Expand Up @@ -80,7 +80,7 @@ def __call__(
itr = 0
while True:
try:
response = self.llm(prompt)
response = AIMessage(self.llm(prompt))
return response
except Exception as e:
if itr == self.n_retry_server - 1:
Expand Down

0 comments on commit 7e49aa7

Please sign in to comment.