Skip to content

Commit

Permalink
feat: flatten content field in user message (#781)
Browse files Browse the repository at this point in the history
  • Loading branch information
carenthomas authored Jan 28, 2025
1 parent b8e4b15 commit 2cc801c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
4 changes: 3 additions & 1 deletion letta/schemas/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ToolReturnMessage,
UserMessage,
)
from letta.system import unpack_message
from letta.utils import get_utc_time, is_utc_datetime, json_dumps


Expand Down Expand Up @@ -264,11 +265,12 @@ def to_letta_message(
elif self.role == MessageRole.user:
# This is type UserMessage
assert self.text is not None, self
message_str = unpack_message(self.text)
messages.append(
UserMessage(
id=self.id,
date=self.created_at,
content=self.text,
content=message_str or self.text,
)
)
elif self.role == MessageRole.system:
Expand Down
20 changes: 20 additions & 0 deletions letta/system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import uuid
import warnings
from typing import Optional

from .constants import (
Expand Down Expand Up @@ -205,3 +206,22 @@ def get_token_limit_warning():
}

return json_dumps(packaged_message)


def unpack_message(packed_message) -> str:
"""Take a packed message string and attempt to extract the inner message content"""

try:
message_json = json.loads(packed_message)
except:
warnings.warn(f"Was unable to load message as JSON to unpack: ''{packed_message}")
return packed_message

if "message" not in message_json:
if "type" in message_json and message_json["type"] in ["login", "heartbeat"]:
# This is a valid user message that the ADE expects, so don't print warning
return packed_message
warnings.warn(f"Was unable to find 'message' field in packed message object: '{packed_message}'")
return packed_message
else:
return message_json.get("message")
6 changes: 3 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from letta.schemas.message import Message
from letta.schemas.source import Source as PydanticSource
from letta.server.server import SyncServer
from letta.system import unpack_message

from .utils import DummyDataConnector

Expand Down Expand Up @@ -711,7 +712,7 @@ def _test_get_messages_letta_format(

elif message.role == MessageRole.user:
assert isinstance(letta_message, UserMessage)
assert message.text == letta_message.content
assert unpack_message(message.text) == letta_message.content
letta_message_index += 1

elif message.role == MessageRole.system:
Expand All @@ -734,8 +735,7 @@ def _test_get_messages_letta_format(


def test_get_messages_letta_format(server, user, agent_id):
# for reverse in [False, True]:
for reverse in [False]:
for reverse in [False, True]:
_test_get_messages_letta_format(server, user, agent_id, reverse=reverse)


Expand Down

0 comments on commit 2cc801c

Please sign in to comment.