-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rerank document compressor #331
base: main
Are you sure you want to change the base?
Changes from 4 commits
99ed23d
bbc0243
13ad88f
9e22d93
567ed12
f45340b
0489125
66f0d64
9d65d5e
5c8aca3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,142 @@ | ||||||||||||||
from copy import deepcopy | ||||||||||||||
from typing import Any, Dict, List, Optional, Sequence, Union | ||||||||||||||
|
||||||||||||||
import boto3 | ||||||||||||||
from langchain_core.callbacks.manager import Callbacks | ||||||||||||||
from langchain_core.documents import BaseDocumentCompressor, Document | ||||||||||||||
from langchain_core.utils import from_env | ||||||||||||||
from pydantic import ConfigDict, Field, model_validator | ||||||||||||||
from typing_extensions import Self | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class BedrockRerank(BaseDocumentCompressor): | ||||||||||||||
"""Document compressor that uses AWS Bedrock Rerank API.""" | ||||||||||||||
|
||||||||||||||
client: Any = None | ||||||||||||||
"""Bedrock client to use for compressing documents.""" | ||||||||||||||
top_n: Optional[int] = 3 | ||||||||||||||
"""Number of documents to return.""" | ||||||||||||||
model_id: Optional[str] = "amazon.rerank-v1:0" | ||||||||||||||
"""Model ID to fetch ARN dynamically.""" | ||||||||||||||
aws_region: str = Field( | ||||||||||||||
default_factory=from_env("AWS_DEFAULT_REGION", default=None) | ||||||||||||||
) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though I like the short name here, we have been using
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Solved on 0489125 |
||||||||||||||
"""AWS region to initialize the Bedrock client.""" | ||||||||||||||
aws_profile: Optional[str] = Field( | ||||||||||||||
default_factory=from_env("AWS_PROFILE", default=None) | ||||||||||||||
) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the region, would like this to be consistent with other implementations in langchain_aws.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Solved on 0489125 |
||||||||||||||
"""AWS profile for authentication, optional.""" | ||||||||||||||
|
||||||||||||||
model_config = ConfigDict( | ||||||||||||||
extra="forbid", | ||||||||||||||
arbitrary_types_allowed=True, | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
@model_validator(mode="after") | ||||||||||||||
def initialize_client(self) -> Self: | ||||||||||||||
"""Initialize the AWS Bedrock client.""" | ||||||||||||||
if not self.client: | ||||||||||||||
session = self._get_session() | ||||||||||||||
self.client = session.client("bedrock-agent-runtime") | ||||||||||||||
return self | ||||||||||||||
|
||||||||||||||
def _get_session(self): | ||||||||||||||
return ( | ||||||||||||||
boto3.Session(profile_name=self.aws_profile) | ||||||||||||||
if self.aws_profile | ||||||||||||||
else boto3.Session() | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def _get_model_arn(self) -> str: | ||||||||||||||
"""Fetch the ARN of the reranker model using the model ID.""" | ||||||||||||||
session = self._get_session() | ||||||||||||||
client = session.client("bedrock", self.aws_region) | ||||||||||||||
response = client.get_foundation_model(modelIdentifier=self.model_id) | ||||||||||||||
return response["modelDetails"]["modelArn"] | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than relying on the API to fetch the model arn, I would suggest, we keep this consistent with the rerank API and take model_arn as input rather than the model_id.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Solved on f45340b |
||||||||||||||
|
||||||||||||||
def rerank( | ||||||||||||||
self, | ||||||||||||||
documents: Sequence[Union[str, Document, dict]], | ||||||||||||||
query: str, | ||||||||||||||
top_n: Optional[int] = None, | ||||||||||||||
extra_model_fields: Optional[Dict[str, Any]] = None, | ||||||||||||||
) -> List[Dict[str, Any]]: | ||||||||||||||
"""Returns an ordered list of documents based on their relevance to the query. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
query: The query to use for reranking. | ||||||||||||||
documents: A sequence of documents to rerank. | ||||||||||||||
top_n: The number of top-ranked results to return. Defaults to self.top_n. | ||||||||||||||
extra_model_fields: A dictionary of additional fields to pass to the model. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
List[Dict[str, Any]]: A list of ranked documents with relevance scores. | ||||||||||||||
""" | ||||||||||||||
if len(documents) == 0: | ||||||||||||||
return [] | ||||||||||||||
|
||||||||||||||
model_arn = self._get_model_arn() | ||||||||||||||
|
||||||||||||||
# Serialize documents for the Bedrock API | ||||||||||||||
serialized_documents = [ | ||||||||||||||
{"textDocument": {"text": doc.page_content}, "type": "TEXT"} | ||||||||||||||
if isinstance(doc, Document) | ||||||||||||||
else {"textDocument": {"text": doc}, "type": "TEXT"} | ||||||||||||||
if isinstance(doc, str) | ||||||||||||||
else {"jsonDocument": doc, "type": "JSON"} | ||||||||||||||
for doc in documents | ||||||||||||||
] | ||||||||||||||
|
||||||||||||||
request_body = { | ||||||||||||||
"queries": [{"textQuery": {"text": query}, "type": "TEXT"}], | ||||||||||||||
"rerankingConfiguration": { | ||||||||||||||
"bedrockRerankingConfiguration": { | ||||||||||||||
"modelConfiguration": { | ||||||||||||||
"modelArn": model_arn, | ||||||||||||||
"additionalModelRequestFields": extra_model_fields | ||||||||||||||
or {}, | ||||||||||||||
}, | ||||||||||||||
"numberOfResults": top_n or self.top_n, | ||||||||||||||
}, | ||||||||||||||
"type": "BEDROCK_RERANKING_MODEL", | ||||||||||||||
}, | ||||||||||||||
"sources": [ | ||||||||||||||
{"inlineDocumentSource": doc, "type": "INLINE"} | ||||||||||||||
for doc in serialized_documents | ||||||||||||||
], | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
response = self.client.rerank(**request_body) | ||||||||||||||
response_body = response.get("results", []) | ||||||||||||||
|
||||||||||||||
results = [ | ||||||||||||||
{"index": result["index"], "relevance_score": result["relevanceScore"]} | ||||||||||||||
for result in response_body | ||||||||||||||
] | ||||||||||||||
|
||||||||||||||
return results | ||||||||||||||
|
||||||||||||||
def compress_documents( | ||||||||||||||
self, | ||||||||||||||
documents: Sequence[Document], | ||||||||||||||
query: str, | ||||||||||||||
callbacks: Optional[Callbacks] = None, | ||||||||||||||
) -> Sequence[Document]: | ||||||||||||||
""" | ||||||||||||||
Compress documents using Bedrock's rerank API. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
documents: A sequence of documents to compress. | ||||||||||||||
query: The query to use for compressing the documents. | ||||||||||||||
callbacks: Callbacks to run during the compression process. | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
A sequence of compressed documents. | ||||||||||||||
""" | ||||||||||||||
compressed = [] | ||||||||||||||
for res in self.rerank(documents, query): | ||||||||||||||
doc = documents[res["index"]] | ||||||||||||||
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) | ||||||||||||||
doc_copy.metadata["relevance_score"] = res["relevance_score"] | ||||||||||||||
compressed.append(doc_copy) | ||||||||||||||
return compressed |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
from langchain.schema import Document | ||
|
||
from langchain_aws import BedrockRerank | ||
|
||
|
||
@pytest.fixture | ||
def reranker(): | ||
reranker = BedrockRerank() | ||
reranker.client = MagicMock() | ||
return reranker | ||
|
||
|
||
@patch("langchain_aws.rerank.rerank.boto3.Session") | ||
def test_initialize_client(mock_boto_session, reranker): | ||
session_instance = MagicMock() | ||
mock_boto_session.return_value = session_instance | ||
session_instance.client.return_value = MagicMock() | ||
reranker.initialize_client() | ||
assert reranker.client is not None | ||
|
||
|
||
@patch("langchain_aws.rerank.rerank.BedrockRerank._get_model_arn") | ||
def test_rerank(mock_get_model_arn, reranker): | ||
mock_get_model_arn.return_value = "arn:aws:bedrock:model/amazon.rerank-v1:0" | ||
reranker.client.rerank.return_value = { | ||
"results": [ | ||
{"index": 0, "relevanceScore": 0.9}, | ||
{"index": 1, "relevanceScore": 0.8}, | ||
] | ||
} | ||
|
||
documents = [Document("Doc 1"), Document("Doc 2")] | ||
query = "Example Query" | ||
results = reranker.rerank(documents, query) | ||
|
||
assert len(results) == 2 | ||
assert results[0]["index"] == 0 | ||
assert results[0]["relevance_score"] == 0.9 | ||
assert results[1]["index"] == 1 | ||
assert results[1]["relevance_score"] == 0.8 | ||
|
||
|
||
@patch("langchain_aws.rerank.rerank.BedrockRerank.rerank") | ||
def test_compress_documents(mock_rerank, reranker): | ||
mock_rerank.return_value = [ | ||
{"index": 0, "relevance_score": 0.95}, | ||
{"index": 1, "relevance_score": 0.85}, | ||
] | ||
documents = [Document("Content 1"), Document("Content 2")] | ||
query = "Relevant query" | ||
compressed_docs = reranker.compress_documents(documents, query) | ||
|
||
assert len(compressed_docs) == 2 | ||
assert compressed_docs[0].metadata["relevance_score"] == 0.95 | ||
assert compressed_docs[1].metadata["relevance_score"] == 0.85 | ||
|
||
|
||
@patch("langchain_aws.rerank.rerank.BedrockRerank._get_model_arn") | ||
def test_get_model_arn(mock_get_model_arn, reranker): | ||
mock_get_model_arn.return_value = "arn:aws:bedrock:model/amazon.rerank-v1:0" | ||
model_arn = reranker._get_model_arn() | ||
assert model_arn == "arn:aws:bedrock:model/amazon.rerank-v1:0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update this to model_arn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved on f45340b