Skip to content

Commit

Permalink
Add api/list_kb_docs function and modify api/list_chunks (infiniflow#874
Browse files Browse the repository at this point in the history
)

### What problem does this PR solve?
infiniflow#717 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
guoyuhao2330 authored May 22, 2024
1 parent 4b43d06 commit 216c40a
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
63 changes: 49 additions & 14 deletions api/apps/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,14 @@
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api.settings import RetCode
from api.settings import RetCode, retrievaler
from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
from itsdangerous import URLSafeTimedSerializer

from api.utils.file_utils import filename_type, thumbnail
from rag.utils.minio_conn import MINIO

from rag.utils.es_conn import ELASTICSEARCH
from rag.nlp import search
from elasticsearch_dsl import Q

def generate_confirmation_token(tenent_id):
serializer = URLSafeTimedSerializer(tenent_id)
Expand Down Expand Up @@ -369,27 +366,65 @@ def list_chunks():
try:
if "doc_name" in form_data.keys():
tenant_id = DocumentService.get_tenant_id_by_name(form_data['doc_name'])
q = Q("match", docnm_kwd=form_data['doc_name'])
doc_id = DocumentService.get_doc_id_by_doc_name(form_data['doc_name'])

elif "doc_id" in form_data.keys():
tenant_id = DocumentService.get_tenant_id(form_data['doc_id'])
q = Q("match", doc_id=form_data['doc_id'])
doc_id = form_data['doc_id']
else:
return get_json_result(
data=False,retmsg="Can't find doc_name or doc_id"
)

res_es_search = ELASTICSEARCH.search(q,idxnm=search.index_name(tenant_id),timeout="600s")
res = retrievaler.chunk_list(doc_id=doc_id, tenant_id=tenant_id)
res = [
{
"content": res_item["content_with_weight"],
"doc_name": res_item["docnm_kwd"],
"img_id": res_item["img_id"]
} for res_item in res
]

res = [{} for _ in range(len(res_es_search['hits']['hits']))]
except Exception as e:
return server_error_response(e)

return get_json_result(data=res)


@manager.route('/list_kb_docs', methods=['POST'])
# @login_required
def list_kb_docs():
token = request.headers.get('Authorization').split()[1]
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)

for index , chunk in enumerate(res_es_search['hits']['hits']):
res[index]['doc_name'] = chunk['_source']['docnm_kwd']
res[index]['content'] = chunk['_source']['content_with_weight']
if 'img_id' in chunk['_source'].keys():
res[index]['img_id'] = chunk['_source']['img_id']
tenant_id = objs[0].tenant_id
kb_name = request.form.get("kb_name").strip()

try:
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
if not e:
return get_data_error_result(
retmsg="Can't find this knowledgebase!")
kb_id = kb.id

except Exception as e:
return server_error_response(e)

return get_json_result(data=res)
page_number = int(request.form.get("page", 1))
items_per_page = int(request.form.get("page_size", 15))
orderby = request.form.get("orderby", "create_time")
desc = request.form.get("desc", True)
keywords = request.form.get("keywords", "")

try:
docs, tol = DocumentService.get_by_kb_id(
kb_id, page_number, items_per_page, orderby, desc, keywords)
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]

return get_json_result(data={"total": tol, "docs": docs})

except Exception as e:
return server_error_response(e)
11 changes: 11 additions & 0 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ def get_tenant_id_by_name(cls, name):
return
return docs[0]["tenant_id"]

@classmethod
@DB.connection_context()
def get_doc_id_by_doc_name(cls, doc_name):
fields = [cls.model.id]
doc_id = cls.model.select(*fields) \
.where(cls.model.name == doc_name)
doc_id = doc_id.dicts()
if not doc_id:
return
return doc_id[0]["id"]

@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
Expand Down
10 changes: 10 additions & 0 deletions rag/nlp/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,13 @@ def sql_retrieval(self, sql, fetch_size=128, format="json"):
except Exception as e:
chat_logger.error(f"SQL failure: {sql} =>" + str(e))
return {"error": str(e)}

def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
s = Search()
s = s.query(Q("match", doc_id=doc_id))[0:max_count]
s = s.to_dict()
es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields)
res = []
for index, chunk in enumerate(es_res['hits']['hits']):
res.append({fld: chunk['_source'].get(fld) for fld in fields})
return res

0 comments on commit 216c40a

Please sign in to comment.