From 3b4e78285c848a5fd235685f9553a1cb8878e6e8 Mon Sep 17 00:00:00 2001 From: robingotz Date: Tue, 5 Mar 2024 19:57:06 +0100 Subject: [PATCH] feat: add endpoints to create human/persona --- memgpt/server/rest_api/humans/index.py | 19 +++++++++++++++++-- memgpt/server/rest_api/personas/index.py | 19 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/memgpt/server/rest_api/humans/index.py b/memgpt/server/rest_api/humans/index.py index a43cd5b959..a2e7242648 100644 --- a/memgpt/server/rest_api/humans/index.py +++ b/memgpt/server/rest_api/humans/index.py @@ -2,13 +2,13 @@ from functools import partial from typing import List -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Body from pydantic import BaseModel, Field +from memgpt.models.pydantic_models import HumanModel from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer -from memgpt.models.pydantic_models import HumanModel router = APIRouter() @@ -17,6 +17,11 @@ class ListHumansResponse(BaseModel): humans: List[HumanModel] = Field(..., description="List of human configurations.") +class CreateHumanRequest(BaseModel): + text: str = Field(..., description="The human text.") + name: str = Field(..., description="The name of the human.") + + def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) @@ -29,4 +34,14 @@ async def list_humans( humans = server.ms.list_humans(user_id=user_id) return ListHumansResponse(humans=humans) + @router.post("/humans", tags=["humans"], response_model=HumanModel) + async def create_persona( + request: CreateHumanRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + new_human = HumanModel(text=request.text, name=request.name, user_id=user_id) + server.ms.add_human(new_human) + return new_human + return router diff --git a/memgpt/server/rest_api/personas/index.py b/memgpt/server/rest_api/personas/index.py index 468b937eeb..dd30370a29 100644 --- a/memgpt/server/rest_api/personas/index.py +++ b/memgpt/server/rest_api/personas/index.py @@ -2,13 +2,13 @@ from functools import partial from typing import List -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Body from pydantic import BaseModel, Field +from memgpt.models.pydantic_models import PersonaModel from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer -from memgpt.models.pydantic_models import PersonaModel router = APIRouter() @@ -17,6 +17,11 @@ class ListPersonasResponse(BaseModel): personas: List[PersonaModel] = Field(..., description="List of persona configurations.") +class CreatePersonaRequest(BaseModel): + text: str = Field(..., description="The persona text.") + name: str = Field(..., description="The name of the persona.") + + def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str): get_current_user_with_server = partial(partial(get_current_user, server), password) @@ -30,4 +35,14 @@ async def list_personas( personas = server.ms.list_personas(user_id=user_id) return ListPersonasResponse(personas=personas) + @router.post("/personas", tags=["personas"], response_model=PersonaModel) + async def create_persona( + request: CreatePersonaRequest = Body(...), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + interface.clear() + new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id) + server.ms.add_persona(new_persona) + return new_persona + return router