Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor: refactor getLlmOptionsFromPayload from AgentRuntime #4790

Merged
merged 8 commits into from
Nov 26, 2024
240 changes: 21 additions & 219 deletions src/server/modules/AgentRuntime/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,54 +27,32 @@ export interface AgentChatOptions {
* @returns The options object.
*/
const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
const llmConfig = getLLMConfig() as Record<string, any>;

switch (provider) {
default: // Use Openai options as default
case ModelProvider.OpenAI: {
const { OPENAI_API_KEY } = getLLMConfig();
default: {
let upperProvider = provider.toUpperCase();

if (!llmConfig[`${upperProvider}_API_KEY`]) {
upperProvider = ModelProvider.OpenAI.toUpperCase(); // Use OpenAI options as default
}

const openaiApiKey = payload?.apiKey || OPENAI_API_KEY;
const baseURL = payload?.endpoint || process.env.OPENAI_PROXY_URL;
const apiKey = apiKeyManager.pick(openaiApiKey);
const apiKey = apiKeyManager.pick(payload?.apiKey || llmConfig[`${upperProvider}_API_KEY`]);
const baseURL = payload?.endpoint || process.env[`${upperProvider}_PROXY_URL`];

return { apiKey, baseURL };
}

case ModelProvider.Azure: {
const { AZURE_API_KEY, AZURE_API_VERSION, AZURE_ENDPOINT } = getLLMConfig();
const apiKey = apiKeyManager.pick(payload?.apiKey || AZURE_API_KEY);
const { AZURE_API_KEY, AZURE_API_VERSION, AZURE_ENDPOINT } = llmConfig;
const apikey = apiKeyManager.pick(payload?.apiKey || AZURE_API_KEY);
const endpoint = payload?.endpoint || AZURE_ENDPOINT;
const apiVersion = payload?.azureApiVersion || AZURE_API_VERSION;
return {
apiVersion,
apikey: apiKey,
endpoint,
};
}
case ModelProvider.ZhiPu: {
const { ZHIPU_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || ZHIPU_API_KEY);

return { apiKey };
return { apiVersion, apikey, endpoint };
}
case ModelProvider.Google: {
const { GOOGLE_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || GOOGLE_API_KEY);
const baseURL = payload?.endpoint || process.env.GOOGLE_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.Moonshot: {
const { MOONSHOT_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || MOONSHOT_API_KEY);
const baseURL = payload?.endpoint || process.env.MOONSHOT_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.Bedrock: {
const { AWS_SECRET_ACCESS_KEY, AWS_ACCESS_KEY_ID, AWS_REGION, AWS_SESSION_TOKEN } =
getLLMConfig();
const { AWS_SECRET_ACCESS_KEY, AWS_ACCESS_KEY_ID, AWS_REGION, AWS_SESSION_TOKEN } = llmConfig;
let accessKeyId: string | undefined = AWS_ACCESS_KEY_ID;
let accessKeySecret: string | undefined = AWS_SECRET_ACCESS_KEY;
let region = AWS_REGION;
Expand All @@ -88,127 +66,9 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {
}
return { accessKeyId, accessKeySecret, region, sessionToken };
}
case ModelProvider.Ollama: {
const baseURL = payload?.endpoint || process.env.OLLAMA_PROXY_URL;
return { baseURL };
}
case ModelProvider.Perplexity: {
const { PERPLEXITY_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || PERPLEXITY_API_KEY);
const baseURL = payload?.endpoint || process.env.PERPLEXITY_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.Anthropic: {
const { ANTHROPIC_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || ANTHROPIC_API_KEY);
const baseURL = payload?.endpoint || process.env.ANTHROPIC_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.Minimax: {
const { MINIMAX_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || MINIMAX_API_KEY);

return { apiKey };
}
case ModelProvider.Mistral: {
const { MISTRAL_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || MISTRAL_API_KEY);

return { apiKey };
}
case ModelProvider.Groq: {
const { GROQ_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || GROQ_API_KEY);
const baseURL = payload?.endpoint || process.env.GROQ_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.Github: {
const { GITHUB_TOKEN } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || GITHUB_TOKEN);

return { apiKey };
}
case ModelProvider.OpenRouter: {
const { OPENROUTER_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || OPENROUTER_API_KEY);

return { apiKey };
}
case ModelProvider.DeepSeek: {
const { DEEPSEEK_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || DEEPSEEK_API_KEY);

return { apiKey };
}
case ModelProvider.TogetherAI: {
const { TOGETHERAI_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || TOGETHERAI_API_KEY);

return { apiKey };
}
case ModelProvider.FireworksAI: {
const { FIREWORKSAI_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || FIREWORKSAI_API_KEY);

return { apiKey };
}
case ModelProvider.ZeroOne: {
const { ZEROONE_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || ZEROONE_API_KEY);

return { apiKey };
}
case ModelProvider.Qwen: {
const { QWEN_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || QWEN_API_KEY);

return { apiKey };
}
case ModelProvider.Stepfun: {
const { STEPFUN_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || STEPFUN_API_KEY);

return { apiKey };
}
case ModelProvider.Novita: {
const { NOVITA_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || NOVITA_API_KEY);

return { apiKey };
}
case ModelProvider.Baichuan: {
const { BAICHUAN_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || BAICHUAN_API_KEY);

return { apiKey };
}
case ModelProvider.Taichu: {
const { TAICHU_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || TAICHU_API_KEY);

return { apiKey };
}
case ModelProvider.Cloudflare: {
const { CLOUDFLARE_API_KEY, CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID } = getLLMConfig();
const { CLOUDFLARE_API_KEY, CLOUDFLARE_BASE_URL_OR_ACCOUNT_ID } = llmConfig;

const apiKey = apiKeyManager.pick(payload?.apiKey || CLOUDFLARE_API_KEY);
const baseURLOrAccountID =
Expand All @@ -218,68 +78,25 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {

return { apiKey, baseURLOrAccountID };
}
case ModelProvider.Ai360: {
const { AI360_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || AI360_API_KEY);

return { apiKey };
}
case ModelProvider.SiliconCloud: {
const { SILICONCLOUD_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || SILICONCLOUD_API_KEY);
const baseURL = payload?.endpoint || process.env.SILICONCLOUD_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.GiteeAI: {
const { GITEE_AI_API_KEY } = getLLMConfig();
const { GITEE_AI_API_KEY } = llmConfig;

const apiKey = apiKeyManager.pick(payload?.apiKey || GITEE_AI_API_KEY);

return { apiKey };
}

case ModelProvider.HuggingFace: {
const { HUGGINGFACE_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || HUGGINGFACE_API_KEY);
const baseURL = payload?.endpoint || process.env.HUGGINGFACE_PROXY_URL;

return { apiKey, baseURL };
}

case ModelProvider.Upstage: {
const { UPSTAGE_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || UPSTAGE_API_KEY);

return { apiKey };
}
case ModelProvider.Spark: {
const { SPARK_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || SPARK_API_KEY);

return { apiKey };
}
case ModelProvider.Ai21: {
const { AI21_API_KEY } = getLLMConfig();
case ModelProvider.Github: {
const { GITHUB_TOKEN } = llmConfig;

const apiKey = apiKeyManager.pick(payload?.apiKey || AI21_API_KEY);
const apiKey = apiKeyManager.pick(payload?.apiKey || GITHUB_TOKEN);

return { apiKey };
}
case ModelProvider.Hunyuan: {
const { HUNYUAN_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || HUNYUAN_API_KEY);

return { apiKey };
}
case ModelProvider.SenseNova: {
const { SENSENOVA_ACCESS_KEY_ID, SENSENOVA_ACCESS_KEY_SECRET } = getLLMConfig();
const { SENSENOVA_ACCESS_KEY_ID, SENSENOVA_ACCESS_KEY_SECRET } = llmConfig;

const sensenovaAccessKeyID = apiKeyManager.pick(
payload?.sensenovaAccessKeyID || SENSENOVA_ACCESS_KEY_ID,
Expand All @@ -290,21 +107,6 @@ const getLlmOptionsFromPayload = (provider: string, payload: JWTPayload) => {

const apiKey = sensenovaAccessKeyID + ':' + sensenovaAccessKeySecret;

return { apiKey };
}
case ModelProvider.XAI: {
const { XAI_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || XAI_API_KEY);
const baseURL = payload?.endpoint || process.env.XAI_PROXY_URL;

return { apiKey, baseURL };
}
case ModelProvider.InternLM: {
const { INTERNLM_API_KEY } = getLLMConfig();

const apiKey = apiKeyManager.pick(payload?.apiKey || INTERNLM_API_KEY);

return { apiKey };
}
}
Expand Down