-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathchatbot.py
127 lines (104 loc) · 4.78 KB
/
chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings import ModelScopeEmbeddings
from langchain.schema import (
BaseMessage,
HumanMessage,
SystemMessage
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Hologres
from typing import List
import os
import argparse
from utils import export_env, DIR_PATH
class Chatbot:
def __init__(self, chat_model: BaseChatModel, clear_db: bool = False, no_vector_store: bool = False) -> None:
with open(os.path.join(DIR_PATH, 'config', 'prompt.txt')) as f:
self.prompt = f.read()
self.embeddings = ModelScopeEmbeddings(
model_id='damo/nlp_corom_sentence-embedding_chinese-base')
self.chat_history = [SystemMessage(content=self.prompt)]
self.chat_model = chat_model
self.no_vector_store = no_vector_store
if no_vector_store:
print('当前正使用无向量引擎模式,不连接hologres')
else:
print('connecting hologres vector store...')
HOLO_ENDPOINT = os.environ['HOLO_ENDPOINT']
HOLO_PORT = os.environ['HOLO_PORT']
HOLO_DATABASE = os.environ['HOLO_DATABASE']
HOLO_USER = os.environ['HOLO_USER']
HOLO_PASSWORD = os.environ['HOLO_PASSWORD']
connection_string = Hologres.connection_string_from_db_params(
HOLO_ENDPOINT, int(HOLO_PORT), HOLO_DATABASE, HOLO_USER, HOLO_PASSWORD)
self.vectorstore = Hologres(connection_string=connection_string,
embedding_function=self.embeddings, ndims=768, table_name='langchain_embedding', pre_delete_table=clear_db)
def load_db(self, files: List[str]) -> None:
# read docs
documents = []
for fname in files:
loader = CSVLoader(fname)
documents += loader.load()
# split docs
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=100)
documents = text_splitter.split_documents(documents)
# store embedding in vectorstore
self.vectorstore.add_documents(documents)
def _generate_context(self, question: str) -> List[BaseMessage]:
# 将用户查询转换为向量,并搜索相关知识
if self.no_vector_store:
docs = []
else:
docs = self.vectorstore.similarity_search(question, k=10)
# 限制context总长度
current_context_length = 0
ret = []
for doc in docs:
if len(doc.page_content) + current_context_length > 2000:
continue
current_context_length += len(doc.page_content)
ret.append(SystemMessage(content='CONTEXT:\n' + doc.page_content))
return ret
def query(self, question: str) -> str:
self.chat_history.append(HumanMessage(content=question))
input_messages = self._generate_context(question) + self.chat_history
ai_response = self.chat_model(input_messages)
self.chat_history.append(ai_response)
return ai_response.content
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='chatbot',
description='holo chatbot command line interface')
parser.add_argument('-l', '--load', action='store_true',
help='generate embeddings and update the vector database.')
parser.add_argument('-f', '--files', nargs='*', default=[],
help='specify the csv data file to update. If leave empty, all files in ./data will be updated. Only valid when --load is set.')
parser.add_argument('-c', '--clear', action='store_true',
help='clear all data in vector store')
parser.add_argument('-n', '--no-vector-store', action='store_true',
help='run pure ChatGPT without vector store')
args = parser.parse_args()
export_env()
if args.clear:
print('start with drop vector store')
chat_model = BaseChatModel() # Fix me: 在此处填入初始化的大模型
print('请换用具体的大模型')
assert False
bot = Chatbot(chat_model, args.clear, args.no_vector_store)
if args.load:
files = args.files
if len(files) == 0:
files = [os.path.join(DIR_PATH, 'data', x)
for x in os.listdir(os.path.join(DIR_PATH, 'data'))]
print(f'start loading files: {files}')
bot.load_db(files)
print(f'load finished!')
exit(0)
print('下面开始问答,请输入问题。按 ctrl-C 退出')
while True:
question = input('Human: ')
answer = bot.query(question)
print('Chatbot: ' + answer)