diff --git a/memgpt/agent.py b/memgpt/agent.py index 8926ff20f5..47ecff7d01 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -24,7 +24,7 @@ from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages from memgpt.metadata import MetadataStore from memgpt.models import chat_completion_response -from memgpt.models.pydantic_models import ToolModel +from memgpt.models.pydantic_models import OptionState, ToolModel from memgpt.persistence_manager import LocalStateManager from memgpt.system import ( get_initial_boot_messages, @@ -314,6 +314,7 @@ def _get_ai_reply( function_call: str = "auto", first_message: bool = False, # hint stream: bool = False, # TODO move to config? + inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ) -> chat_completion_response.ChatCompletionResponse: """Get response from LLM API""" try: @@ -330,6 +331,8 @@ def _get_ai_reply( # streaming stream=stream, stream_inferface=self.interface, + # putting inner thoughts in func args or not + inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ) # special case for 'length' if response.choices[0].finish_reason == "length": @@ -401,7 +404,7 @@ def _handle_ai_response( printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}") try: function_to_call = self.functions_python[function_name] - except KeyError as e: + except KeyError: error_msg = f"No function named {function_name}" function_response = package_function_response(False, error_msg) messages.append( @@ -424,7 +427,7 @@ def _handle_ai_response( try: raw_function_args = function_call.arguments function_args = parse_json(raw_function_args) - except Exception as e: + except Exception: error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" function_response = package_function_response(False, error_msg) messages.append( @@ -550,6 +553,7 @@ def step( recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field stream: bool = False, # TODO move to config? timestamp: Optional[datetime.datetime] = None, + inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: """Top-level event message handler for the MemGPT agent""" @@ -634,6 +638,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: message_sequence=input_message_sequence, first_message=True, # passed through to the prompt formatter stream=stream, + inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ) if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break @@ -646,6 +651,7 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: response = self._get_ai_reply( message_sequence=input_message_sequence, stream=stream, + inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ) # Step 2: check if LLM wanted to call a function @@ -716,7 +722,18 @@ def validate_json(user_message_text: str, raise_on_error: bool) -> str: self.summarize_messages_inplace() # Try step again - return self.step(user_message, first_message=first_message, return_dicts=return_dicts) + return self.step( + user_message, + first_message=first_message, + first_message_retry_limit=first_message_retry_limit, + skip_verify=skip_verify, + return_dicts=return_dicts, + recreate_message_timestamp=recreate_message_timestamp, + stream=stream, + timestamp=timestamp, + inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, + ) + else: printd(f"step() failed with an unrecognized exception: '{str(e)}'") raise e diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index c7367306e9..ffd8fe930f 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -24,6 +24,7 @@ from memgpt.memory import ChatMemory from memgpt.metadata import MetadataStore from memgpt.migrate import migrate_all_agents, migrate_all_sources +from memgpt.models.pydantic_models import OptionState from memgpt.server.constants import WS_DEFAULT_PORT from memgpt.server.server import logger as server_logger @@ -410,6 +411,10 @@ def run( yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False, # streaming stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False, + # whether or not to put the inner thoughts inside the function args + no_content: Annotated[ + OptionState, typer.Option(help="Set to 'yes' for LLM APIs that omit the `content` field during tool calling") + ] = OptionState.DEFAULT, ): """Start chatting with an MemGPT agent @@ -671,7 +676,13 @@ def run( print() # extra space run_agent_loop( - memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream + memgpt_agent=memgpt_agent, + config=config, + first=first, + ms=ms, + no_verify=no_verify, + stream=stream, + inner_thoughts_in_kwargs=no_content, ) # TODO: add back no_verify diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 80af95702a..3fd00445a1 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -1,7 +1,9 @@ """ This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """ +import copy import json import uuid +import warnings from datetime import datetime, timezone from typing import Dict, List, Optional, TypeVar @@ -70,6 +72,27 @@ def to_dict(self): } +def add_inner_thoughts_to_tool_call( + tool_call: ToolCall, + inner_thoughts: str, + inner_thoughts_key: str, +) -> ToolCall: + """Add inner thoughts (arg + value) to a tool call""" + # because the kwargs are stored as strings, we need to load then write the JSON dicts + try: + # load the args list + func_args = json.loads(tool_call.function["arguments"]) + # add the inner thoughts to the args list + func_args[inner_thoughts_key] = inner_thoughts + # create the updated tool call (as a string) + updated_tool_call = copy.deepcopy(tool_call) + updated_tool_call.function["arguments"] = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII) + return updated_tool_call + except json.JSONDecodeError as e: + warnings.warn(f"Failed to put inner thoughts in kwargs: {e}") + raise e + + class Message(Record): """Representation of a message sent. @@ -249,12 +272,16 @@ def dict_to_message( tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, ) - def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict: + def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict: result_json = self.to_openai_dict() search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}} return search_result_json - def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict: + def to_openai_dict( + self, + max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN, + put_inner_thoughts_in_kwargs: bool = True, + ) -> dict: """Go from Message class to ChatCompletion message object""" # TODO change to pydantic casting, eg `return SystemMessageModel(self)` @@ -282,14 +309,25 @@ def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict: elif self.role == "assistant": assert self.tool_calls is not None or self.text is not None openai_message = { - "content": self.text, + "content": None if put_inner_thoughts_in_kwargs else self.text, "role": self.role, } # Optional fields, do not include if null if self.name is not None: openai_message["name"] = self.name if self.tool_calls is not None: - openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls] + if put_inner_thoughts_in_kwargs: + # put the inner thoughts inside the tool call before casting to a dict + openai_message["tool_calls"] = [ + add_inner_thoughts_to_tool_call( + tool_call, + inner_thoughts=self.text, + inner_thoughts_key=INNER_THOUGHTS_KWARG, + ).to_dict() + for tool_call in self.tool_calls + ] + else: + openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls] if max_tool_id_length: for tool_call_dict in openai_message["tool_calls"]: tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length] @@ -453,7 +491,7 @@ def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict: assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self) if self.name is None: - raise UserWarning(f"Couldn't find function name on tool call, defaulting to tool ID instead.") + warnings.warn(f"Couldn't find function name on tool call, defaulting to tool ID instead.") function_name = self.tool_call_id else: function_name = self.name diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index bb0d7e61b0..7e98c54f78 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -1,12 +1,15 @@ +import copy +import json import os import random import time import uuid +import warnings from typing import List, Optional, Union import requests -from memgpt.constants import CLI_WARNING_PREFIX +from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII from memgpt.credentials import MemGPTCredentials from memgpt.data_types import Message from memgpt.llm_api.anthropic import anthropic_chat_completions_request @@ -24,13 +27,17 @@ openai_chat_completions_request, ) from memgpt.local_llm.chat_completion_proxy import get_chat_completion +from memgpt.local_llm.constants import ( + INNER_THOUGHTS_KWARG, + INNER_THOUGHTS_KWARG_DESCRIPTION, +) from memgpt.models.chat_completion_request import ( ChatCompletionRequest, Tool, cast_message_to_subtype, ) from memgpt.models.chat_completion_response import ChatCompletionResponse -from memgpt.models.pydantic_models import LLMConfigModel +from memgpt.models.pydantic_models import LLMConfigModel, OptionState from memgpt.streaming_interface import ( AgentChunkStreamingInterface, AgentRefreshStreamingInterface, @@ -39,6 +46,89 @@ LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local"] +# TODO update to use better types +def add_inner_thoughts_to_functions( + functions: List[dict], + inner_thoughts_key: str, + inner_thoughts_description: str, + inner_thoughts_required: bool = True, + # inner_thoughts_to_front: bool = True, TODO support sorting somewhere, probably in the to_dict? +) -> List[dict]: + """Add an inner_thoughts kwarg to every function in the provided list""" + # return copies + new_functions = [] + + # functions is a list of dicts in the OpenAI schema (https://platform.openai.com/docs/api-reference/chat/create) + for function_object in functions: + function_params = function_object["parameters"]["properties"] + required_params = list(function_object["parameters"]["required"]) + + # if the inner thoughts arg doesn't exist, add it + if inner_thoughts_key not in function_params: + function_params[inner_thoughts_key] = { + "type": "string", + "description": inner_thoughts_description, + } + + # make sure it's tagged as required + new_function_object = copy.deepcopy(function_object) + if inner_thoughts_required and inner_thoughts_key not in required_params: + required_params.append(inner_thoughts_key) + new_function_object["parameters"]["required"] = required_params + + new_functions.append(new_function_object) + + # return a list of copies + return new_functions + + +def unpack_inner_thoughts_from_kwargs( + response: ChatCompletionResponse, + inner_thoughts_key: str, +) -> ChatCompletionResponse: + """Strip the inner thoughts out of the tool call and put it in the message content""" + if len(response.choices) == 0: + raise ValueError(f"Unpacking inner thoughts from empty response not supported") + + new_choices = [] + for choice in response.choices: + msg = choice.message + if msg.role == "assistant" and len(msg.tool_calls) >= 1: + if len(msg.tool_calls) > 1: + warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(msg.tool_calls)}) is not supported") + # TODO support multiple tool calls + tool_call = msg.tool_calls[0] + + try: + # Sadly we need to parse the JSON since args are in string format + func_args = dict(json.loads(tool_call.function.arguments)) + if inner_thoughts_key in func_args: + # extract the inner thoughts + inner_thoughts = func_args.pop(inner_thoughts_key) + + # replace the kwargs + new_choice = choice.model_copy(deep=True) + new_choice.message.tool_calls[0].function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII) + # also replace the message content + if new_choice.message.content is not None: + warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})") + new_choice.message.content = inner_thoughts + + # save copy + new_choices.append(new_choice) + else: + warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}") + + except json.JSONDecodeError as e: + warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}") + raise e + + # return an updated copy + new_response = response.model_copy(deep=True) + new_response.choices = new_choices + return new_response + + def is_context_overflow_error(exception: requests.exceptions.RequestException) -> bool: """Checks if an exception is due to context overflow (based on common OpenAI response messages)""" from memgpt.utils import printd @@ -152,6 +242,9 @@ def create( # streaming? stream: bool = False, stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, + # TODO move to llm_config? + # if unspecified (None), default to something we've tested + inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" from memgpt.utils import printd @@ -166,8 +259,31 @@ def create( printd("unsetting function_call because functions is None") function_call = None + # print("HELLO") + # openai if llm_config.model_endpoint_type == "openai": + + if inner_thoughts_in_kwargs == OptionState.DEFAULT: + # model that are known to not use `content` fields on tool calls + inner_thoughts_in_kwargs = ( + "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model + ) + else: + inner_thoughts_in_kwargs = True if inner_thoughts_in_kwargs == OptionState.YES else False + + assert isinstance(inner_thoughts_in_kwargs, bool), type(inner_thoughts_in_kwargs) + if inner_thoughts_in_kwargs: + functions = add_inner_thoughts_to_functions( + functions=functions, + inner_thoughts_key=INNER_THOUGHTS_KWARG, + inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION, + ) + + openai_message_list = [ + cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs)) for m in messages + ] + # TODO do the same for Azure? if credentials.openai_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy @@ -175,7 +291,7 @@ def create( if use_tool_naming: data = ChatCompletionRequest( model=llm_config.model, - messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], + messages=openai_message_list, tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, user=str(user_id), @@ -183,7 +299,7 @@ def create( else: data = ChatCompletionRequest( model=llm_config.model, - messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], + messages=openai_message_list, functions=functions, function_call=function_call, user=str(user_id), @@ -198,7 +314,7 @@ def create( assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance( stream_inferface, AgentRefreshStreamingInterface ), type(stream_inferface) - return openai_chat_completions_process_stream( + response = openai_chat_completions_process_stream( url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, chat_completion_request=data, @@ -217,7 +333,11 @@ def create( finally: if isinstance(stream_inferface, AgentChunkStreamingInterface): stream_inferface.stream_end() - return response + + if inner_thoughts_in_kwargs: + response = unpack_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + + return response # azure elif llm_config.model_endpoint_type == "azure": diff --git a/memgpt/main.py b/memgpt/main.py index 567503788f..1de91bd69a 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -34,6 +34,7 @@ REQ_HEARTBEAT_MESSAGE, ) from memgpt.metadata import MetadataStore +from memgpt.models.pydantic_models import OptionState # from memgpt.interface import CLIInterface as interface # for printing to terminal from memgpt.streaming_interface import AgentRefreshStreamingInterface @@ -71,7 +72,14 @@ def clear_line(console, strip_ui=False): def run_agent_loop( - memgpt_agent: agent.Agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False + memgpt_agent: agent.Agent, + config: MemGPTConfig, + first: bool, + ms: MetadataStore, + no_verify: bool = False, + strip_ui: bool = False, + stream: bool = False, + inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, ): if isinstance(memgpt_agent.interface, AgentRefreshStreamingInterface): # memgpt_agent.interface.toggle_streaming(on=stream) @@ -386,6 +394,7 @@ def process_agent_step(user_message, no_verify): first_message=False, skip_verify=no_verify, stream=stream, + inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ) skip_next_user_input = False @@ -419,7 +428,7 @@ def process_agent_step(user_message, no_verify): retry = questionary.confirm("Retry agent.step()?").ask() if not retry: break - except Exception as e: + except Exception: print("An exception occurred when running agent.step(): ") traceback.print_exc() retry = questionary.confirm("Retry agent.step()?").ask() diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 24ea2af69b..a7e7819681 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -13,6 +13,14 @@ from memgpt.utils import get_human_text, get_persona_text, get_utc_time +class OptionState(str, Enum): + """Useful for kwargs that are bool + default option""" + + YES = "yes" + NO = "no" + DEFAULT = "default" + + class MemGPTUsageStatistics(BaseModel): completion_tokens: int prompt_tokens: int