Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added groq support via local option w/ auth #1203

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 48 additions & 29 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
model_endpoint = azure_creds["azure_endpoint"]

else: # local models
backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
# assert backend_options_old == backend_options, (backend_options_old, backend_options)
default_model_endpoint_type = None
if config.default_llm_config.model_endpoint_type in backend_options:
# set from previous config
Expand Down Expand Up @@ -223,8 +225,12 @@ def get_model_options(

else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=None, fix_url=True)
# TODO support local auth with api-key header
if credentials.openllm_auth_type == "bearer_token":
api_key = credentials.openllm_key
else:
api_key = None
fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=api_key, fix_url=True)
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]
# NOTE no filtering of local model options

Expand Down Expand Up @@ -289,6 +295,44 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
raise KeyboardInterrupt

else: # local models

# ask about local auth
if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys
use_local_auth = True
local_auth_type = "bearer_token"
local_auth_key = questionary.password(
"Enter your Groq API key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()
else:
use_local_auth = questionary.confirm(
cpacker marked this conversation as resolved.
Show resolved Hide resolved
"Is your LLM endpoint authenticated? (default no)",
default=False,
).ask()
if use_local_auth is None:
raise KeyboardInterrupt
if use_local_auth:
local_auth_type = questionary.select(
"What HTTP authentication method does your endpoint require?",
choices=SUPPORTED_AUTH_TYPES,
default=SUPPORTED_AUTH_TYPES[0],
).ask()
if local_auth_type is None:
raise KeyboardInterrupt
local_auth_key = questionary.password(
"Enter your authentication key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
# credentials = MemGPTCredentials.load()
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()

# ollama also needs model type
if model_endpoint_type == "ollama":
default_model = (
Expand All @@ -311,7 +355,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
)

# vllm needs huggingface model tag
if model_endpoint_type == "vllm":
if model_endpoint_type in ["vllm", "groq"]:
try:
# Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint
# + probably has custom model names
Expand Down Expand Up @@ -366,31 +410,6 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model_wrapper is None:
raise KeyboardInterrupt

# ask about local auth
use_local_auth = questionary.confirm(
"Is your LLM endpoint authenticated? (default no)",
default=False,
).ask()
if use_local_auth is None:
raise KeyboardInterrupt
if use_local_auth:
local_auth_type = questionary.select(
"What HTTP authentication method does your endpoint require?",
choices=SUPPORTED_AUTH_TYPES,
default=SUPPORTED_AUTH_TYPES[0],
).ask()
if local_auth_type is None:
raise KeyboardInterrupt
local_auth_key = questionary.password(
"Enter your authentication key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
# credentials = MemGPTCredentials.load()
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()

# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length
Expand Down
3 changes: 3 additions & 0 deletions memgpt/local_llm/chat_completion_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from memgpt.local_llm.koboldcpp.api import get_koboldcpp_completion
from memgpt.local_llm.ollama.api import get_ollama_completion
from memgpt.local_llm.vllm.api import get_vllm_completion
from memgpt.local_llm.groq.api import get_groq_completion
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
from memgpt.local_llm.constants import DEFAULT_WRAPPER
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
Expand Down Expand Up @@ -155,6 +156,8 @@ def get_chat_completion(
result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
elif endpoint_type == "vllm":
result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user)
elif endpoint_type == "groq":
result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window)
else:
raise LocalLLMError(
f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)"
Expand Down
4 changes: 4 additions & 0 deletions memgpt/local_llm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from memgpt.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper, ChatMLOuterInnerMonologueWrapper

DEFAULT_ENDPOINTS = {
# Local
"koboldcpp": "http://localhost:5001",
"llamacpp": "http://localhost:8080",
"lmstudio": "http://localhost:1234",
Expand All @@ -10,6 +11,9 @@
"webui-legacy": "http://localhost:5000",
"webui": "http://localhost:5000",
"vllm": "http://localhost:8000",
# APIs
"openai": "https://api.openai.com",
"groq": "https://api.groq.com/openai",
}

DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
Expand Down
79 changes: 79 additions & 0 deletions memgpt/local_llm/groq/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from urllib.parse import urljoin
import requests
from typing import Tuple

from memgpt.local_llm.settings.settings import get_completions_settings
from memgpt.local_llm.utils import post_json_auth_request
from memgpt.utils import count_tokens


API_CHAT_SUFFIX = "/v1/chat/completions"
# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"


def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]:
"""TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead"""
from memgpt.utils import printd

prompt_tokens = count_tokens(prompt)
if prompt_tokens > context_window:
raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)")

settings = get_completions_settings()
settings.update(
{
# see https://console.groq.com/docs/text-chat, supports:
# "temperature": ,
# "max_tokens": ,
# "top_p",
# "stream",
# "stop",
}
)

URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/"))

# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["model"] = model
request["max_tokens"] = context_window
# NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message
message_structure = [{"role": "user", "content": prompt}]
request["messages"] = message_structure

if not endpoint.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://")

try:
response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key)
if response.status_code == 200:
result_full = response.json()
printd(f"JSON API response:\n{result_full}")
result = result_full["choices"][0]["message"]["content"]
usage = result_full.get("usage", None)
else:
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
if "context length" in str(response.text).lower():
# "exceeds context length" is what appears in the LM Studio error message
# raise an alternate exception that matches OpenAI's message, which is "maximum context length"
raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})")
else:
raise Exception(
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
+ f" Make sure that the inference server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise

# Pass usage statistics back to main thread
# These are used to compute memory warning messages
completion_tokens = usage.get("completion_tokens", None) if usage is not None else None
total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None
usage = {
"prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0)
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}

return result, usage
17 changes: 16 additions & 1 deletion memgpt/local_llm/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@
from memgpt.errors import LLMJSONParsingError


def extract_first_json(string):
def replace_escaped_underscores(string: str):
cpacker marked this conversation as resolved.
Show resolved Hide resolved
"""Handles the case of escaped underscores, e.g.:

{
"function":"send\_message",
"params": {
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
"""
return string.replace("\_", "_")


def extract_first_json(string: str):
"""Handles the case of two JSON objects back-to-back"""
from memgpt.utils import printd

Expand Down Expand Up @@ -163,6 +175,9 @@ def clean_json(raw_llm_output, messages=None, functions=None):
lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT),
lambda output: extract_first_json(output + "}}"),
lambda output: clean_and_interpret_send_message_json(output),
# replace underscores
lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT),
lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"),
]

for strategy in strategies:
Expand Down
1 change: 1 addition & 0 deletions paper_experiments/nested_kv_task/nested_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
sample 30 different ordering configurations including both
the initial key position and nesting key positions.
"""

import math
import json
import argparse
Expand Down
9 changes: 9 additions & 0 deletions tests/test_json_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import memgpt.local_llm.json_parser as json_parser


EXAMPLE_ESCAPED_UNDERSCORES = """{
"function":"send\_message",
"params": {
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
"""


EXAMPLE_MISSING_CLOSING_BRACE = """{
"function": "send_message",
"params": {
Expand Down Expand Up @@ -72,6 +80,7 @@ def test_json_parsers():
"""Try various broken JSON and check that the parsers can fix it"""

test_strings = [
EXAMPLE_ESCAPED_UNDERSCORES,
EXAMPLE_MISSING_CLOSING_BRACE,
EXAMPLE_BAD_TOKEN_END,
EXAMPLE_DOUBLE_JSON,
Expand Down
Loading