From ef8063310df5a85c0efb7b69bda91c7347fc9ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E8=85=BE?= <101850389+hangters@users.noreply.github.com> Date: Thu, 11 Jul 2024 15:41:00 +0800 Subject: [PATCH] add support for Gemini (#1465) ### What problem does this PR solve? #1036 ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: Zhedong Cen --- api/db/init_data.py | 37 +++++- rag/llm/chat_model.py | 61 ++++++++++ rag/llm/cv_model.py | 23 ++++ rag/llm/embedding_model.py | 26 +++- requirements.txt | 1 + requirements_arm.txt | 1 + requirements_dev.txt | 1 + web/src/assets/svg/llm/gemini.svg | 114 ++++++++++++++++++ .../user-setting/setting-model/index.tsx | 1 + 9 files changed, 263 insertions(+), 2 deletions(-) create mode 100644 web/src/assets/svg/llm/gemini.svg diff --git a/api/db/init_data.py b/api/db/init_data.py index 5405fc3e732..c649e0e7bc6 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -175,6 +175,11 @@ def init_superuser(): "logo": "", "tags": "LLM,TEXT EMBEDDING", "status": "1", +},{ + "name": "Gemini", + "logo": "", + "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT", + "status": "1", } # { # "name": "文心一言", @@ -898,7 +903,37 @@ def init_llm_factory(): "tags": "TEXT EMBEDDING", "max_tokens": 2048, "model_type": LLMType.EMBEDDING.value - }, + }, { + "fid": factory_infos[17]["name"], + "llm_name": "gemini-1.5-pro-latest", + "tags": "LLM,CHAT,1024K", + "max_tokens": 1024*1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[17]["name"], + "llm_name": "gemini-1.5-flash-latest", + "tags": "LLM,CHAT,1024K", + "max_tokens": 1024*1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[17]["name"], + "llm_name": "gemini-1.0-pro", + "tags": "LLM,CHAT,30K", + "max_tokens": 30*1024, + "model_type": LLMType.CHAT.value + }, { + "fid": factory_infos[17]["name"], + "llm_name": "gemini-1.0-pro-vision-latest", + "tags": "LLM,IMAGE2TEXT,12K", + "max_tokens": 12*1024, + "model_type": LLMType.IMAGE2TEXT.value + }, { + "fid": factory_infos[17]["name"], + "llm_name": "text-embedding-004", + "tags": "TEXT EMBEDDING", + "max_tokens": 2048, + "model_type": LLMType.EMBEDDING.value + } ] for info in factory_infos: try: diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index f7f137b6fc9..f0fcb39bcf7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -621,3 +621,64 @@ def chat_streamly(self, system, history, gen_conf): yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}" yield num_tokens_from_string(ans) + +class GeminiChat(Base): + + def __init__(self, key, model_name,base_url=None): + from google.generativeai import client,GenerativeModel + + client.configure(api_key=key) + _client = client.get_default_generative_client() + self.model_name = 'models/' + model_name + self.model = GenerativeModel(model_name=self.model_name) + self.model._client = _client + + def chat(self,system,history,gen_conf): + if system: + history.insert(0, {"role": "user", "parts": system}) + if 'max_tokens' in gen_conf: + gen_conf['max_output_tokens'] = gen_conf['max_tokens'] + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_output_tokens"]: + del gen_conf[k] + for item in history: + if 'role' in item and item['role'] == 'assistant': + item['role'] = 'model' + if 'content' in item : + item['parts'] = item.pop('content') + + try: + response = self.model.generate_content( + history, + generation_config=gen_conf) + ans = response.text + return ans, response.usage_metadata.total_token_count + except Exception as e: + return "**ERROR**: " + str(e), 0 + + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "user", "parts": system}) + if 'max_tokens' in gen_conf: + gen_conf['max_output_tokens'] = gen_conf['max_tokens'] + for k in list(gen_conf.keys()): + if k not in ["temperature", "top_p", "max_output_tokens"]: + del gen_conf[k] + for item in history: + if 'role' in item and item['role'] == 'assistant': + item['role'] = 'model' + if 'content' in item : + item['parts'] = item.pop('content') + ans = "" + try: + response = self.model.generate_content( + history, + generation_config=gen_conf,stream=True) + for resp in response: + ans += resp.text + yield ans + + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield response._chunks[-1].usage_metadata.total_token_count \ No newline at end of file diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 9c25ffd87f2..19843a352bd 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -203,6 +203,29 @@ def describe(self, image, max_tokens=300): ) return res.choices[0].message.content.strip(), res.usage.total_tokens +class GeminiCV(Base): + def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): + from google.generativeai import client,GenerativeModel + client.configure(api_key=key) + _client = client.get_default_generative_client() + self.model_name = model_name + self.model = GenerativeModel(model_name=self.model_name) + self.model._client = _client + self.lang = lang + + def describe(self, image, max_tokens=2048): + from PIL.Image import open + gen_config = {'max_output_tokens':max_tokens} + prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ + "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." + b64 = self.image2base64(image) + img = open(BytesIO(base64.b64decode(b64))) + input = [prompt,img] + res = self.model.generate_content( + input, + generation_config=gen_config, + ) + return res.text,res.usage_metadata.total_token_count class LocalCV(Base): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 48081e0124d..06cc56975d6 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -31,7 +31,7 @@ import asyncio from api.utils.file_utils import get_home_cache_dir from rag.utils import num_tokens_from_string, truncate - +import google.generativeai as genai class Base(ABC): def __init__(self, key, model_name): @@ -419,3 +419,27 @@ def encode_queries(self, text): return np.array(embeddings), token_count +class GeminiEmbed(Base): + def __init__(self, key, model_name='models/text-embedding-004', + **kwargs): + genai.configure(api_key=key) + self.model_name = 'models/' + model_name + + def encode(self, texts: list, batch_size=32): + texts = [truncate(t, 2048) for t in texts] + token_count = sum(num_tokens_from_string(text) for text in texts) + result = genai.embed_content( + model=self.model_name, + content=texts, + task_type="retrieval_document", + title="Embedding of list of strings") + return np.array(result['embedding']),token_count + + def encode_queries(self, text): + result = genai.embed_content( + model=self.model_name, + content=truncate(text,2048), + task_type="retrieval_document", + title="Embedding of single string") + token_count = num_tokens_from_string(text) + return np.array(result['embedding']),token_count \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 048761dabf2..497c946832d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -147,3 +147,4 @@ markdown==3.6 mistralai==0.4.2 boto3==1.34.140 duckduckgo_search==6.1.9 +google-generativeai==0.7.2 \ No newline at end of file diff --git a/requirements_arm.txt b/requirements_arm.txt index 448b093619b..a5cbf5d7009 100644 --- a/requirements_arm.txt +++ b/requirements_arm.txt @@ -148,3 +148,4 @@ markdown==3.6 mistralai==0.4.2 boto3==1.34.140 duckduckgo_search==6.1.9 +google-generativeai==0.7.2 \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index 8bc49f07ef0..f6b6799b56f 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -133,3 +133,4 @@ markdown==3.6 mistralai==0.4.2 boto3==1.34.140 duckduckgo_search==6.1.9 +google-generativeai==0.7.2 \ No newline at end of file diff --git a/web/src/assets/svg/llm/gemini.svg b/web/src/assets/svg/llm/gemini.svg new file mode 100644 index 00000000000..3f06bf3b5a1 --- /dev/null +++ b/web/src/assets/svg/llm/gemini.svg @@ -0,0 +1,114 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/pages/user-setting/setting-model/index.tsx b/web/src/pages/user-setting/setting-model/index.tsx index dd927e5cbc0..42f58a61e57 100644 --- a/web/src/pages/user-setting/setting-model/index.tsx +++ b/web/src/pages/user-setting/setting-model/index.tsx @@ -61,6 +61,7 @@ const IconMap = { Mistral: 'mistral', 'Azure-OpenAI': 'azure', Bedrock: 'bedrock', + Gemini:'gemini', }; const LlmIcon = ({ name }: { name: string }) => {