From b1cf71fb96b09aa71d1c1fdb4e13aa6fa68b14e5 Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 1 Apr 2024 13:06:19 -0700 Subject: [PATCH 1/3] feat: added groq support via local option w/ auth --- memgpt/cli/cli_config.py | 77 ++++++++++++++--------- memgpt/local_llm/chat_completion_proxy.py | 3 + memgpt/local_llm/constants.py | 4 ++ memgpt/local_llm/json_parser.py | 17 ++++- tests/test_json_parsers.py | 9 +++ 5 files changed, 80 insertions(+), 30 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index bbdaa43173..5dd5e52e2e 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -132,7 +132,9 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) model_endpoint = azure_creds["azure_endpoint"] else: # local models - backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] + # backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] + backend_options = builtins.list(DEFAULT_ENDPOINTS.keys()) + # assert backend_options_old == backend_options, (backend_options_old, backend_options) default_model_endpoint_type = None if config.default_llm_config.model_endpoint_type in backend_options: # set from previous config @@ -223,8 +225,12 @@ def get_model_options( else: # Attempt to do OpenAI endpoint style model fetching - # TODO support local auth - fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=None, fix_url=True) + # TODO support local auth with api-key header + if credentials.openllm_auth_type == "bearer_token": + api_key = credentials.openllm_key + else: + api_key = None + fetched_model_options_response = openai_get_model_list(url=model_endpoint, api_key=api_key, fix_url=True) model_options = [obj["id"] for obj in fetched_model_options_response["data"]] # NOTE no filtering of local model options @@ -289,6 +295,44 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ raise KeyboardInterrupt else: # local models + + # ask about local auth + if model_endpoint_type in ["groq"]: # TODO all llm engines under 'local' that will require api keys + use_local_auth = True + local_auth_type = "bearer_token" + local_auth_key = questionary.password( + "Enter your Groq API key:", + ).ask() + if local_auth_key is None: + raise KeyboardInterrupt + credentials.openllm_auth_type = local_auth_type + credentials.openllm_key = local_auth_key + credentials.save() + else: + use_local_auth = questionary.confirm( + "Is your LLM endpoint authenticated? (default no)", + default=False, + ).ask() + if use_local_auth is None: + raise KeyboardInterrupt + if use_local_auth: + local_auth_type = questionary.select( + "What HTTP authentication method does your endpoint require?", + choices=SUPPORTED_AUTH_TYPES, + default=SUPPORTED_AUTH_TYPES[0], + ).ask() + if local_auth_type is None: + raise KeyboardInterrupt + local_auth_key = questionary.password( + "Enter your authentication key:", + ).ask() + if local_auth_key is None: + raise KeyboardInterrupt + # credentials = MemGPTCredentials.load() + credentials.openllm_auth_type = local_auth_type + credentials.openllm_key = local_auth_key + credentials.save() + # ollama also needs model type if model_endpoint_type == "ollama": default_model = ( @@ -311,7 +355,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ ) # vllm needs huggingface model tag - if model_endpoint_type == "vllm": + if model_endpoint_type in ["vllm", "groq"]: try: # Don't filter model list for vLLM since model list is likely much smaller than OpenAI/Azure endpoint # + probably has custom model names @@ -366,31 +410,6 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if model_wrapper is None: raise KeyboardInterrupt - # ask about local auth - use_local_auth = questionary.confirm( - "Is your LLM endpoint authenticated? (default no)", - default=False, - ).ask() - if use_local_auth is None: - raise KeyboardInterrupt - if use_local_auth: - local_auth_type = questionary.select( - "What HTTP authentication method does your endpoint require?", - choices=SUPPORTED_AUTH_TYPES, - default=SUPPORTED_AUTH_TYPES[0], - ).ask() - if local_auth_type is None: - raise KeyboardInterrupt - local_auth_key = questionary.password( - "Enter your authentication key:", - ).ask() - if local_auth_key is None: - raise KeyboardInterrupt - # credentials = MemGPTCredentials.load() - credentials.openllm_auth_type = local_auth_type - credentials.openllm_key = local_auth_key - credentials.save() - # set: context_window if str(model) not in LLM_MAX_TOKENS: # Ask the user to specify the context length diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 8d87c65764..b6fa2a758c 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -13,6 +13,7 @@ from memgpt.local_llm.koboldcpp.api import get_koboldcpp_completion from memgpt.local_llm.ollama.api import get_ollama_completion from memgpt.local_llm.vllm.api import get_vllm_completion +from memgpt.local_llm.groq.api import get_groq_completion from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper from memgpt.local_llm.constants import DEFAULT_WRAPPER from memgpt.local_llm.utils import get_available_wrappers, count_tokens @@ -155,6 +156,8 @@ def get_chat_completion( result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window) elif endpoint_type == "vllm": result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user) + elif endpoint_type == "groq": + result, usage = get_groq_completion(endpoint, auth_type, auth_key, model, prompt, context_window) else: raise LocalLLMError( f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" diff --git a/memgpt/local_llm/constants.py b/memgpt/local_llm/constants.py index 139e3a2576..a3734f165c 100644 --- a/memgpt/local_llm/constants.py +++ b/memgpt/local_llm/constants.py @@ -2,6 +2,7 @@ from memgpt.local_llm.llm_chat_completion_wrappers.chatml import ChatMLInnerMonologueWrapper, ChatMLOuterInnerMonologueWrapper DEFAULT_ENDPOINTS = { + # Local "koboldcpp": "http://localhost:5001", "llamacpp": "http://localhost:8080", "lmstudio": "http://localhost:1234", @@ -10,6 +11,9 @@ "webui-legacy": "http://localhost:5000", "webui": "http://localhost:5000", "vllm": "http://localhost:8000", + # APIs + "openai": "https://api.openai.com", + "groq": "https://api.groq.com/openai", } DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K" diff --git a/memgpt/local_llm/json_parser.py b/memgpt/local_llm/json_parser.py index 9c7b0ca761..5ca712fbe6 100644 --- a/memgpt/local_llm/json_parser.py +++ b/memgpt/local_llm/json_parser.py @@ -5,7 +5,19 @@ from memgpt.errors import LLMJSONParsingError -def extract_first_json(string): +def replace_escaped_underscores(string: str): + """Handles the case of escaped underscores, e.g.: + + { + "function":"send\_message", + "params": { + "inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.", + "message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?" + """ + return string.replace("\_", "_") + + +def extract_first_json(string: str): """Handles the case of two JSON objects back-to-back""" from memgpt.utils import printd @@ -163,6 +175,9 @@ def clean_json(raw_llm_output, messages=None, functions=None): lambda output: json.loads(repair_even_worse_json(output), strict=JSON_LOADS_STRICT), lambda output: extract_first_json(output + "}}"), lambda output: clean_and_interpret_send_message_json(output), + # replace underscores + lambda output: json.loads(replace_escaped_underscores(output), strict=JSON_LOADS_STRICT), + lambda output: extract_first_json(replace_escaped_underscores(output) + "}}"), ] for strategy in strategies: diff --git a/tests/test_json_parsers.py b/tests/test_json_parsers.py index 6a260e1fd3..3ada8e3429 100644 --- a/tests/test_json_parsers.py +++ b/tests/test_json_parsers.py @@ -4,6 +4,14 @@ import memgpt.local_llm.json_parser as json_parser +EXAMPLE_ESCAPED_UNDERSCORES = """{ + "function":"send\_message", + "params": { + "inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.", + "message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?" +""" + + EXAMPLE_MISSING_CLOSING_BRACE = """{ "function": "send_message", "params": { @@ -72,6 +80,7 @@ def test_json_parsers(): """Try various broken JSON and check that the parsers can fix it""" test_strings = [ + EXAMPLE_ESCAPED_UNDERSCORES, EXAMPLE_MISSING_CLOSING_BRACE, EXAMPLE_BAD_TOKEN_END, EXAMPLE_DOUBLE_JSON, From df9ff7bf11f69cd7caee6a4f9b9ba02626fc1fc9 Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 1 Apr 2024 13:07:28 -0700 Subject: [PATCH 2/3] fix: missing file --- memgpt/local_llm/groq/api.py | 79 ++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 memgpt/local_llm/groq/api.py diff --git a/memgpt/local_llm/groq/api.py b/memgpt/local_llm/groq/api.py new file mode 100644 index 0000000000..c680af3014 --- /dev/null +++ b/memgpt/local_llm/groq/api.py @@ -0,0 +1,79 @@ +from urllib.parse import urljoin +import requests +from typing import Tuple + +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import post_json_auth_request +from memgpt.utils import count_tokens + + +API_CHAT_SUFFIX = "/v1/chat/completions" +# LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" + + +def get_groq_completion(endpoint: str, auth_type: str, auth_key: str, model: str, prompt: str, context_window: int) -> Tuple[str, dict]: + """TODO no support for function calling OR raw completions, so we need to route the request into /chat/completions instead""" + from memgpt.utils import printd + + prompt_tokens = count_tokens(prompt) + if prompt_tokens > context_window: + raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") + + settings = get_completions_settings() + settings.update( + { + # see https://console.groq.com/docs/text-chat, supports: + # "temperature": , + # "max_tokens": , + # "top_p", + # "stream", + # "stop", + } + ) + + URI = urljoin(endpoint.strip("/") + "/", API_CHAT_SUFFIX.strip("/")) + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["model"] = model + request["max_tokens"] = context_window + # NOTE: Hack for chat/completion-only endpoints: put the entire completion string inside the first message + message_structure = [{"role": "user", "content": prompt}] + request["messages"] = message_structure + + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") + + try: + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) + if response.status_code == 200: + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["choices"][0]["message"]["content"] + usage = result_full.get("usage", None) + else: + # Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"} + if "context length" in str(response.text).lower(): + # "exceeds context length" is what appears in the LM Studio error message + # raise an alternate exception that matches OpenAI's message, which is "maximum context length" + raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})") + else: + raise Exception( + f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}." + + f" Make sure that the inference server is running and reachable at {URI}." + ) + except: + # TODO handle gracefully + raise + + # Pass usage statistics back to main thread + # These are used to compute memory warning messages + completion_tokens = usage.get("completion_tokens", None) if usage is not None else None + total_tokens = prompt_tokens + completion_tokens if completion_tokens is not None else None + usage = { + "prompt_tokens": prompt_tokens, # can grab from usage dict, but it's usually wrong (set to 0) + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + return result, usage From 90566f3282ca555b4f74eecf5beea95b32a53f1d Mon Sep 17 00:00:00 2001 From: cpacker Date: Mon, 1 Apr 2024 13:11:39 -0700 Subject: [PATCH 3/3] black --- paper_experiments/nested_kv_task/nested_kv.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paper_experiments/nested_kv_task/nested_kv.py b/paper_experiments/nested_kv_task/nested_kv.py index 4cf6ae9c72..5aa8c90b86 100644 --- a/paper_experiments/nested_kv_task/nested_kv.py +++ b/paper_experiments/nested_kv_task/nested_kv.py @@ -18,6 +18,7 @@ sample 30 different ordering configurations including both the initial key position and nesting key positions. """ + import math import json import argparse