Skip to content

Commit

Permalink
fix: CLI conveniences (add-on to #674) (#675)
Browse files Browse the repository at this point in the history
* for openai, check for key and if missing allow user to pass it, for azure, throw error if the key isn't present

* correct prior checking of azure to be more strict, added similar checks at the embedding endpoint config stage

* forgot to override value in config before saving

* clean up the valuerrors from missing keys so that no stacktrace gets printed, make success text green to match others
  • Loading branch information
cpacker authored Dec 22, 2023
1 parent cfbec58 commit 09c7fa7
Showing 1 changed file with 59 additions and 25 deletions.
84 changes: 59 additions & 25 deletions memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,36 @@ def configure_llm_endpoint(config: MemGPTConfig):

# set: model_endpoint_type, model_endpoint
if provider == "openai":
# check for key
if config.openai_key is None:
# allow key to get pulled from env vars
openai_api_key = os.getenv("OPENAI_API_KEY", None)
if openai_api_key is None:
# if we still can't find it, ask for it as input
while openai_api_key is None or len(openai_api_key) == 0:
# Ask for API key as input
openai_api_key = questionary.text(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
config.openai_key = openai_api_key
config.save()

model_endpoint_type = "openai"
model_endpoint = "https://api.openai.com/v1"
model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask()
provider = "openai"

elif provider == "azure":
# check for necessary vars
azure_creds = get_azure_credentials()
if not all([azure_creds["azure_key"], azure_creds["azure_endpoint"], azure_creds["azure_version"]]):
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)

model_endpoint_type = "azure"
model_endpoint = get_azure_credentials()["azure_endpoint"]
model_endpoint = azure_creds["azure_endpoint"]

else: # local models
backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
default_model_endpoint_type = None
Expand Down Expand Up @@ -178,14 +201,38 @@ def configure_embedding_endpoint(config: MemGPTConfig):
embedding_provider = questionary.select(
"Select embedding provider:", choices=["openai", "azure", "hugging-face", "local"], default=default_embedding_endpoint_type
).ask()

if embedding_provider == "openai":
# check for key
if config.openai_key is None:
# allow key to get pulled from env vars
openai_api_key = os.getenv("OPENAI_API_KEY", None)
if openai_api_key is None:
# if we still can't find it, ask for it as input
while openai_api_key is None or len(openai_api_key) == 0:
# Ask for API key as input
openai_api_key = questionary.text(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
config.openai_key = openai_api_key
config.save()

embedding_endpoint_type = "openai"
embedding_endpoint = "https://api.openai.com/v1"
embedding_dim = 1536

elif embedding_provider == "azure":
# check for necessary vars
azure_creds = get_azure_credentials()
if not all([azure_creds["azure_key"], azure_creds["azure_embedding_endpoint"], azure_creds["azure_embedding_version"]]):
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again."
)

embedding_endpoint_type = "azure"
embedding_endpoint = get_azure_credentials()["azure_embedding_endpoint"]
embedding_endpoint = azure_creds["azure_embedding_endpoint"]
embedding_dim = 1536

elif embedding_provider == "hugging-face":
# configure hugging face embedding endpoint (https://github.com/huggingface/text-embeddings-inference)
# supports custom model/endpoints
Expand Down Expand Up @@ -292,32 +339,19 @@ def configure():
openai_key = get_openai_credentials()
azure_creds = get_azure_credentials()

if not openai_key and all(value is None or value == "" for value in azure_creds.values()):
raise ValueError(
"Missing environment variables (see https://memgpt.readme.io/docs/endpoints). Please set them and run `memgpt configure` again."
)
else: # Detecting non-empty configurations for Azure or OpenAI
detected_services = []

if all([azure_creds["azure_key"], azure_creds["azure_endpoint"], azure_creds["azure_version"]]):
detected_services.append("Azure")

if openai_key:
detected_services.append("OpenAI")

if detected_services:
detected_services_message = ", ".join(detected_services)
typer.secho(f"Detected {detected_services_message} configuration.", fg=typer.colors.YELLOW)

MemGPTConfig.create_config_dir()

# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
try:
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
except ValueError as e:
typer.secho(str(e), fg=typer.colors.RED)
return

config = MemGPTConfig(
# model configs
Expand Down Expand Up @@ -348,7 +382,7 @@ def configure():
archival_storage_uri=archival_storage_uri,
archival_storage_path=archival_storage_path,
)
print(f"Saving config to {config.config_path}")
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
config.save()


Expand Down

0 comments on commit 09c7fa7

Please sign in to comment.