Skip to content

Commit

Permalink
feat: enable tool selection via agent creation POST (#1137)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Mar 12, 2024
1 parent 994c015 commit efa12cf
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
18 changes: 17 additions & 1 deletion memgpt/server/rest_api/agents/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,23 @@ def create_agent(
interface.clear()

try:
agent_state = server.create_agent(user_id=user_id, **request.config)
try:
agent_state = server.create_agent(
user_id=user_id,
# **request.config
# TODO turn into a pydantic model
name=request.config["name"],
preset=request.config["preset"] if "preset" in request.config else None,
persona=request.config["persona_text"] if "persona_text" in request.config else None,
human=request.config["human_text"] if "human_text" in request.config else None,
# llm_config=LLMConfigModel(
# model=request.config['model'],
# )
function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
)
except:
print(f"Failed to create agent from provided config:\n{request.config}")
raise
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return CreateAgentResponse(
Expand Down
9 changes: 9 additions & 0 deletions memgpt/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def create_agent(
embedding_config: Optional[EmbeddingConfig] = None,
interface: Union[AgentInterface, None] = None,
# persistence_manager: Union[PersistenceManager, None] = None,
function_names: Optional[List[str]] = None, # TODO remove
) -> AgentState:
"""Create a new agent using a config"""
if self.ms.get_user(user_id=user_id) is None:
Expand Down Expand Up @@ -610,6 +611,14 @@ def create_agent(
llm_config = llm_config if llm_config else self.server_llm_config
embedding_config = embedding_config if embedding_config else self.server_embedding_config

# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
if function_names is not None:
available_tools = self.ms.list_tools(user_id=user_id)
available_tools_names = [t.name for t in available_tools]
assert all([f_name in available_tools_names for f_name in function_names])
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
print("overriding preset_obj tools with:", preset_obj.functions_schema)

agent = Agent(
interface=interface,
preset=preset_obj,
Expand Down

0 comments on commit efa12cf

Please sign in to comment.