Skip to content

Commit

Permalink
feat: update list message response to not be a union (#783)
Browse files Browse the repository at this point in the history
  • Loading branch information
carenthomas authored Jan 28, 2025
1 parent 5bc11f1 commit b8e4b15
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
9 changes: 5 additions & 4 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_archival_memory(

def get_messages(
self, agent_id: str, after: Optional[str] = None, before: Optional[str] = None, limit: Optional[int] = 1000
) -> List[Message]:
) -> List[LettaMessage]:
raise NotImplementedError

def list_model_configs(self) -> List[LLMConfig]:
Expand Down Expand Up @@ -965,7 +965,7 @@ def delete_archival_memory(self, agent_id: str, memory_id: str):

def get_messages(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
) -> List[Message]:
) -> List[LettaMessage]:
"""
Get messages from an agent with pagination.
Expand All @@ -983,7 +983,7 @@ def get_messages(
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", params=params, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get messages: {response.text}")
return [Message(**message) for message in response.json()]
return [LettaMessage(**message) for message in response.json()]

def send_message(
self,
Expand Down Expand Up @@ -3355,7 +3355,7 @@ def get_archival_memory(

def get_messages(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
) -> List[Message]:
) -> List[LettaMessage]:
"""
Get messages from an agent with pagination.
Expand All @@ -3377,6 +3377,7 @@ def get_messages(
after=after,
limit=limit,
reverse=True,
return_message_object=False,
)

def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]:
Expand Down
15 changes: 3 additions & 12 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Annotated, List, Optional, Union
from typing import Annotated, List, Optional

from fastapi import APIRouter, BackgroundTasks, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -391,15 +391,7 @@ def delete_archival_memory(


AgentMessagesResponse = Annotated[
Union[List[Message], List[LettaMessageUnion]],
Field(
json_schema_extra={
"anyOf": [
{"type": "array", "items": {"$ref": "#/components/schemas/Message"}},
{"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}},
]
}
),
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
]


Expand All @@ -410,7 +402,6 @@ def list_messages(
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return LettaMessage objects."),
config: LettaRequestConfig = Query(LettaRequestConfig(), description="Configuration options for the LettaRequest."),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
Expand All @@ -426,7 +417,7 @@ def list_messages(
before=before,
limit=limit,
reverse=True,
return_message_object=msg_object,
return_message_object=False,
assistant_message_tool_name=config.assistant_message_tool_name,
assistant_message_tool_kwarg=config.assistant_message_tool_kwarg,
)
Expand Down

0 comments on commit b8e4b15

Please sign in to comment.