Skip to content

Commit

Permalink
Relax verify_first_message_correctness to accept any function call (l…
Browse files Browse the repository at this point in the history
…etta-ai#340)

* Relax verify_first_message_correctness to accept any function call

* Also allow missing internal monologue if request_heartbeat

* Cleanup

* get instead of raw dict access
  • Loading branch information
vivi authored Nov 7, 2023
1 parent b1c5566 commit 30e9110
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,25 +466,37 @@ def load_from_json_file_inplace(self, json_file):
state = json.load(file)
self.load_inplace(state)

def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
def verify_first_message_correctness(self, response, require_function_call=True, require_monologue=False):
"""Can be used to enforce that the first message always uses send_message"""
response_message = response.choices[0].message

# First message should be a call to send_message with a non-empty content
if require_send_message and not response_message.get("function_call"):
if require_function_call and not response_message.get("function_call"):
printd(f"First message didn't include function call: {response_message}")
return False

expected_function_calls = [x["name"] for x in self.functions]
function_name = response_message["function_call"]["name"]
if require_send_message and function_name != "send_message":
printd(f"First message function call wasn't send_message: {response_message}")
if require_function_call and function_name not in expected_function_calls:
printd(f"First message function call wasn't one of {expected_function_calls}: {response_message}")
return False

if require_monologue and (
not response_message.get("content") or response_message["content"] is None or response_message["content"] == ""
):
printd(f"First message missing internal monologue: {response_message}")
return False
if function_name in expected_function_calls:
try:
raw_function_args = response_message["function_call"]["arguments"]
function_args = parse_json(raw_function_args)
except Exception as e:
printd(f"First message missing internal monologue and has badly formed arguments: {response_message}")
return False
if function_args.get("request_heartbeat") != True:
printd(f"First message missing internal monologue and does not chain further functions: {response_message}")
return False
else:
printd(f"First message missing internal monologue: {response_message}")
return False

if response_message.get("content"):
### Extras
Expand Down

0 comments on commit 30e9110

Please sign in to comment.