Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deployment config field to SagemakerEndpoint. #318

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions libs/aws/langchain_aws/llms/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,15 @@ class LLMContentHandler(ContentHandlerBase[str, str]):
class SagemakerEndpoint(LLM):
"""Sagemaker Inference Endpoint models.

To use, you must supply the endpoint name from your deployed
To use with a pre-deployed SageMaker endpoint or inference component, you must
supply the endpoint name and optional inference component name from your deployed
Sagemaker model & the region where it is deployed.

To use with undeployed SageMaker resources, you can supply an endpoint name,
optional inference component name, and deployment configuration which defines
the endpoint and model configs. This construct can then be used by the SageMaker
PythonSDK ModelBuilder class to deploy a Sagemaker model on the desired compute.

To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
Expand Down Expand Up @@ -191,8 +197,7 @@ class SagemakerEndpoint(LLM):
region_name=region_name,
credentials_profile_name=credentials_profile_name
)

#Use with boto3 client
# Usage with boto3 client
client = boto3.client(
"sagemaker-runtime",
region_name=region_name
Expand All @@ -208,7 +213,7 @@ class SagemakerEndpoint(LLM):
"""Boto3 client for sagemaker runtime"""

endpoint_name: str = ""
"""The name of the endpoint from the deployed Sagemaker model.
"""The name of the endpoint created from a Sagemaker model.
Must be unique within an AWS Region."""

inference_component_name: Optional[str] = None
Expand Down Expand Up @@ -263,6 +268,33 @@ def transform_output(self, output: bytes) -> str:
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
"""

deployment_config: Optional[Dict] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than keeping this as an open-ended dictionary, how about encoding this as a pydantic model, with allowed additions.

"""The deployment configuration for an undeployed endpoint or inference component
which can be deployed through the Sagemaker Python SDK ModelBuilder class.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the Model Builder SDK changes available to test this?

Comprises two sub-dictionaries model_config and endpoint_config.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Comprises two sub-dictionaries model_config and endpoint_config.
Comprises two sub-dictionaries model_config and endpoint_config.

"""

"""
Schema:
.. code-block:: python
deployment_config = {
"model_config": {
"model": Optional[str],
"model_path": Optional[str],
"image_uri": Optional[str],
"model_server": Optional[str],
"content_type": Optional[str],
"accept_type": Optional[str]
},
"endpoint_config": {
"resources": Optional[Dict[str, int]],
"instance_type": Optional[str],
"initial_instance_count": Optional[int]
},
"tags": Optional[List[Dict]]
}
"""

model_config = ConfigDict(
extra="forbid",
)
Expand Down
Loading