-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'working' of https://github.com/daethyra/openai-utilikit …
…into working
- Loading branch information
Showing
2 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from typing import List, Any, Optional, Dict | ||
import argparse | ||
from langchain.document_loaders import PyPDFDirectoryLoader | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.embeddings import OpenAIEmbeddings, CacheBackedEmbeddings, HuggingFaceEmbeddings | ||
from langchain.filters import EmbeddingsRedundantFilter | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.chains.conversation.memory import ConversationBufferWindowMemory | ||
from langchain.chains import RetrievalQA | ||
import chromadb | ||
from langchain.vectorstores import Chroma | ||
|
||
# PDF Document Management | ||
class PDFDocumentManager: | ||
def __init__(self, directory: str): | ||
""" | ||
Initialize the PDFDocumentManager with a directory path. | ||
Args: | ||
directory (str): The path to the directory containing PDF files. | ||
""" | ||
try: | ||
self.loader = PyPDFDirectoryLoader(directory) | ||
except Exception as e: | ||
raise ValueError(f"Error initializing PyPDFDirectoryLoader: {e}") from e | ||
|
||
def load_documents(self) -> List[Any]: | ||
""" | ||
Load PDF documents from the specified directory. | ||
Returns: | ||
List[Any]: A list of loaded PDF documents. | ||
""" | ||
try: | ||
return self.loader.load() | ||
except Exception as e: | ||
raise ValueError(f"Error loading documents: {e}") from e | ||
|
||
# Text Splitting | ||
class TextSplitManager: | ||
def __init__(self, chunk_size: int, chunk_overlap: int, length_function=len, add_start_index=True): | ||
""" | ||
Initialize TextSplitManager with configuration for text splitting. | ||
Args: | ||
chunk_size (int): The maximum size for each chunk. | ||
chunk_overlap (int): The overlap between adjacent chunks. | ||
length_function (callable, optional): Function to compute the length of a chunk. Defaults to len. | ||
add_start_index (bool, optional): Whether to include the start index of each chunk. Defaults to True. | ||
""" | ||
self.text_splitter = RecursiveCharacterTextSplitter( | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
length_function=length_function, | ||
add_start_index=add_start_index | ||
) | ||
|
||
def create_documents(self, docs: List[Any]) -> List[Any]: | ||
""" | ||
Create document chunks based on the configuration. | ||
Args: | ||
docs (List[Any]): List of documents to be chunked. | ||
Returns: | ||
List[Any]: List of document chunks. | ||
""" | ||
try: | ||
return self.text_splitter.create_documents(docs) | ||
except Exception as e: | ||
raise ValueError(f"Error in text splitting: {e}") from e | ||
|
||
# Embeddings and Filtering | ||
class EmbeddingManager: | ||
def __init__(self): | ||
""" | ||
Initialize EmbeddingManager for handling document embeddings. | ||
""" | ||
self.embedder = CacheBackedEmbeddings(OpenAIEmbeddings()) | ||
|
||
def embed_documents(self, docs: List[Any]) -> List[Any]: | ||
""" | ||
Embed the documents using the configured embedder. | ||
Args: | ||
docs (List[Any]): List of documents to be embedded. | ||
Returns: | ||
List[Any]: List of embedded documents. | ||
""" | ||
try: | ||
return self.embedder.embed_documents(docs) | ||
except Exception as e: | ||
raise ValueError(f"Error in embedding documents: {e}") from e | ||
|
||
def filter_redundant(self, embeddings: List[Any]) -> List[Any]: | ||
""" | ||
Filter redundant embeddings from the list. | ||
Args: | ||
embeddings (List[Any]): List of embeddings. | ||
Returns: | ||
List[Any]: List of non-redundant embeddings. | ||
""" | ||
try: | ||
filter_instance = EmbeddingsRedundantFilter(embeddings) | ||
return filter_instance() | ||
except Exception as e: | ||
raise ValueError(f"Error in filtering redundant embeddings: {e}") from e | ||
|
||
# Document Retrieval and Reordering | ||
class DocumentRetriever: | ||
def __init__(self, model_name: str, texts: List[str], search_kwargs: Dict[str, Any]): | ||
""" | ||
Initialize DocumentRetriever for document retrieval and reordering. | ||
Args: | ||
model_name (str): Name of the embedding model to use. | ||
texts (List[str]): Texts for retriever training. | ||
search_kwargs (Dict[str, Any]): Additional search parameters. | ||
""" | ||
self.embeddings = HuggingFaceEmbeddings(model_name=model_name) | ||
self.retriever = Chroma.from_texts(texts, embedding=self.embeddings).as_retriever( | ||
search_kwargs=search_kwargs | ||
) | ||
|
||
def get_relevant_documents(self, query: str) -> List[Any]: | ||
""" | ||
Retrieve relevant documents based on the query. | ||
Args: | ||
query (str): The query string. | ||
Returns: | ||
List[Any]: List of relevant documents. | ||
""" | ||
try: | ||
return self.retriever.get_relevant_documents(query) | ||
except Exception as e: | ||
raise ValueError(f"Error retrieving relevant documents: {e}") from e | ||
|
||
# Chat and QA functionalities | ||
class ChatQA: | ||
def __init__(self, api_key: str, model_name: str, texts: List[str], search_kwargs: Dict[str, Any]): | ||
""" | ||
Initialize ChatQA for chat and QA functionalities. | ||
Args: | ||
api_key (str): API key for OpenAI. | ||
model_name (str): Name of the model for embeddings. | ||
texts (List[str]): Texts for retriever training. | ||
search_kwargs (Dict[str, Any]): Additional search parameters. | ||
""" | ||
self.llm = ChatOpenAI( | ||
openai_api_key=api_key, | ||
model_name='gpt-3.5-turbo', | ||
temperature=0.0 | ||
) | ||
self.conversational_memory = ConversationBufferWindowMemory( | ||
memory_key='chat_history', | ||
k=5, | ||
return_messages=True | ||
) | ||
self.retriever = DocumentRetriever(model_name, texts, search_kwargs) | ||
self.qa = RetrievalQA.from_chain_type( | ||
llm=self.llm, | ||
chain_type="stuff", | ||
retriever=self.retriever.retriever | ||
) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Run the QA module.") | ||
parser.add_argument("--api_key", type=str, required=True, help="API key for OpenAI.") | ||
parser.add_argument("--model_name", type=str, required=True, help="Name of the model for embeddings.") | ||
parser.add_argument("--texts", type=str, nargs='+', required=True, help="Texts for retriever training.") | ||
parser.add_argument("--search_k", type=int, default=10, help="Number of documents to retrieve.") | ||
|
||
args = parser.parse_args() | ||
|
||
# Initialize the ChatQA class | ||
chat_qa = ChatQA( | ||
api_key=args.api_key, | ||
model_name=args.model_name, | ||
texts=args.texts, | ||
search_kwargs={"k": args.search_k} | ||
) | ||
|
||
while True: | ||
query = input("Enter your query (or type 'quit' to exit): ") | ||
if query.lower() == "quit": | ||
break | ||
else: | ||
# Retrieve relevant documents | ||
relevant_docs = chat_qa.retriever.get_relevant_documents(query) | ||
print(f"Relevant Documents: {relevant_docs}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
|
||
import os | ||
import glob | ||
from dotenv import load_dotenv | ||
from retrying import retry | ||
from langchain.document_loaders import PyPDFLoader | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.vectorstores import Chroma | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from langchain.llms import OpenAI as OpenAILLM | ||
from langchain.chains.question_answering import load_qa_chain | ||
|
||
# Define the retrying decorator for specific functions | ||
def retry_if_value_error(exception): | ||
"""Return True if we should retry (in this case when it's a ValueError), False otherwise""" | ||
return isinstance(exception, ValueError) | ||
|
||
def retry_if_file_not_found_error(exception): | ||
"""Return True if we should retry (in this case when it's a FileNotFoundError), False otherwise""" | ||
return isinstance(exception, FileNotFoundError) | ||
|
||
class PDFProcessor: | ||
""" | ||
A class to handle PDF document processing, similarity search, and question answering. | ||
Attributes | ||
---------- | ||
OPENAI_API_KEY : str | ||
OpenAI API Key for authentication. | ||
embeddings : OpenAIEmbeddings | ||
Object for OpenAI embeddings. | ||
llm : OpenAILLM | ||
Language model for generating embeddings. | ||
Methods | ||
------- | ||
get_user_query(prompt="Please enter your query: "): | ||
Get query from the user. | ||
load_pdfs_from_directory(directory_path='data/'): | ||
Load PDFs from a specified directory. | ||
_load_and_split_document(file_path, chunk_size=2000, chunk_overlap=0): | ||
Load and split a single document. | ||
perform_similarity_search(docsearch, query): | ||
Perform similarity search on documents. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initialize PDFProcessor with environment variables and reusable objects.""" | ||
self._load_env_vars() | ||
self._initialize_reusable_objects() | ||
|
||
@retry(retry_on_exception=retry_if_value_error, stop_max_attempt_number=3) | ||
def _load_env_vars(self): | ||
"""Load environment variables.""" | ||
try: | ||
load_dotenv() | ||
self.OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', 'sk-') | ||
if not self.OPENAI_API_KEY: | ||
raise ValueError("OPENAI_API_KEY is missing. Please set the environment variable.") | ||
except ValueError as ve: | ||
print(f"ValueError encountered: {ve}") | ||
raise | ||
|
||
def _initialize_reusable_objects(self): | ||
"""Initialize reusable objects like embeddings and language models.""" | ||
self.embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY) | ||
self.llm = OpenAILLM(temperature=0, openai_api_key=self.OPENAI_API_KEY) | ||
|
||
@staticmethod | ||
def get_user_query(prompt="Please enter your query: "): | ||
""" | ||
Get user input for a query. | ||
Parameters: | ||
prompt (str): The prompt to display for user input. | ||
Returns: | ||
str: User's query input. | ||
""" | ||
return input(prompt) | ||
|
||
@retry(retry_on_exception=retry_if_file_not_found_error, stop_max_attempt_number=3) | ||
def load_pdfs_from_directory(self, directory_path='data/'): | ||
""" | ||
Load all PDF files from a given directory. | ||
Parameters: | ||
directory_path (str): Directory path to load PDFs from. | ||
Returns: | ||
list: List of text chunks from all loaded PDFs. | ||
""" | ||
try: | ||
if not os.path.exists(directory_path): | ||
raise FileNotFoundError(f"The directory {directory_path} does not exist.") | ||
pdf_files = glob.glob(f"{directory_path}/*.pdf") | ||
if not pdf_files: | ||
raise FileNotFoundError(f"No PDF files found in the directory {directory_path}.") | ||
all_texts = [] | ||
for pdf_file in pdf_files: | ||
all_texts.extend(self._load_and_split_document(pdf_file)) | ||
return all_texts | ||
except FileNotFoundError as fe: | ||
print(f"FileNotFoundError encountered: {fe}") | ||
raise | ||
|
||
def _load_and_split_document(self, file_path, chunk_size=2000, chunk_overlap=0): | ||
""" | ||
Load and split a PDF document into text chunks. | ||
Parameters: | ||
file_path (str): Path to the PDF file. | ||
chunk_size (int): Size of each text chunk. | ||
chunk_overlap (int): Overlapping characters between chunks. | ||
Returns: | ||
list: List of text chunks. | ||
""" | ||
if not os.path.exists(file_path): | ||
raise FileNotFoundError(f"The file {file_path} does not exist.") | ||
loader = PyPDFLoader(file_path) | ||
data = loader.load() | ||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | ||
return text_splitter.split_documents(data) | ||
|
||
def perform_similarity_search(self, docsearch, query): | ||
""" | ||
Perform similarity search on documents based on a query. | ||
Parameters: | ||
docsearch (Chroma): Chroma object containing document vectors. | ||
query (str): User query for similarity search. | ||
Returns: | ||
list: List of similar documents or chunks. | ||
""" | ||
if not query: | ||
raise ValueError("Query should not be empty.") | ||
return docsearch.similarity_search(query) | ||
|
||
if __name__ == "__main__": | ||
try: | ||
# Initialize PDFProcessor class | ||
pdf_processor = PDFProcessor() | ||
|
||
# Load PDFs from directory and count the number of loaded documents | ||
texts = pdf_processor.load_pdfs_from_directory() | ||
num_docs = len(texts) | ||
print(f'Loaded {num_docs} document(s).') | ||
|
||
# Create a Chroma object for document similarity search | ||
docsearch = Chroma.from_documents(texts, pdf_processor.embeddings) | ||
|
||
# Load a QA chain | ||
chain = load_qa_chain(pdf_processor.llm, chain_type="stuff") | ||
|
||
# Get user query for similarity search | ||
query = pdf_processor.get_user_query() | ||
|
||
# Perform similarity search based on the query | ||
result = pdf_processor.perform_similarity_search(docsearch, query) | ||
|
||
# Run the QA chain on the result | ||
chain.run(input_documents=result, question=query) | ||
except Exception as e: | ||
print(f"An error occurred: {e}") |