Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming response #190

Merged
merged 15 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ typer = {extras = ["all"], version = "^0.12.3"}
langchain-community = "^0.2.4"
tiktoken = "^0.7.0"
llama-index-embeddings-huggingface = "^0.2.1"
rich = "^13.7.1"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -105,10 +106,13 @@ build-backend = "poetry.core.masonry.api"
minversion = "6.0"
testpaths = [
"tests",
"reginald",
]
addopts = """
--cov=estios
--cov=reginald
--cov-report=term:skip-covered
--cov-append
--pdbcls=IPython.terminal.debugger:TerminalPdb
--doctest-modules
"""
doctest_optionflags = ["NORMALIZE_WHITESPACE", "ELLIPSIS",]
25 changes: 22 additions & 3 deletions reginald/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"device": "Device to use (ignored if not using llama-index).",
"api_url": "API URL for the Reginald app.",
"emoji": "Emoji to use for the bot.",
"streaming": "Whether to use streaming for the chat interaction.",
}

cli = typer.Typer()
Expand Down Expand Up @@ -102,6 +103,11 @@ def run_all(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
"""
Run all the components of the Reginald slack bot.
Establishes the connection to the Slack API, sets up the bot,
and creates a Reginald model to query from.
"""
set_up_logging_config(level=20)
main(
cli="run_all",
Expand Down Expand Up @@ -135,7 +141,7 @@ def bot(
] = EMOJI_DEFAULT,
) -> None:
"""
Main function to run the Slack bot which sets up the bot
Run the Slack bot which sets up the bot
(which uses an API for responding to messages) and
then establishes a WebSocket connection to the
Socket Mode servers and listens for events.
Expand Down Expand Up @@ -213,8 +219,8 @@ def app(
] = DEFAULT_ARGS["device"],
) -> None:
"""
Main function to run the app which sets up the response model
and then creates a FastAPI app to serve the model.
Sets up the response model and then creates a
FastAPI app to serve the model.

The app listens on port 8000 and has two endpoints:
- /direct_message: for obtaining responses from direct messages
Expand Down Expand Up @@ -262,6 +268,9 @@ def create_index(
int, typer.Option(envvar="LLAMA_INDEX_NUM_OUTPUT")
] = DEFAULT_ARGS["num_output"],
) -> None:
"""
Create an index for the Reginald model.
"""
set_up_logging_config(level=20)
main(
cli="create_index",
Expand All @@ -288,6 +297,12 @@ def chat(
Optional[str],
typer.Option(envvar="REGINALD_MODEL_NAME", help=HELP_TEXT["model_name"]),
] = None,
streaming: Annotated[
bool,
typer.Option(
help=HELP_TEXT["streaming"],
),
] = True,
mode: Annotated[
str, typer.Option(envvar="LLAMA_INDEX_MODE", help=HELP_TEXT["mode"])
] = DEFAULT_ARGS["mode"],
Expand Down Expand Up @@ -339,9 +354,13 @@ def chat(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
"""
Run the chat interaction with the Reginald model.
"""
set_up_logging_config(level=40)
main(
cli="chat",
streaming=streaming,
model=model,
model_name=model_name,
mode=mode,
Expand Down
4 changes: 2 additions & 2 deletions reginald/models/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
}

DEFAULTS = {
"chat-completion-azure": "reginald-curie",
"chat-completion-azure": "reginald-gpt4",
"chat-completion-openai": "gpt-3.5-turbo",
"hello": None,
"llama-index-ollama": "llama3",
"llama-index-llama-cpp": "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q6_K.gguf",
"llama-index-hf": "microsoft/phi-1_5",
"llama-index-gpt-azure": "reginald-gpt35-turbo",
"llama-index-gpt-azure": "reginald-gpt4",
"llama-index-gpt-openai": "gpt-3.5-turbo",
}

Expand Down
4 changes: 4 additions & 0 deletions reginald/models/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def __init__(self, emoji: Optional[str], *args: Any, **kwargs: Any):
Emoji to use for the bot's response
"""
self.emoji = emoji
self.mode = "NA"

def direct_message(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError

def channel_mention(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError

def stream_message(self, message: str, user_id: str) -> None:
raise NotImplementedError
42 changes: 40 additions & 2 deletions reginald/models/models/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging
import os
import sys
from typing import Any

import openai
from openai import AzureOpenAI, OpenAI

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import get_env_var
from reginald.utils import get_env_var, stream_iter_progress_wrapper


class ChatCompletionBase(ResponseModel):
Expand Down Expand Up @@ -155,6 +154,35 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
"""
return self._respond(message=message, user_id=user_id)

def stream_message(self, message: str, user_id: str) -> None:
if self.mode == "chat":
response = self.client.chat.completions.create(
model=self.engine,
messages=[{"role": "user", "content": message}],
frequency_penalty=self.frequency_penalty,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
stop=None,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
)
elif self.mode == "query":
response = self.client.completions.create(
model=self.engine,
frequency_penalty=self.frequency_penalty,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
prompt=message,
stop=None,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
)

for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)


class ChatCompletionOpenAI(ChatCompletionBase):
def __init__(
Expand Down Expand Up @@ -233,3 +261,13 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
Response from the query engine.
"""
return self._respond(message=message, user_id=user_id)

def stream_message(self, message: str, user_id: str) -> None:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": message}],
stream=True,
)

for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)
7 changes: 7 additions & 0 deletions reginald/models/models/hello.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import stream_iter_progress_wrapper


class Hello(ResponseModel):
Expand All @@ -16,3 +17,9 @@ def direct_message(self, message: str, user_id: str) -> MessageResponse:

def channel_mention(self, message: str, user_id: str) -> MessageResponse:
return MessageResponse(f"Hello <@{user_id}>")

def stream_message(self, message: str, user_id: str) -> None:
# print("\nReginald: ", end="")
token_list: tuple[str, ...] = ("Hello", "!", " How", " are", " you", "?")
for token in stream_iter_progress_wrapper(token_list):
print(token, end="", flush=True)
Loading
Loading