Skip to content

Commit

Permalink
add option to clear history in cli
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Jun 12, 2024
1 parent bf8d04a commit a04b0c0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 8 deletions.
17 changes: 14 additions & 3 deletions reginald/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def run_all(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
"""
Run all the components of the Reginald slack bot.
Establishes the connection to the Slack API, sets up the bot,
and creates a Reginald model to query from.
"""
set_up_logging_config(level=20)
main(
cli="run_all",
Expand Down Expand Up @@ -136,7 +141,7 @@ def bot(
] = EMOJI_DEFAULT,
) -> None:
"""
Main function to run the Slack bot which sets up the bot
Run the Slack bot which sets up the bot
(which uses an API for responding to messages) and
then establishes a WebSocket connection to the
Socket Mode servers and listens for events.
Expand Down Expand Up @@ -214,8 +219,8 @@ def app(
] = DEFAULT_ARGS["device"],
) -> None:
"""
Main function to run the app which sets up the response model
and then creates a FastAPI app to serve the model.
Sets up the response model and then creates a
FastAPI app to serve the model.
The app listens on port 8000 and has two endpoints:
- /direct_message: for obtaining responses from direct messages
Expand Down Expand Up @@ -263,6 +268,9 @@ def create_index(
int, typer.Option(envvar="LLAMA_INDEX_NUM_OUTPUT")
] = DEFAULT_ARGS["num_output"],
) -> None:
"""
Create an index for the Reginald model.
"""
set_up_logging_config(level=20)
main(
cli="create_index",
Expand Down Expand Up @@ -346,6 +354,9 @@ def chat(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
"""
Run the chat interaction with the Reginald model.
"""
set_up_logging_config(level=40)
main(
cli="chat",
Expand Down
1 change: 1 addition & 0 deletions reginald/models/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, emoji: Optional[str], *args: Any, **kwargs: Any):
Emoji to use for the bot's response
"""
self.emoji = emoji
self.mode = "NA"

def direct_message(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError
Expand Down
22 changes: 17 additions & 5 deletions reginald/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LISTENING_MSG: Final[str] = "Listening for requests..."


async def run_bot(api_url: str | None, emoji: str):
async def run_bot(api_url: str | None, emoji: str) -> None:
if api_url is None:
logging.error(
"API URL is not set. Please set the REGINALD_API_URL "
Expand All @@ -44,7 +44,7 @@ async def run_reginald_app(**kwargs) -> None:
uvicorn.run(app, host="0.0.0.0", port=8000)


async def run_full_pipeline(**kwargs):
async def run_full_pipeline(**kwargs) -> None:
# set up response model
response_model = setup_llm(**kwargs)
bot = setup_slack_bot(response_model)
Expand All @@ -56,20 +56,32 @@ async def run_full_pipeline(**kwargs):
def run_chat_interact(streaming: bool = False, **kwargs) -> ResponseModel:
# set up response model
response_model = setup_llm(**kwargs)
user_id = "command_line_chat"

while True:
message = input(">>> ")
if message in ["exit", "exit()", "quit()", "bye Reginald"]:
return response_model
if message in ["clear_history", "\clear_history"]:
if (
response_model.mode == "chat"
and response_model.chat_engine.get(user_id) is not None
):
response_model.chat_engine[user_id].reset()
print("\nReginald: History cleared.")
else:
print("\nReginald: No history to clear.")
continue

if streaming:
response = response_model.stream_message(message=message, user_id="chat")
response = response_model.stream_message(message=message, user_id=user_id)
print("")
else:
response = response_model.direct_message(message=message, user_id="chat")
response = response_model.direct_message(message=message, user_id=user_id)
print(f"\nReginald: {response.message}")


async def connect_client(client: SocketModeClient):
async def connect_client(client: SocketModeClient) -> None:
await client.connect()
# listen for events
logging.info(LISTENING_MSG)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_chat_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ def test_chat_interact_bye():
with mock.patch.object(builtins, "input", lambda _: "bye Reginald"):
interaction = run_chat_interact(model="hello")
assert isinstance(interaction, Hello)


def test_chat_interact_clear_history():
result = runner.invoke(cli, ["chat"], input="clear_history\n")
term_stdout_lines: list[str] = result.stdout.split("\n")
assert term_stdout_lines[0] == ">>> "
assert term_stdout_lines[1] == "Reginald: No history to clear."
assert term_stdout_lines[2] == ">>> "


def test_chat_interact_slash_clear_history():
result = runner.invoke(cli, ["chat"], input="\clear_history\n")
term_stdout_lines: list[str] = result.stdout.split("\n")
assert term_stdout_lines[0] == ">>> "
assert term_stdout_lines[1] == "Reginald: No history to clear."
assert term_stdout_lines[2] == ">>> "

0 comments on commit a04b0c0

Please sign in to comment.