-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: refactor
send_message
API (#1633)
- Loading branch information
1 parent
191e86a
commit 41358a5
Showing
10 changed files
with
383 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from datetime import datetime, timezone | ||
from typing import Literal, Union | ||
|
||
from pydantic import BaseModel, field_serializer | ||
|
||
# MemGPT API style responses (intended to be easier to use vs getting true Message types) | ||
|
||
|
||
class BaseMemGPTMessage(BaseModel): | ||
id: str | ||
date: datetime | ||
|
||
@field_serializer("date") | ||
def serialize_datetime(self, dt: datetime, _info): | ||
return dt.now(timezone.utc).isoformat() | ||
|
||
|
||
class InternalMonologue(BaseMemGPTMessage): | ||
""" | ||
{ | ||
"internal_monologue": msg, | ||
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), | ||
"id": str(msg_obj.id) if msg_obj is not None else None, | ||
} | ||
""" | ||
|
||
internal_monologue: str | ||
|
||
|
||
class FunctionCall(BaseModel): | ||
name: str | ||
arguments: str | ||
|
||
|
||
class FunctionCallMessage(BaseMemGPTMessage): | ||
""" | ||
{ | ||
"function_call": { | ||
"name": function_call.function.name, | ||
"arguments": function_call.function.arguments, | ||
}, | ||
"id": str(msg_obj.id), | ||
"date": msg_obj.created_at.isoformat(), | ||
} | ||
""" | ||
|
||
function_call: FunctionCall | ||
|
||
|
||
class FunctionReturn(BaseMemGPTMessage): | ||
""" | ||
{ | ||
"function_return": msg, | ||
"status": "success" or "error", | ||
"id": str(msg_obj.id), | ||
"date": msg_obj.created_at.isoformat(), | ||
} | ||
""" | ||
|
||
function_return: str | ||
status: Literal["success", "error"] | ||
|
||
|
||
MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn] | ||
|
||
|
||
# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string | ||
|
||
|
||
class AssistantMessage(BaseMemGPTMessage): | ||
assistant_message: str | ||
|
||
|
||
class LegacyFunctionCallMessage(BaseMemGPTMessage): | ||
function_call: str | ||
|
||
|
||
LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,17 @@ | ||
from typing import List | ||
from typing import List, Union | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage | ||
from memgpt.schemas.message import Message | ||
from memgpt.schemas.usage import MemGPTUsageStatistics | ||
|
||
|
||
# TODO: consider moving into own file | ||
|
||
|
||
class MemGPTResponse(BaseModel): | ||
messages: List[Message] = Field(..., description="The messages returned by the agent.") | ||
# messages: List[Message] = Field(..., description="The messages returned by the agent.") | ||
messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field( | ||
..., description="The messages returned by the agent." | ||
) | ||
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.