Skip to content

Commit

Permalink
📂 feat: load full context of files via /documents/{id}/context (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila authored Mar 22, 2024
1 parent d7a7b38 commit 90970a7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
25 changes: 24 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from langchain.schema import Document
from contextlib import asynccontextmanager
from dotenv import find_dotenv, load_dotenv
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
Expand All @@ -25,6 +25,7 @@
from psql import PSQLDatabase, ensure_custom_id_index_on_embedding
from middleware import security_middleware
from pgvector_routes import router as pgvector_router
from parsers import process_documents
from constants import ERROR_MESSAGES
from store import AsyncPgVector

Expand Down Expand Up @@ -263,6 +264,28 @@ async def embed_file(document: StoreDocument):
detail=ERROR_MESSAGES.DEFAULT(e),
)

@app.get("/documents/{id}/context")
async def load_document_context(id: str):
ids = [id]
try:
if isinstance(vector_store, AsyncPgVector):
existing_ids = await vector_store.get_all_ids()
documents = await vector_store.get_documents_by_ids(ids)
else:
existing_ids = vector_store.get_all_ids()
documents = vector_store.get_documents_by_ids(ids)

if not all(id in existing_ids for id in ids):
raise HTTPException(status_code=404, detail="The specified file_id was not found")

return process_documents(documents)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)

@app.post("/embed-upload")
async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile = File(...)):
temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename)
Expand Down
29 changes: 29 additions & 0 deletions parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import List, Optional
from langchain.schema import Document
from config import CHUNK_OVERLAP

def process_documents(documents: List[Document]) -> str:
processed_text = ""
last_page: Optional[int] = None
doc_basename = ""

for doc in documents:
if 'source' in doc.metadata:
doc_basename = doc.metadata['source'].split('/')[-1]
break

processed_text += f"{doc_basename}\n"

for doc in documents:
current_page = doc.metadata.get('page')
if current_page and current_page != last_page:
processed_text += f"\n# PAGE {doc.metadata['page']}\n\n"
last_page = current_page

new_content = doc.page_content
if processed_text.endswith(new_content[:CHUNK_OVERLAP]):
processed_text += new_content[CHUNK_OVERLAP:]
else:
processed_text += new_content

return processed_text.strip()

0 comments on commit 90970a7

Please sign in to comment.