Skip to content

Commit

Permalink
add supprot for OpenAI-API-Compatible llm (#1787)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

#1771  add supprot for OpenAI-API-Compatible 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>
  • Loading branch information
hangters and aopstudio authored Aug 6, 2024
1 parent 66e4113 commit b67484e
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 11 deletions.
9 changes: 6 additions & 3 deletions api/apps/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def add_llm():
elif factory == "LocalAI":
llm_name = req["llm_name"]+"___LocalAI"
api_key = "xxxxxxxxxxxxxxx"
elif factory == "OpenAI-API-Compatible":
llm_name = req["llm_name"]+"___OpenAI-API"
api_key = req["api_key"]
else:
llm_name = req["llm_name"]
api_key = "xxxxxxxxxxxxxxx"
Expand All @@ -145,7 +148,7 @@ def add_llm():
msg = ""
if llm["model_type"] == LLMType.EMBEDDING.value:
mdl = EmbeddingModel[factory](
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"])
try:
Expand All @@ -156,7 +159,7 @@ def add_llm():
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
elif llm["model_type"] == LLMType.CHAT.value:
mdl = ChatModel[factory](
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible"] else None,
model_name=llm["llm_name"],
base_url=llm["api_base"]
)
Expand All @@ -181,7 +184,7 @@ def add_llm():
e)
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
mdl = CvModel[factory](
key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
key=llm["api_key"] if factory in ["OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"]
)
try:
img_url = (
Expand Down
7 changes: 7 additions & 0 deletions conf/llm_factories.json
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@
"status": "1",
"llm": []
},
{
"name": "OpenAI-API-Compatible",
"logo": "",
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
"status": "1",
"llm": []
},
{
"name": "Moonshot",
"logo": "",
Expand Down
12 changes: 8 additions & 4 deletions rag/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
"Bedrock": BedrockEmbed,
"Gemini": GeminiEmbed,
"NVIDIA": NvidiaEmbed,
"LM-Studio": LmStudioEmbed
"LM-Studio": LmStudioEmbed,
"OpenAI-API-Compatible": OpenAI_APIEmbed
}


Expand All @@ -53,7 +54,8 @@
"LocalAI": LocalAICV,
"NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV,
"StepFun":StepFunCV
"StepFun":StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV
}


Expand All @@ -78,7 +80,8 @@
"OpenRouter": OpenRouterChat,
"StepFun": StepFunChat,
"NVIDIA": NvidiaChat,
"LM-Studio": LmStudioChat
"LM-Studio": LmStudioChat,
"OpenAI-API-Compatible": OpenAI_APIChat
}


Expand All @@ -88,7 +91,8 @@
"Youdao": YoudaoRerank,
"Xinference": XInferenceRerank,
"NVIDIA": NvidiaRerank,
"LM-Studio": LmStudioRerank
"LM-Studio": LmStudioRerank,
"OpenAI-API-Compatible": OpenAI_APIRerank
}


Expand Down
14 changes: 12 additions & 2 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,16 @@ def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("Local llm url cannot be None")
if base_url.split("/")[-1] != "v1":
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name


class OpenAI_APIChat(Base):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
model_name = model_name.split("___")[0]
super().__init__(key, model_name, base_url)
11 changes: 11 additions & 0 deletions rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,3 +638,14 @@ def __init__(self, key, model_name, base_url, lang="Chinese"):
self.client = OpenAI(api_key="lm-studio", base_url=base_url)
self.model_name = model_name
self.lang = lang


class OpenAI_APICV(GptV4):
def __init__(self, key, model_name, base_url, lang="Chinese"):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
self.lang = lang
10 changes: 10 additions & 0 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,13 @@ def __init__(self, key, model_name, base_url):
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="lm-studio", base_url=self.base_url)
self.model_name = model_name


class OpenAI_APIEmbed(OpenAIEmbed):
def __init__(self, key, model_name, base_url):
if not base_url:
raise ValueError("url cannot be None")
if base_url.split("/")[-1] != "v1":
self.base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name.split("___")[0]
8 changes: 8 additions & 0 deletions rag/llm/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,11 @@ def __init__(self, key, model_name, base_url):

def similarity(self, query: str, texts: list):
raise NotImplementedError("The LmStudioRerank has not been implement")


class OpenAI_APIRerank(Base):
def __init__(self, key, model_name, base_url):
pass

def similarity(self, query: str, texts: list):
raise NotImplementedError("The api has not been implement")
1 change: 1 addition & 0 deletions web/src/assets/svg/llm/openai-api.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions web/src/interfaces/request/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export interface IAddLlmRequestBody {
llm_name: string;
model_type: string;
api_base?: string; // chat|embedding|speech2text|image2text
api_key: string;
}

export interface IDeleteLlmRequestBody {
Expand Down
2 changes: 1 addition & 1 deletion web/src/pages/user-setting/constants.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ export const UserSettingIconMap = {

export * from '@/constants/setting';

export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio'];
export const LocalLlmFactories = ['Ollama', 'Xinference','LocalAI','LM-Studio',"OpenAI-API-Compatible"];
3 changes: 2 additions & 1 deletion web/src/pages/user-setting/setting-model/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ export const IconMap = {
LocalAI: 'local-ai',
StepFun: 'stepfun',
NVIDIA:'nvidia',
'LM-Studio':'lm-studio'
'LM-Studio':'lm-studio',
'OpenAI-API-Compatible':'openai-api'
};

export const BedrockRegionList = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ const OllamaModal = ({
>
<Input placeholder={t('baseUrlNameMessage')} />
</Form.Item>
<Form.Item<FieldType>
label={t('apiKey')}
name="api_key"
rules={[{ required: false, message: t('apiKeyMessage') }]}
>
<Input placeholder={t('apiKeyMessage')} />
</Form.Item>
<Form.Item noStyle dependencies={['model_type']}>
{({ getFieldValue }) =>
getFieldValue('model_type') === 'chat' && (
Expand Down

0 comments on commit b67484e

Please sign in to comment.