-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rerank document compressor #331
base: main
Are you sure you want to change the base?
Conversation
df8c35a
to
70f4e2d
Compare
Hi @jpfcabral, interesting contribution! I noticed that the default region is set to |
Fair point, @mgvalverde , I just changed on bbc0243 |
@jpfcabral |
@3coins Following on the documentation you sent me, I need to call it by bedrock-agent-runtime so we can configure the request with Note: bedrock-agent-runtime rerank only support requests with I just added a commit 13ad88f with the changes above, let me know if need some fixes on that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jpfcabral
Thanks for submitting this PR and a quick turnaround on the updates.
Great job on adding the examples in the PR description, it might be more useful to add a notebook with those samples in the samples/document_compressors
directory.
Also, to keep the module organization consistent with community, does it sound better to put this under document_compressors
rather than rerank
?
def _get_model_arn(self) -> str: | ||
"""Fetch the ARN of the reranker model using the model ID.""" | ||
session = self._get_session() | ||
client = session.client("bedrock", self.aws_region) | ||
response = client.get_foundation_model(modelIdentifier=self.model_id) | ||
return response["modelDetails"]["modelArn"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than relying on the API to fetch the model arn, I would suggest, we keep this consistent with the rerank API and take model_arn as input rather than the model_id.
def _get_model_arn(self) -> str: | |
"""Fetch the ARN of the reranker model using the model ID.""" | |
session = self._get_session() | |
client = session.client("bedrock", self.aws_region) | |
response = client.get_foundation_model(modelIdentifier=self.model_id) | |
return response["modelDetails"]["modelArn"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved on f45340b
"""Bedrock client to use for compressing documents.""" | ||
top_n: Optional[int] = 3 | ||
"""Number of documents to return.""" | ||
model_id: Optional[str] = "amazon.rerank-v1:0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update this to model_arn.
model_id: Optional[str] = "amazon.rerank-v1:0" | |
model_arn: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved on f45340b
aws_region: str = Field( | ||
default_factory=from_env("AWS_DEFAULT_REGION", default=None) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I like the short name here, we have been using region_name
in other places, and would prefer we keep it consistent here as well.
aws_region: str = Field( | |
default_factory=from_env("AWS_DEFAULT_REGION", default=None) | |
) | |
region_name: str = Field( | |
default_factory=from_env("AWS_DEFAULT_REGION", default=None) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved on 0489125
aws_profile: Optional[str] = Field( | ||
default_factory=from_env("AWS_PROFILE", default=None) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the region, would like this to be consistent with other implementations in langchain_aws.
aws_profile: Optional[str] = Field( | |
default_factory=from_env("AWS_PROFILE", default=None) | |
) | |
credentials_profile_name: Optional[str] = Field( | |
default_factory=from_env("AWS_PROFILE", default=None) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved on 0489125
This PR resolves #298
Added:
from langchain_aws import BedrockRerank
Some snippets: