From f025b8c0447edd08095096c4cb362e6e6c0b436a Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Wed, 6 Dec 2023 22:28:46 -0800 Subject: [PATCH 001/120] [wip] Sagemaker serving agent Signed-off-by: Ketan Umare --- .../awssagemaker/agents/__init__.py | 0 .../awssagemaker/agents/boto3_mixin.py | 71 ++++++++++++++++ .../awssagemaker/agents/sagemaker_agents.py | 81 +++++++++++++++++++ .../tests/agents/__init__.py | 0 .../tests/agents/test_boto3_mixin.py | 40 +++++++++ 5 files changed, 192 insertions(+) create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/__init__.py create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py create mode 100644 plugins/flytekit-aws-sagemaker/tests/agents/__init__.py create mode 100644 plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py new file mode 100644 index 0000000000..30a63db79a --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py @@ -0,0 +1,71 @@ +import typing +from dataclasses import dataclass + +import boto3 +import botocore.exceptions + + +def update_dict(d: typing.Any, u: typing.Dict[str, typing.Any]) -> typing.Any: + """ + Recursively update a dictionary with values from another dictionary. E.g. if d is {"EndpointConfigName": "{endpoint_config_name}"}, + and u is {"endpoint_config_name": "my-endpoint-config"}, then the result will be {"EndpointConfigName": "my-endpoint-config"}. + :param d: The dictionary to update (in place) + :param u: The dictionary to use for updating + :return: The updated dictionary - Note that the original dictionary is updated in place. + """ + if d is None: + return None + if isinstance(d, str): + if "{" in d and "}" in d: + v = d.format(**u) + if v == d: + raise ValueError(f"Could not find value for {d}") + orig_v = u.get(d.replace("{", "").replace("}", "")) + if isinstance(orig_v, str): + return v + return orig_v + return d + if isinstance(d, list): + return [update_dict(i, u) for i in d] + if isinstance(d, dict): + for k, v in d.items(): + d[k] = update_dict(d.get(k), u) + return d + + +# TODO write AsyncBoto3AgentBase - https://github.com/terrycain/aioboto3 +class Boto3AgentMixin: + """ + This mixin can be used to create a Flyte agent for any AWS service, using boto3. + The mixin provides a single method `_call` that can be used to call any boto3 method. + """ + + def __init__(self, *, service: str, region: typing.Optional[str] = None, **kwargs): + """ + :param service: The AWS service to use - e.g. sagemaker + :param region: The region to use for the boto3 client, can be overridden when calling the boto3 method. + """ + self._region = region + self._service = service + + def _call(self, method: str, config: typing.Dict[str, typing.Any], + args: typing.Optional[typing.Dict[str, typing.Any]] = None, + region: typing.Optional[str] = None) -> typing.Dict[str, typing.Any]: + """ + TODO we should probably also accept task_template and inputs separately, and then call update_dict + Use this method to call any boto3 method (aws service method). + :param method: The boto3 method to call - e.g. create_endpoint_config + :param config: The config for the method - e.g. {"EndpointConfigName": "my-endpoint-config"}. The config can + contain placeholders that will be replaced by values from the args dict. For example, if the config is + {"EndpointConfigName": "{endpoint_config_name}"}, and the args dict is {"endpoint_config_name": "my-endpoint-config"}, + then the config will be updated to {"EndpointConfigName": "my-endpoint-config"} before calling the boto3 method. + :param args: The args dict can be used to provide values for placeholders in the config dict. For example, if the config is + :param region: The region to use for the boto3 client. If not provided, the region provided in the constructor will be used. + """ + client = boto3.client(self._service, region) + updated_config = update_dict(config, args or {}) + try: + res = getattr(client, method)(**updated_config) + except botocore.exceptions.ClientError as e: + raise e + return res diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py new file mode 100644 index 0000000000..5bd1cee319 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py @@ -0,0 +1,81 @@ +import typing + +import grpc +from flyteidl.admin.agent_pb2 import DeleteTaskResponse, GetTaskResponse, CreateTaskResponse +from flyteidl.core.tasks_pb2 import TaskTemplate + +from flytekit.extend.backend.base_agent import AgentBase +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.models.literals import LiteralMap +from .boto3_mixin import Boto3AgentMixin + + +class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): + + def __init__(self, task_type: str): + super().__init__(service="sagemaker", region="us-east-2", task_type=task_type, asynchronous=False, ) + + def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + str_inputs = literal_map_string_repr(inputs) + res = self._call("us-east-2", "create_endpoint_config", task_template.custom, str_inputs) + res = self._call("us-east-2", "create_endpoint", task_template.custom, str_inputs) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + pass + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + pass + + +class SagemakerModelAgent(Boto3AgentMixin, AgentBase): + + def __init__(self, task_type: str): + super().__init__(service="sagemaker", region="us-east-2", task_type=task_type, asynchronous=False, ) + + def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + str_inputs = literal_map_string_repr(inputs) + res = self._call("us-east-2", "create_model", task_template.custom, str_inputs) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + pass + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + pass + + +class SagemakerInvokeEndpointAgent(Boto3AgentMixin, AgentBase): + + def __init__(self, task_type: str): + super().__init__(service="sagemaker-runtime", region="us-east-2", task_type=task_type, asynchronous=False, ) + + def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + str_inputs = literal_map_string_repr(inputs) + res = self._call("us-east-2", "invoke_endpoint", task_template.custom, str_inputs) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + pass + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + pass + + + +# create_model = SagemakerCreateModelTask( +# inputs=kwtypes(model_name=str, image=str, model_path=str), +# config={ +# "ModelName": "{inputs.model_name}", +# "Image": "{container.image}", +# "PrimaryContainer": { +# "Image": "{image}", +# "ModelDataUrl": "{model_path}", +# "region": "us-east-2", +# }, +# }, +# ) + diff --git a/plugins/flytekit-aws-sagemaker/tests/agents/__init__.py b/plugins/flytekit-aws-sagemaker/tests/agents/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py new file mode 100644 index 0000000000..7ea135a5ff --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py @@ -0,0 +1,40 @@ +import typing +from dataclasses import dataclass + +from flytekit import FlyteContext, StructuredDataset +from flytekit.core.type_engine import TypeEngine +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.types.file import FlyteFile +from flytekitplugins.awssagemaker.agents.boto3_mixin import update_dict + + +@dataclass +class MyData: + image: str + model_name: str + model_path: str + + +# TODO improve this test to actually assert outputs +def test_update_dict(): + d = update_dict( + {"a": "{a}", "b": "{b}", "c": "{c}", "d": "{d}", "e": "{e}", "f": "{f}", + "j": {"a": "{a}", "b": "{f}", "c": "{e}"}}, + {"a": 1, "b": "hello", "c": True, "d": 1.0, "e": [1, 2, 3], "f": {"a": "b"}}) + assert d == {'a': 1, 'b': 'hello', 'c': True, 'd': 1.0, 'e': [1, 2, 3], 'f': {'a': 'b'}, + 'j': {'a': 1, 'b': {'a': 'b'}, 'c': [1, 2, 3]}} + + lm = TypeEngine.dict_to_literal_map(FlyteContext.current_context(), + {"a": 1, "b": "hello", "c": True, "d": 1.0, + "e": [1, 2, 3], "f": {"a": "b"}, "g": None, + "h": FlyteFile("s3://foo/bar", remote_path=False), + "i": StructuredDataset(uri="s3://foo/bar")}, + {"a": int, "b": str, "c": bool, "d": float, "e": typing.List[int], + "f": typing.Dict[str, str], "g": typing.Optional[str], "h": FlyteFile, + "i": StructuredDataset}) + + d = literal_map_string_repr(lm) + print(d) + + print("{data.image}, {data.model_name}, {data.model_path}".format( + data=MyData(image="foo", model_name="bar", model_path="baz"))) From 5db022a457e724791b4aa046988409080e702a63 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Sun, 10 Dec 2023 22:43:26 -0800 Subject: [PATCH 002/120] added sync agent example Signed-off-by: Ketan Umare --- .../awssagemaker/agents/boto3_mixin.py | 37 +++++++++++++++---- .../awssagemaker/agents/boto_agent.py | 29 +++++++++++++++ .../awssagemaker/agents/sagemaker_agents.py | 25 ++++++++----- .../awssagemaker/agents/sync_agent_base.py | 36 ++++++++++++++++++ 4 files changed, 110 insertions(+), 17 deletions(-) create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py index 30a63db79a..a7613e9c9e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py @@ -1,8 +1,12 @@ import typing -from dataclasses import dataclass import boto3 import botocore.exceptions +from flyteidl.core.tasks_pb2 import TaskTemplate +from google.protobuf.json_format import MessageToDict + +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.models.literals import LiteralMap def update_dict(d: typing.Any, u: typing.Dict[str, typing.Any]) -> typing.Any: @@ -40,7 +44,7 @@ class Boto3AgentMixin: The mixin provides a single method `_call` that can be used to call any boto3 method. """ - def __init__(self, *, service: str, region: typing.Optional[str] = None, **kwargs): + def __init__(self, *, service: typing.Optional[str] = None, region: typing.Optional[str] = None, **kwargs): """ :param service: The AWS service to use - e.g. sagemaker :param region: The region to use for the boto3 client, can be overridden when calling the boto3 method. @@ -49,21 +53,38 @@ def __init__(self, *, service: str, region: typing.Optional[str] = None, **kwarg self._service = service def _call(self, method: str, config: typing.Dict[str, typing.Any], - args: typing.Optional[typing.Dict[str, typing.Any]] = None, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + additional_args: typing.Optional[typing.Dict[str, typing.Any]] = None, region: typing.Optional[str] = None) -> typing.Dict[str, typing.Any]: """ TODO we should probably also accept task_template and inputs separately, and then call update_dict Use this method to call any boto3 method (aws service method). :param method: The boto3 method to call - e.g. create_endpoint_config :param config: The config for the method - e.g. {"EndpointConfigName": "my-endpoint-config"}. The config can - contain placeholders that will be replaced by values from the args dict. For example, if the config is - {"EndpointConfigName": "{endpoint_config_name}"}, and the args dict is {"endpoint_config_name": "my-endpoint-config"}, - then the config will be updated to {"EndpointConfigName": "my-endpoint-config"} before calling the boto3 method. - :param args: The args dict can be used to provide values for placeholders in the config dict. For example, if the config is + contain placeholders that will be replaced by values from the inputs, task_template or additional_args. + For example, if the config is + {"EndpointConfigName": "{inputs.endpoint_config_name}", "EndpointName": "{endpoint_name}", + "Image": "{container.image}"} + and the additional_args dict is {"endpoint_name": "my-endpoint"}, the inputs contains a string literal for + endpoint_config_name and the task_template contains a container with an image, then the config will be updated + to {"EndpointConfigName": "my-endpoint-config", "EndpointName": "my-endpoint", "Image": "my-image"} + before calling the boto3 method. + :param task_template: The task template for the task that is being created. + :param inputs: The inputs for the task that is being created. + :param additional_args: Additional arguments to use for updating the config. These are optional and + can be controlled by the task author. :param region: The region to use for the boto3 client. If not provided, the region provided in the constructor will be used. """ client = boto3.client(self._service, region) - updated_config = update_dict(config, args or {}) + args = {} + if inputs: + args["inputs"] = literal_map_string_repr(inputs) + if task_template: + args["container"] = MessageToDict(task_template.container) + if additional_args: + args.update(additional_args) + updated_config = update_dict(config, args) try: res = getattr(client, method)(**updated_config) except botocore.exceptions.ClientError as e: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py new file mode 100644 index 0000000000..0c98d35eab --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py @@ -0,0 +1,29 @@ +import typing + +import grpc +from flyteidl.admin.agent_pb2 import CreateTaskResponse +from flyteidl.core.tasks_pb2 import TaskTemplate + +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.models.literals import LiteralMap +from .boto3_mixin import Boto3AgentMixin +from .external_api_task import ExternalApiTask + + +class GenericSyncBotoAgent(Boto3AgentMixin, ExternalApiTask): + """ + This provides a general purpose Boto3 agent that can be used to call any boto3 method, synchronously. + The method has to be provided as part of the task template custom field. + + TODO this needs a common base. + """ + + def __init__(self, task_type: str): + super().__init__(task_type=task_type, asynchronous=False) + + def do(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + str_inputs = literal_map_string_repr(inputs) + res = self._call("us-east-2", "create_model", task_template.custom, str_inputs) + return CreateTaskResponse() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py index 5bd1cee319..1ec91b750b 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py @@ -3,42 +3,49 @@ import grpc from flyteidl.admin.agent_pb2 import DeleteTaskResponse, GetTaskResponse, CreateTaskResponse from flyteidl.core.tasks_pb2 import TaskTemplate +from google.protobuf.json_format import MessageToDict from flytekit.extend.backend.base_agent import AgentBase from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.models.literals import LiteralMap from .boto3_mixin import Boto3AgentMixin +from .sync_agent_base import SyncAgentBase -class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): +class SagemakerEndpointAgent(Boto3AgentMixin, SyncAgentBase): def __init__(self, task_type: str): super().__init__(service="sagemaker", region="us-east-2", task_type=task_type, asynchronous=False, ) def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - str_inputs = literal_map_string_repr(inputs) - res = self._call("us-east-2", "create_endpoint_config", task_template.custom, str_inputs) - res = self._call("us-east-2", "create_endpoint", task_template.custom, str_inputs) + # We probably want to make 2 parts in endpoint config, one for the model and one for the endpoint + res = self._call("create_endpoint_config", task_template.custom, task_template, inputs) + res = self._call("create_endpoint", task_template.custom, task_template, inputs) def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + # Wait for endpoint to be created pass def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + # delete endpoint and endpoint config pass class SagemakerModelAgent(Boto3AgentMixin, AgentBase): + boto3_method = "create_model" + boto3_service = "sagemaker" def __init__(self, task_type: str): - super().__init__(service="sagemaker", region="us-east-2", task_type=task_type, asynchronous=False, ) + super().__init__(service=self.boto3_service, region="us-east-2", task_type=task_type, asynchronous=False, ) def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - str_inputs = literal_map_string_repr(inputs) - res = self._call("us-east-2", "create_model", task_template.custom, str_inputs) + custom_config = {} + if task_template.custom: + custom_config = MessageToDict(task_template.custom) + res = self._call(self.boto3_method, custom_config, task_template, inputs, region="us-east-2") + # Should return immediately def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: pass diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py new file mode 100644 index 0000000000..f149681c1a --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py @@ -0,0 +1,36 @@ +import typing + +import grpc +from flyteidl.admin.agent_pb2 import DoTaskResponse +from flytekit.core.external_api_task import TASK_TYPE + +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +T = typing.TypeVar("T") + + +class SyncAgentBase(AgentBase): + """ + TaskExecutor is an agent responsible for executing external API tasks. + + This class is meant to be subclassed when implementing plugins that require + an external API to perform the task execution. It provides a routing mechanism + to direct the task to the appropriate handler based on the task's specifications. + """ + + def __init__(self): + super().__init__(task_type=TASK_TYPE, asynchronous=True) + + async def async_do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> DoTaskResponse: + pass + + +AgentRegistry.register(TaskExecutor()) \ No newline at end of file From deed189a2246163ba728de50123932b237e4b448 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 28 Dec 2023 22:32:58 +0530 Subject: [PATCH 003/120] initial version Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/README.md | 4 - .../awssagemaker/agents/boto3_agent.py | 51 +++ .../awssagemaker/agents/boto3_mixin.py | 170 +++++---- .../awssagemaker/agents/boto_agent.py | 29 -- .../awssagemaker/agents/sagemaker_agents.py | 88 ----- .../agents/sagemaker_deploy_agents.py | 208 +++++++++++ .../awssagemaker/agents/sync_agent_base.py | 36 -- .../awssagemaker/distributed_training.py | 82 ----- .../flytekitplugins/awssagemaker/hpo.py | 168 --------- .../awssagemaker/models/__init__.py | 0 .../awssagemaker/models/hpo_job.py | 181 ---------- .../awssagemaker/models/parameter_ranges.py | 315 ----------------- .../awssagemaker/models/training_job.py | 326 ------------------ .../flytekitplugins/awssagemaker/task.py | 41 +++ .../flytekitplugins/awssagemaker/training.py | 192 ----------- .../scripts/flytekit_sagemaker_runner.py | 92 ----- .../tests/test_flytekit_sagemaker_running.py | 37 -- .../flytekit-aws-sagemaker/tests/test_hpo.py | 124 ------- .../tests/test_hpo_job.py | 79 ----- .../tests/test_parameter_ranges.py | 93 ----- .../tests/test_training.py | 134 ------- .../tests/test_training_job.py | 87 ----- 22 files changed, 409 insertions(+), 2128 deletions(-) create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/distributed_training.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/__init__.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/training_job.py create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py delete mode 100644 plugins/flytekit-aws-sagemaker/scripts/flytekit_sagemaker_runner.py delete mode 100644 plugins/flytekit-aws-sagemaker/tests/test_flytekit_sagemaker_running.py delete mode 100644 plugins/flytekit-aws-sagemaker/tests/test_hpo.py delete mode 100644 plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py delete mode 100644 plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py delete mode 100644 plugins/flytekit-aws-sagemaker/tests/test_training.py delete mode 100644 plugins/flytekit-aws-sagemaker/tests/test_training_job.py diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index 0974da52c5..02abb4de49 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -7,7 +7,3 @@ To install the plugin, run the following command: ```bash pip install flytekitplugins-awssagemaker ``` - -To install Sagemaker in the Flyte deployment's backend, go through the [prerequisites](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/aws/sagemaker_training/index.html#prerequisites). - -[Built-in sagemaker](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/aws/sagemaker_training/sagemaker_builtin_algo_training.html#sphx-glr-auto-integrations-aws-sagemaker-training-sagemaker-builtin-algo-training-py) and [custom sagemaker](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/aws/sagemaker_training/sagemaker_custom_training.html#sphx-glr-auto-integrations-aws-sagemaker-training-sagemaker-custom-training-py) training models can be found in the documentation. diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py new file mode 100644 index 0000000000..95aaa4fa54 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py @@ -0,0 +1,51 @@ +from typing import Any, Optional, Type + +from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource +from flyteidl.core.tasks_pb2 import TaskTemplate + +from flytekit import FlyteContextManager +from flytekit.core.external_api_task import ExternalApiTask +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import LiteralMap + +from .boto3_mixin import Boto3AgentMixin + + +class SyncBotoAgentTask(Boto3AgentMixin, ExternalApiTask): + """A general purpose boto3 agent that can be used to call any boto3 method synchronously.""" + + def __init__(self, name: str, config: dict[str, Any], service: str, region: Optional[str] = None, **kwargs): + super().__init__(service=service, region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + method: str, + output_result_type: Type, + inputs: Optional[LiteralMap] = None, + additional_args: Optional[dict[str, Any]] = None, + region: Optional[str] = None, + ): + inputs = inputs or LiteralMap(literals={}) + result = self._call( + method=method, + config=task_template.custom["task_config"], + inputs=inputs, + task_template=task_template, + additional_args=additional_args, + region=region, + ) + + ctx = FlyteContextManager.current_context() + + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + result, + output_result_type, + TypeEngine.to_literal_type(output_result_type), + ) + } + ).to_flyte_idl() + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py index a7613e9c9e..8fa8025e19 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py @@ -1,7 +1,6 @@ -import typing +from typing import Any, Optional -import boto3 -import botocore.exceptions +import aioboto3 from flyteidl.core.tasks_pb2 import TaskTemplate from google.protobuf.json_format import MessageToDict @@ -9,74 +8,112 @@ from flytekit.models.literals import LiteralMap -def update_dict(d: typing.Any, u: typing.Dict[str, typing.Any]) -> typing.Any: +def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: """ - Recursively update a dictionary with values from another dictionary. E.g. if d is {"EndpointConfigName": "{endpoint_config_name}"}, - and u is {"endpoint_config_name": "my-endpoint-config"}, then the result will be {"EndpointConfigName": "my-endpoint-config"}. - :param d: The dictionary to update (in place) - :param u: The dictionary to use for updating - :return: The updated dictionary - Note that the original dictionary is updated in place. + Recursively update a dictionary with values from another dictionary. + For example, if original_dict is {"EndpointConfigName": "{endpoint_config_name}"}, + and update_dict is {"endpoint_config_name": "my-endpoint-config"}, + then the result will be {"EndpointConfigName": "my-endpoint-config"}. + + :param original_dict: The dictionary to update (in place) + :param update_dict: The dictionary to use for updating + :return: The updated dictionary - note that the original dictionary is updated in place """ - if d is None: + if original_dict is None: return None - if isinstance(d, str): - if "{" in d and "}" in d: - v = d.format(**u) - if v == d: - raise ValueError(f"Could not find value for {d}") - orig_v = u.get(d.replace("{", "").replace("}", "")) - if isinstance(orig_v, str): - return v - return orig_v - return d - if isinstance(d, list): - return [update_dict(i, u) for i in d] - if isinstance(d, dict): - for k, v in d.items(): - d[k] = update_dict(d.get(k), u) - return d - - -# TODO write AsyncBoto3AgentBase - https://github.com/terrycain/aioboto3 + + # If the original value is a string and contains placeholder curly braces + if isinstance(original_dict, str): + if "{" in original_dict and "}" in original_dict: + # Check if there are nested keys + if "." in original_dict: + # Create a copy of update_dict + update_dict_copy = update_dict.copy() + + # Fetch keys from the original_dict + keys = original_dict.strip("{}").split(".") + + # Get value from the nested dictionary + for key in keys: + update_dict_copy = update_dict_copy.get(key) + if not update_dict_copy: + raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") + + return update_dict_copy + + # Retrieve the original value using the key without curly braces + original_value = update_dict.get(original_dict.replace("{", "").replace("}", "")) + + # Check if original_value exists; if so, return it, + # otherwise, raise a ValueError indicating that the value for the key original_dict could not be found. + if original_value: + return original_value + else: + raise ValueError(f"Could not find value for {original_dict}.") + + # If the string does not contain placeholders, return it as is + return original_dict + + # If the original value is a list, recursively update each element in the list + if isinstance(original_dict, list): + return [update_dict_fn(item, update_dict) for item in original_dict] + + # If the original value is a dictionary, recursively update each key-value pair + if isinstance(original_dict, dict): + for key, value in original_dict.items(): + original_dict[key] = update_dict_fn(value, update_dict) + + # Return the updated original dict + return original_dict + + class Boto3AgentMixin: """ - This mixin can be used to create a Flyte agent for any AWS service, using boto3. - The mixin provides a single method `_call` that can be used to call any boto3 method. + This mixin facilitates the creation of a Flyte agent for any AWS service using boto3. + It provides a single method, `_call`, which can be employed to invoke any boto3 method. """ - def __init__(self, *, service: typing.Optional[str] = None, region: typing.Optional[str] = None, **kwargs): + def __init__(self, *, service: Optional[str] = None, region: Optional[str] = None, **kwargs): """ - :param service: The AWS service to use - e.g. sagemaker - :param region: The region to use for the boto3 client, can be overridden when calling the boto3 method. + Initialize the Boto3AgentMixin. + + :param service: The AWS service to use, e.g., sagemaker. + :param region: The region for the boto3 client; can be overridden when calling boto3 methods. """ - self._region = region self._service = service + self._region = region + super().__init__(**kwargs) - def _call(self, method: str, config: typing.Dict[str, typing.Any], - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - additional_args: typing.Optional[typing.Dict[str, typing.Any]] = None, - region: typing.Optional[str] = None) -> typing.Dict[str, typing.Any]: + async def _call( + self, + method: str, + config: dict[str, Any], + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + additional_args: Optional[dict[str, Any]] = None, + region: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + ) -> Any: """ - TODO we should probably also accept task_template and inputs separately, and then call update_dict - Use this method to call any boto3 method (aws service method). - :param method: The boto3 method to call - e.g. create_endpoint_config - :param config: The config for the method - e.g. {"EndpointConfigName": "my-endpoint-config"}. The config can - contain placeholders that will be replaced by values from the inputs, task_template or additional_args. + Utilize this method to invoke any boto3 method (AWS service method). + + :param method: The boto3 method to invoke, e.g., create_endpoint_config. + :param config: The configuration for the method, e.g., {"EndpointConfigName": "my-endpoint-config"}. The config + may contain placeholders replaced by values from inputs, task_template, or additional_args. For example, if the config is {"EndpointConfigName": "{inputs.endpoint_config_name}", "EndpointName": "{endpoint_name}", "Image": "{container.image}"} - and the additional_args dict is {"endpoint_name": "my-endpoint"}, the inputs contains a string literal for - endpoint_config_name and the task_template contains a container with an image, then the config will be updated - to {"EndpointConfigName": "my-endpoint-config", "EndpointName": "my-endpoint", "Image": "my-image"} - before calling the boto3 method. - :param task_template: The task template for the task that is being created. - :param inputs: The inputs for the task that is being created. - :param additional_args: Additional arguments to use for updating the config. These are optional and - can be controlled by the task author. - :param region: The region to use for the boto3 client. If not provided, the region provided in the constructor will be used. + and the additional_args dict is {"endpoint_name": "my-endpoint"}, the inputs contain a string literal for + endpoint_config_name, and the task_template contains a container with an image, + then the config will be updated to {"EndpointConfigName": "my-endpoint-config", "EndpointName": "my-endpoint", + "Image": "my-image"} before invoking the boto3 method. + :param task_template: The task template for the task being created. + :param inputs: The inputs for the task being created. + :param additional_args: Additional arguments for updating the config. These are optional and can be controlled by the task author. + :param region: The region for the boto3 client. If not provided, the region specified in the constructor will be used. """ - client = boto3.client(self._service, region) args = {} if inputs: args["inputs"] = literal_map_string_repr(inputs) @@ -84,9 +121,20 @@ def _call(self, method: str, config: typing.Dict[str, typing.Any], args["container"] = MessageToDict(task_template.container) if additional_args: args.update(additional_args) - updated_config = update_dict(config, args) - try: - res = getattr(client, method)(**updated_config) - except botocore.exceptions.ClientError as e: - raise e - return res + + updated_config = update_dict_fn(config, args) + + session = aioboto3.Session() + async with session.client( + service_name=self._service, + region_name=region, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) as client: + try: + result = await getattr(client, method)(**updated_config) + except Exception as e: + raise e + + return result diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py deleted file mode 100644 index 0c98d35eab..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto_agent.py +++ /dev/null @@ -1,29 +0,0 @@ -import typing - -import grpc -from flyteidl.admin.agent_pb2 import CreateTaskResponse -from flyteidl.core.tasks_pb2 import TaskTemplate - -from flytekit.interaction.string_literals import literal_map_string_repr -from flytekit.models.literals import LiteralMap -from .boto3_mixin import Boto3AgentMixin -from .external_api_task import ExternalApiTask - - -class GenericSyncBotoAgent(Boto3AgentMixin, ExternalApiTask): - """ - This provides a general purpose Boto3 agent that can be used to call any boto3 method, synchronously. - The method has to be provided as part of the task template custom field. - - TODO this needs a common base. - """ - - def __init__(self, task_type: str): - super().__init__(task_type=task_type, asynchronous=False) - - def do(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - str_inputs = literal_map_string_repr(inputs) - res = self._call("us-east-2", "create_model", task_template.custom, str_inputs) - return CreateTaskResponse() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py deleted file mode 100644 index 1ec91b750b..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_agents.py +++ /dev/null @@ -1,88 +0,0 @@ -import typing - -import grpc -from flyteidl.admin.agent_pb2 import DeleteTaskResponse, GetTaskResponse, CreateTaskResponse -from flyteidl.core.tasks_pb2 import TaskTemplate -from google.protobuf.json_format import MessageToDict - -from flytekit.extend.backend.base_agent import AgentBase -from flytekit.interaction.string_literals import literal_map_string_repr -from flytekit.models.literals import LiteralMap -from .boto3_mixin import Boto3AgentMixin -from .sync_agent_base import SyncAgentBase - - -class SagemakerEndpointAgent(Boto3AgentMixin, SyncAgentBase): - - def __init__(self, task_type: str): - super().__init__(service="sagemaker", region="us-east-2", task_type=task_type, asynchronous=False, ) - - def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: - # We probably want to make 2 parts in endpoint config, one for the model and one for the endpoint - res = self._call("create_endpoint_config", task_template.custom, task_template, inputs) - res = self._call("create_endpoint", task_template.custom, task_template, inputs) - - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - # Wait for endpoint to be created - pass - - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - # delete endpoint and endpoint config - pass - - -class SagemakerModelAgent(Boto3AgentMixin, AgentBase): - boto3_method = "create_model" - boto3_service = "sagemaker" - - def __init__(self, task_type: str): - super().__init__(service=self.boto3_service, region="us-east-2", task_type=task_type, asynchronous=False, ) - - def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: - custom_config = {} - if task_template.custom: - custom_config = MessageToDict(task_template.custom) - res = self._call(self.boto3_method, custom_config, task_template, inputs, region="us-east-2") - # Should return immediately - - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - pass - - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - pass - - -class SagemakerInvokeEndpointAgent(Boto3AgentMixin, AgentBase): - - def __init__(self, task_type: str): - super().__init__(service="sagemaker-runtime", region="us-east-2", task_type=task_type, asynchronous=False, ) - - def create(self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - str_inputs = literal_map_string_repr(inputs) - res = self._call("us-east-2", "invoke_endpoint", task_template.custom, str_inputs) - - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - pass - - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - pass - - - -# create_model = SagemakerCreateModelTask( -# inputs=kwtypes(model_name=str, image=str, model_path=str), -# config={ -# "ModelName": "{inputs.model_name}", -# "Image": "{container.image}", -# "PrimaryContainer": { -# "Image": "{image}", -# "ModelDataUrl": "{model_path}", -# "region": "us-east-2", -# }, -# }, -# ) - diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py new file mode 100644 index 0000000000..d97f01944b --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py @@ -0,0 +1,208 @@ +import json +from dataclasses import asdict, dataclass +from typing import Any, Optional + +import grpc +from flyteidl.admin.agent_pb2 import ( + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) +from flyteidl.core.tasks_pb2 import TaskTemplate + +from flytekit import FlyteContextManager +from flytekit.core.external_api_task import ExternalApiTask +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, convert_to_flyte_state +from flytekit.models.literals import LiteralMap +from flytekit.extend.backend.base_agent import get_agent_secret + +from .boto3_mixin import Boto3AgentMixin + + +@dataclass +class Metadata: + endpoint_name: str + region: str + + +class SagemakerModelTask(Boto3AgentMixin, ExternalApiTask): + """This agent creates a Sagemaker model.""" + + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + additional_args: Optional[dict[str, Any]] = None, + ) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + + result = self._call( + method="create_model", + config=task_template.custom["task_config"], + inputs=inputs, + task_template=task_template, + additional_args=additional_args, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + result, + type(result), + TypeEngine.to_literal_type(type(result)), + ) + } + ).to_flyte_idl() + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + + +class SagemakerEndpointConfigTask(Boto3AgentMixin, ExternalApiTask): + """This agent creates an endpoint config.""" + + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + additional_args: Optional[dict[str, Any]] = None, + ) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + + result = self._call( + method="create_endpoint_config", + inputs=inputs, + config=task_template.custom["task_config"], + additional_args=additional_args, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + result, + type(result), + TypeEngine.to_literal_type(type(result)), + ) + } + ).to_flyte_idl() + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + + +class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): + """This agent creates an endpoint.""" + + def __init__(self, region: str): + super().__init__( + service="sagemaker-runtime", + region=region, + task_type="sagemaker-endpoint", + asynchronous=True, + ) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + custom = task_template.custom + config = custom["config"] + region = custom["region"] + + await self._call( + "create_endpoint", + config=config, + task_template=task_template, + inputs=inputs, + region=region, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + metadata = Metadata(endpoint_name=config["EndpointName"], region=region) + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + + endpoint_status = await self._call( + "describe_endpoint", + config={"EndpointName": metadata.endpoint_name}, + ) + + current_state = endpoint_status.get("EndpointStatus") + message = "" + if current_state in ("Failed", "UpdateRollbackFailed"): + message = endpoint_status.get("FailureReason") + + # THIS WON'T WORK. NEED TO FIX THIS. + flyte_state = convert_to_flyte_state(current_state) + + return GetTaskResponse(resource=Resource(state=flyte_state, message=message)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + + await self._call( + "delete_endpoint", + config={"EndpointName": metadata.endpoint_name}, + ) + + return DeleteTaskResponse() + + +class SagemakerInvokeEndpointTask(Boto3AgentMixin, ExternalApiTask): + """This agent invokes an endpoint.""" + + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + additional_args: Optional[dict[str, Any]] = None, + ) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + + result = self._call( + method="invoke_endpoint", + inputs=inputs, + config=task_template.custom["task_config"], + additional_args=additional_args, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + result, + type(result), + TypeEngine.to_literal_type(type(result)), + ) + } + ).to_flyte_idl() + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py deleted file mode 100644 index f149681c1a..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sync_agent_base.py +++ /dev/null @@ -1,36 +0,0 @@ -import typing - -import grpc -from flyteidl.admin.agent_pb2 import DoTaskResponse -from flytekit.core.external_api_task import TASK_TYPE - -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry -from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate - -T = typing.TypeVar("T") - - -class SyncAgentBase(AgentBase): - """ - TaskExecutor is an agent responsible for executing external API tasks. - - This class is meant to be subclassed when implementing plugins that require - an external API to perform the task execution. It provides a routing mechanism - to direct the task to the appropriate handler based on the task's specifications. - """ - - def __init__(self): - super().__init__(task_type=TASK_TYPE, asynchronous=True) - - async def async_do( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - ) -> DoTaskResponse: - pass - - -AgentRegistry.register(TaskExecutor()) \ No newline at end of file diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/distributed_training.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/distributed_training.py deleted file mode 100644 index 2b69bd430d..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/distributed_training.py +++ /dev/null @@ -1,82 +0,0 @@ -from __future__ import annotations - -import json -import os -import typing -from dataclasses import dataclass - -import retry - -SM_RESOURCE_CONFIG_FILE = "/opt/ml/input/config/resourceconfig.json" -SM_ENV_VAR_CURRENT_HOST = "SM_CURRENT_HOST" -SM_ENV_VAR_HOSTS = "SM_HOSTS" -SM_ENV_VAR_NETWORK_INTERFACE_NAME = "SM_NETWORK_INTERFACE_NAME" - - -def setup_envars_for_testing(): - """ - This method is useful in simulating the env variables that sagemaker will set on the execution environment - """ - os.environ[SM_ENV_VAR_CURRENT_HOST] = "host" - os.environ[SM_ENV_VAR_HOSTS] = '["host1","host2"]' - os.environ[SM_ENV_VAR_NETWORK_INTERFACE_NAME] = "nw" - - -@dataclass -class DistributedTrainingContext(object): - current_host: str - hosts: typing.List[str] - network_interface_name: str - - @classmethod - @retry.retry(exceptions=KeyError, delay=1, tries=10, backoff=1) - def from_env(cls) -> DistributedTrainingContext: - """ - SageMaker suggests "Hostname information might not be immediately available to the processing container. - We recommend adding a retry policy on hostname resolution operations as nodes become available in the cluster." - https://docs.aws.amazon.com/sagemaker/latest/dg/build-your-own-processing-container.html#byoc-config - This is why we have an automatic retry policy - """ - curr_host = os.environ.get(SM_ENV_VAR_CURRENT_HOST) - raw_hosts = os.environ.get(SM_ENV_VAR_HOSTS) - nw_iface = os.environ.get(SM_ENV_VAR_NETWORK_INTERFACE_NAME) - if not (curr_host and raw_hosts and nw_iface): - raise KeyError("Unable to locate Sagemaker Environment variables!") - hosts = json.loads(raw_hosts) - return DistributedTrainingContext(curr_host, hosts, nw_iface) - - @classmethod - @retry.retry(exceptions=FileNotFoundError, delay=1, tries=10, backoff=1) - def from_sagemaker_context_file(cls) -> DistributedTrainingContext: - with open(SM_RESOURCE_CONFIG_FILE, "r") as rc_file: - d = json.load(rc_file) - curr_host = d["current_host"] - hosts = d["hosts"] - nw_iface = d["network_interface_name"] - - if not (curr_host and hosts and nw_iface): - raise KeyError - - return DistributedTrainingContext(curr_host, hosts, nw_iface) - - @classmethod - def local_execute(cls) -> DistributedTrainingContext: - """ - Creates a dummy local execution context for distributed execution. - TODO revisit if this is a good idea - """ - return DistributedTrainingContext(hosts=["localhost"], current_host="localhost", network_interface_name="dummy") - - -DISTRIBUTED_TRAINING_CONTEXT_KEY = "DISTRIBUTED_TRAINING_CONTEXT" -""" -Use this key to retrieve the distributed training context of type :py:class:`DistributedTrainingContext`. -Usage: - -.. code-block:: python - - ctx = flytekit.current_context().distributed_training_context - # OR - ctx = flytekit.current_context().get(sagemaker.DISTRIBUTED_TRAINING_CONTEXT_KEY) - -""" diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py deleted file mode 100644 index 1229b96195..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py +++ /dev/null @@ -1,168 +0,0 @@ -import json -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Type, Union - -from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job -from flyteidl.plugins.sagemaker import parameter_ranges_pb2 as _pb2_params -from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerCustomTrainingTask -from google.protobuf import json_format -from google.protobuf.json_format import MessageToDict - -from flytekit import FlyteContext -from flytekit.configuration import SerializationSettings -from flytekit.extend import DictTransformer, PythonTask, TypeEngine, TypeTransformer -from flytekit.models.literals import Literal -from flytekit.models.types import LiteralType, SimpleType - -from .models import hpo_job as _hpo_job_model -from .models import parameter_ranges as _params -from .models import training_job as _training_job_model - - -@dataclass -class HPOJob(object): - """ - HPOJob Configuration should be used to configure the HPO Job. - - Args: - max_number_of_training_jobs: maximum number of jobs to run for a training round - max_parallel_training_jobs: limits the concurrency of the training jobs - tunable_params: [optional] should be a list of parameters for which we want to provide the tuning ranges - """ - - max_number_of_training_jobs: int - max_parallel_training_jobs: int - # TODO. we could make the tunable params a tuple of name and type of range? - tunable_params: Optional[List[str]] = None - - -# TODO Not done yet, but once we clean this up, the client interface should be simplified. The interface should -# Just take a list of Union of different types of Parameter Ranges. Lets see how simplify that -class SagemakerHPOTask(PythonTask[HPOJob]): - _SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK = "sagemaker_hyperparameter_tuning_job_task" - - def __init__( - self, - name: str, - task_config: HPOJob, - training_task: Union[SagemakerCustomTrainingTask, SagemakerBuiltinAlgorithmsTask], - **kwargs, - ): - if training_task is None or not ( - isinstance(training_task, SagemakerCustomTrainingTask) - or isinstance(training_task, SagemakerBuiltinAlgorithmsTask) - ): - raise ValueError( - "Training Task of type SagemakerCustomTrainingTask/SagemakerBuiltinAlgorithmsTask is required to work" - " with Sagemaker HPO" - ) - - self._task_config = task_config - self._training_task = training_task - - extra_inputs = {"hyperparameter_tuning_job_config": _hpo_job_model.HyperparameterTuningJobConfig} - - if task_config.tunable_params: - extra_inputs.update({param: _params.ParameterRangeOneOf for param in task_config.tunable_params}) - - iface = training_task.python_interface - updated_iface = iface.with_inputs(extra_inputs) - super().__init__( - task_type=self._SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, - name=name, - interface=updated_iface, - task_config=task_config, - **kwargs, - ) - - def execute(self, **kwargs) -> Any: - raise NotImplementedError("Sagemaker HPO Task cannot be executed locally, to execute locally mock it!") - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - training_job = _training_job_model.TrainingJob( - algorithm_specification=self._training_task.task_config.algorithm_specification, - training_job_resource_config=self._training_task.task_config.training_job_resource_config, - ) - return MessageToDict( - _hpo_job_model.HyperparameterTuningJob( - max_number_of_training_jobs=self.task_config.max_number_of_training_jobs, - max_parallel_training_jobs=self.task_config.max_parallel_training_jobs, - training_job=training_job, - ).to_flyte_idl() - ) - - -# %% -# HPO Task allows ParameterRangeOneOf and HyperparameterTuningJobConfig as inputs. In flytekit this is possible -# to allow these two types to be registered as valid input / output types and provide a custom transformer -# We will create custom transformers for them as follows and provide them once a user loads HPO task - - -class HPOTuningJobConfigTransformer(TypeTransformer[_hpo_job_model.HyperparameterTuningJobConfig]): - """ - Transformer to make ``HyperparameterTuningJobConfig`` an accepted value, for which a transformer is registered - """ - - def __init__(self): - super().__init__("sagemaker-hpojobconfig-transformer", _hpo_job_model.HyperparameterTuningJobConfig) - - def get_literal_type(self, t: Type[_hpo_job_model.HyperparameterTuningJobConfig]) -> LiteralType: - return LiteralType(simple=SimpleType.STRUCT, metadata=None) - - def to_literal( - self, - ctx: FlyteContext, - python_val: _hpo_job_model.HyperparameterTuningJobConfig, - python_type: Type[_hpo_job_model.HyperparameterTuningJobConfig], - expected: LiteralType, - ) -> Literal: - d = MessageToDict(python_val.to_flyte_idl()) - return DictTransformer.dict_to_generic_literal(d) - - def to_python_value( - self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[_hpo_job_model.HyperparameterTuningJobConfig] - ) -> _hpo_job_model.HyperparameterTuningJobConfig: - if lv and lv.scalar and lv.scalar.generic is not None: - d = json.loads(json_format.MessageToJson(lv.scalar.generic)) - o = _pb2_hpo_job.HyperparameterTuningJobConfig() - o = json_format.ParseDict(d, o) - return _hpo_job_model.HyperparameterTuningJobConfig.from_flyte_idl(o) - return None - - -class ParameterRangesTransformer(TypeTransformer[_params.ParameterRangeOneOf]): - """ - Transformer to make ``ParameterRange`` an accepted value, for which a transformer is registered - """ - - def __init__(self): - super().__init__("sagemaker-paramrange-transformer", _params.ParameterRangeOneOf) - - def get_literal_type(self, t: Type[_params.ParameterRangeOneOf]) -> LiteralType: - return LiteralType(simple=SimpleType.STRUCT, metadata=None) - - def to_literal( - self, - ctx: FlyteContext, - python_val: _params.ParameterRangeOneOf, - python_type: Type[_hpo_job_model.HyperparameterTuningJobConfig], - expected: LiteralType, - ) -> Literal: - d = MessageToDict(python_val.to_flyte_idl()) - return DictTransformer.dict_to_generic_literal(d) - - def to_python_value( - self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[_params.ParameterRangeOneOf] - ) -> _params.ParameterRangeOneOf: - if lv and lv.scalar and lv.scalar.generic is not None: - d = json.loads(json_format.MessageToJson(lv.scalar.generic)) - o = _pb2_params.ParameterRangeOneOf() - o = json_format.ParseDict(d, o) - return _params.ParameterRangeOneOf.from_flyte_idl(o) - return None - - -# %% -# Register the types -TypeEngine.register(HPOTuningJobConfigTransformer()) -TypeEngine.register(ParameterRangesTransformer()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py deleted file mode 100644 index 6d6b17189f..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py +++ /dev/null @@ -1,181 +0,0 @@ -from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job - -from flytekit.models import common as _common - -from . import training_job as _training_job - - -class HyperparameterTuningObjectiveType(object): - MINIMIZE = _pb2_hpo_job.HyperparameterTuningObjectiveType.MINIMIZE - MAXIMIZE = _pb2_hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE - - -class HyperparameterTuningObjective(_common.FlyteIdlEntity): - """ - HyperparameterTuningObjective is a data structure that contains the target metric and the - objective of the hyperparameter tuning. - - https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html - """ - - def __init__( - self, - objective_type: int, - metric_name: str, - ): - self._objective_type = objective_type - self._metric_name = metric_name - - @property - def objective_type(self) -> int: - """ - Enum value of HyperparameterTuningObjectiveType. objective_type determines the direction of the tuning of - the Hyperparameter Tuning Job with respect to the specified metric. - :rtype: int - """ - return self._objective_type - - @property - def metric_name(self) -> str: - """ - The target metric name, which is the user-defined name of the metric specified in the - training job's algorithm specification - :rtype: str - """ - return self._metric_name - - def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningObjective: - return _pb2_hpo_job.HyperparameterTuningObjective( - objective_type=self.objective_type, - metric_name=self._metric_name, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningObjective): - return cls( - objective_type=pb2_object.objective_type, - metric_name=pb2_object.metric_name, - ) - - -class HyperparameterTuningStrategy: - BAYESIAN = _pb2_hpo_job.HyperparameterTuningStrategy.BAYESIAN - RANDOM = _pb2_hpo_job.HyperparameterTuningStrategy.RANDOM - - -class TrainingJobEarlyStoppingType: - OFF = _pb2_hpo_job.TrainingJobEarlyStoppingType.OFF - AUTO = _pb2_hpo_job.TrainingJobEarlyStoppingType.AUTO - - -class HyperparameterTuningJobConfig(_common.FlyteIdlEntity): - """ - The specification of the hyperparameter tuning process - https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-ex-tuning-job.html#automatic-model-tuning-ex-low-tuning-config - """ - - def __init__( - self, - tuning_strategy: int, - tuning_objective: HyperparameterTuningObjective, - training_job_early_stopping_type: TrainingJobEarlyStoppingType, - ): - self._tuning_strategy = tuning_strategy - self._tuning_objective = tuning_objective - self._training_job_early_stopping_type = training_job_early_stopping_type - - @property - def tuning_strategy(self) -> int: - """ - Enum value of HyperparameterTuningStrategy. Setting the strategy used when searching in the hyperparameter space - :rtype: int - """ - return self._tuning_strategy - - @property - def tuning_objective(self) -> HyperparameterTuningObjective: - """ - The target metric and the objective of the hyperparameter tuning. - :rtype: HyperparameterTuningObjective - """ - return self._tuning_objective - - @property - def training_job_early_stopping_type(self) -> int: - """ - Enum value of TrainingJobEarlyStoppingType. When the training jobs launched by the hyperparameter tuning job - are not improving significantly, a hyperparameter tuning job can be stopping early. This attribute determines - how the early stopping is to be done. - Note that there's only a subset of built-in algorithms that supports early stopping. - see: https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-early-stopping.html - :rtype: int - """ - return self._training_job_early_stopping_type - - def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJobConfig: - return _pb2_hpo_job.HyperparameterTuningJobConfig( - tuning_strategy=self._tuning_strategy, - tuning_objective=self._tuning_objective.to_flyte_idl(), - training_job_early_stopping_type=self._training_job_early_stopping_type, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJobConfig): - return cls( - tuning_strategy=pb2_object.tuning_strategy, - tuning_objective=HyperparameterTuningObjective.from_flyte_idl(pb2_object.tuning_objective), - training_job_early_stopping_type=pb2_object.training_job_early_stopping_type, - ) - - -class HyperparameterTuningJob(_common.FlyteIdlEntity): - def __init__( - self, - max_number_of_training_jobs: int, - max_parallel_training_jobs: int, - training_job: _training_job.TrainingJob, - ): - self._max_number_of_training_jobs = max_number_of_training_jobs - self._max_parallel_training_jobs = max_parallel_training_jobs - self._training_job = training_job - - @property - def max_number_of_training_jobs(self) -> int: - """ - The maximum number of training jobs that a hyperparameter tuning job can launch. - https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ResourceLimits.html - :rtype: int - """ - return self._max_number_of_training_jobs - - @property - def max_parallel_training_jobs(self) -> int: - """ - The maximum number of concurrent training job that an hpo job can launch - https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ResourceLimits.html - :rtype: int - """ - return self._max_parallel_training_jobs - - @property - def training_job(self) -> _training_job.TrainingJob: - """ - The reference to the underlying training job that the hyperparameter tuning job will launch during the process - :rtype: _training_job.TrainingJob - """ - return self._training_job - - def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningJob: - return _pb2_hpo_job.HyperparameterTuningJob( - max_number_of_training_jobs=self._max_number_of_training_jobs, - max_parallel_training_jobs=self._max_parallel_training_jobs, - training_job=self._training_job.to_flyte_idl(), # SDK task has already serialized it - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJob): - return cls( - max_number_of_training_jobs=pb2_object.max_number_of_training_jobs, - max_parallel_training_jobs=pb2_object.max_parallel_training_jobs, - training_job=_training_job.TrainingJob.from_flyte_idl(pb2_object.training_job), - ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py deleted file mode 100644 index 738f1820a2..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py +++ /dev/null @@ -1,315 +0,0 @@ -from typing import Dict, List, Optional, Union - -from flyteidl.plugins.sagemaker import parameter_ranges_pb2 as _idl_parameter_ranges - -from flytekit.exceptions import user -from flytekit.models import common as _common - - -class HyperparameterScalingType(object): - AUTO = _idl_parameter_ranges.HyperparameterScalingType.AUTO - LINEAR = _idl_parameter_ranges.HyperparameterScalingType.LINEAR - LOGARITHMIC = _idl_parameter_ranges.HyperparameterScalingType.LOGARITHMIC - REVERSELOGARITHMIC = _idl_parameter_ranges.HyperparameterScalingType.REVERSELOGARITHMIC - - -class ContinuousParameterRange(_common.FlyteIdlEntity): - def __init__( - self, - max_value: float, - min_value: float, - scaling_type: int, - ): - """ - - :param float max_value: - :param float min_value: - :param int scaling_type: - """ - self._max_value = max_value - self._min_value = min_value - self._scaling_type = scaling_type - - @property - def max_value(self) -> float: - """ - - :rtype: float - """ - return self._max_value - - @property - def min_value(self) -> float: - """ - - :rtype: float - """ - return self._min_value - - @property - def scaling_type(self) -> int: - """ - enum value from HyperparameterScalingType - :rtype: int - """ - return self._scaling_type - - def to_flyte_idl(self) -> _idl_parameter_ranges.ContinuousParameterRange: - """ - :rtype: _idl_parameter_ranges.ContinuousParameterRange - """ - - return _idl_parameter_ranges.ContinuousParameterRange( - max_value=self._max_value, - min_value=self._min_value, - scaling_type=self.scaling_type, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ContinuousParameterRange): - """ - - :param pb2_object: - :rtype: ContinuousParameterRange - """ - return cls( - max_value=pb2_object.max_value, - min_value=pb2_object.min_value, - scaling_type=pb2_object.scaling_type, - ) - - -class IntegerParameterRange(_common.FlyteIdlEntity): - def __init__( - self, - max_value: int, - min_value: int, - scaling_type: int, - ): - """ - :param int max_value: - :param int min_value: - :param int scaling_type: - """ - self._max_value = max_value - self._min_value = min_value - self._scaling_type = scaling_type - - @property - def max_value(self) -> int: - """ - :rtype: int - """ - return self._max_value - - @property - def min_value(self) -> int: - """ - - :rtype: int - """ - return self._min_value - - @property - def scaling_type(self) -> int: - """ - enum value from HyperparameterScalingType - :rtype: int - """ - return self._scaling_type - - def to_flyte_idl(self) -> _idl_parameter_ranges.IntegerParameterRange: - """ - :rtype: _idl_parameter_ranges.IntegerParameterRange - """ - return _idl_parameter_ranges.IntegerParameterRange( - max_value=self._max_value, - min_value=self._min_value, - scaling_type=self.scaling_type, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.IntegerParameterRange): - """ - - :param pb2_object: - :rtype: IntegerParameterRange - """ - return cls( - max_value=pb2_object.max_value, - min_value=pb2_object.min_value, - scaling_type=pb2_object.scaling_type, - ) - - -class CategoricalParameterRange(_common.FlyteIdlEntity): - def __init__( - self, - values: List[str], - ): - """ - - :param List[str] values: list of strings representing categorical values - """ - self._values = values - - @property - def values(self) -> List[str]: - """ - :rtype: List[str] - """ - return self._values - - def to_flyte_idl(self) -> _idl_parameter_ranges.CategoricalParameterRange: - """ - :rtype: _idl_parameter_ranges.CategoricalParameterRange - """ - return _idl_parameter_ranges.CategoricalParameterRange(values=self._values) - - @classmethod - def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.CategoricalParameterRange): - """ - - :param pb2_object: - :rtype: CategoricalParameterRange - """ - return cls(values=[v for v in pb2_object.values]) - - -class ParameterRanges(_common.FlyteIdlEntity): - def __init__( - self, - parameter_range_map: Dict[str, _common.FlyteIdlEntity], - ): - self._parameter_range_map = parameter_range_map - - def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRanges: - """ - - :rtype: _idl_parameter_ranges.ParameterRanges - """ - converted = {} - for k, v in self._parameter_range_map.items(): - if isinstance(v, IntegerParameterRange): - converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(integer_parameter_range=v.to_flyte_idl()) - elif isinstance(v, ContinuousParameterRange): - converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(continuous_parameter_range=v.to_flyte_idl()) - elif isinstance(v, CategoricalParameterRange): - converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(categorical_parameter_range=v.to_flyte_idl()) - else: - raise user.FlyteTypeException( - received_type=type(v), - expected_type=type( - Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange] - ), - ) - - return _idl_parameter_ranges.ParameterRanges( - parameter_range_map=converted, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): - """ - - :param pb2_object: - :rtype: ParameterRanges - """ - converted = {} - for k, v in pb2_object.parameter_range_map.items(): - if v.HasField("continuous_parameter_range"): - converted[k] = ContinuousParameterRange.from_flyte_idl(v.continuous_parameter_range) - elif v.HasField("integer_parameter_range"): - converted[k] = IntegerParameterRange.from_flyte_idl(v.integer_parameter_range) - else: - converted[k] = CategoricalParameterRange.from_flyte_idl(v.categorical_parameter_range) - - return cls( - parameter_range_map=converted, - ) - - -class ParameterRangeOneOf(_common.FlyteIdlEntity): - def __init__(self, param: Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange]): - """ - Initializes a new ParameterRangeOneOf. - - :param Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange] param: One of the - supported parameter ranges. - """ - self._integer_parameter_range = param if isinstance(param, IntegerParameterRange) else None - self._continuous_parameter_range = param if isinstance(param, ContinuousParameterRange) else None - self._categorical_parameter_range = param if isinstance(param, CategoricalParameterRange) else None - - @property - def integer_parameter_range(self) -> Optional[IntegerParameterRange]: - """ - Retrieves the integer parameter range if one is set. None otherwise. - :rtype: Optional[IntegerParameterRange] - """ - if self._integer_parameter_range: - return self._integer_parameter_range - - return None - - @property - def continuous_parameter_range(self) -> Optional[ContinuousParameterRange]: - """ - Retrieves the continuous parameter range if one is set. None otherwise. - :rtype: Optional[ContinuousParameterRange] - """ - if self._continuous_parameter_range: - return self._continuous_parameter_range - - return None - - @property - def categorical_parameter_range(self) -> Optional[CategoricalParameterRange]: - """ - Retrieves the categorical parameter range if one is set. None otherwise. - :rtype: Optional[CategoricalParameterRange] - """ - if self._categorical_parameter_range: - return self._categorical_parameter_range - - return None - - def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRangeOneOf: - return _idl_parameter_ranges.ParameterRangeOneOf( - integer_parameter_range=self.integer_parameter_range.to_flyte_idl() - if self.integer_parameter_range - else None, - continuous_parameter_range=self.continuous_parameter_range.to_flyte_idl() - if self.continuous_parameter_range - else None, - categorical_parameter_range=self.categorical_parameter_range.to_flyte_idl() - if self.categorical_parameter_range - else None, - ) - - @classmethod - def from_flyte_idl( - cls, - pb_object: Union[ - _idl_parameter_ranges.ParameterRangeOneOf, - _idl_parameter_ranges.IntegerParameterRange, - _idl_parameter_ranges.ContinuousParameterRange, - _idl_parameter_ranges.CategoricalParameterRange, - ], - ): - param = None - if isinstance(pb_object, _idl_parameter_ranges.ParameterRangeOneOf): - if pb_object.HasField("continuous_parameter_range"): - param = ContinuousParameterRange.from_flyte_idl(pb_object.continuous_parameter_range) - elif pb_object.HasField("integer_parameter_range"): - param = IntegerParameterRange.from_flyte_idl(pb_object.integer_parameter_range) - elif pb_object.HasField("categorical_parameter_range"): - param = CategoricalParameterRange.from_flyte_idl(pb_object.categorical_parameter_range) - elif isinstance(pb_object, _idl_parameter_ranges.IntegerParameterRange): - param = IntegerParameterRange.from_flyte_idl(pb_object) - elif isinstance(pb_object, _idl_parameter_ranges.ContinuousParameterRange): - param = ContinuousParameterRange.from_flyte_idl(pb_object) - elif isinstance(pb_object, _idl_parameter_ranges.CategoricalParameterRange): - param = CategoricalParameterRange.from_flyte_idl(pb_object) - - return cls(param=param) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/training_job.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/training_job.py deleted file mode 100644 index 238aa27fa4..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/training_job.py +++ /dev/null @@ -1,326 +0,0 @@ -from typing import List - -from flyteidl.plugins.sagemaker import training_job_pb2 as _training_job_pb2 - -from flytekit.models import common as _common - - -class DistributedProtocol(object): - """ - The distribution framework is used for determining which underlying distributed training mechanism to use. - This is only required for use cases where the user wants to train its custom training job in a distributed manner - """ - - UNSPECIFIED = _training_job_pb2.DistributedProtocol.UNSPECIFIED - MPI = _training_job_pb2.DistributedProtocol.MPI - - -class TrainingJobResourceConfig(_common.FlyteIdlEntity): - """ - TrainingJobResourceConfig is a pass-through, specifying the instance type to use for the training job, the - number of instances to launch, and the size of the ML storage volume the user wants to provision - Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html - """ - - def __init__( - self, - instance_type: str, - volume_size_in_gb: int, - instance_count: int = 1, - distributed_protocol: int = DistributedProtocol.UNSPECIFIED, - ): - self._instance_count = instance_count - self._instance_type = instance_type - self._volume_size_in_gb = volume_size_in_gb - self._distributed_protocol = distributed_protocol - - @property - def instance_count(self) -> int: - """ - The number of ML compute instances to use. For distributed training, provide a value greater than 1. - :rtype: int - """ - return self._instance_count - - @property - def instance_type(self) -> str: - """ - The ML compute instance type. - :rtype: str - """ - return self._instance_type - - @property - def volume_size_in_gb(self) -> int: - """ - The size of the ML storage volume that you want to provision to store the data and intermediate artifacts, etc. - :rtype: int - """ - return self._volume_size_in_gb - - @property - def distributed_protocol(self) -> int: - """ - The distribution framework is used to determine through which mechanism the distributed training is done. - enum value from DistributionFramework. - :rtype: int - """ - return self._distributed_protocol - - def to_flyte_idl(self) -> _training_job_pb2.TrainingJobResourceConfig: - """ - - :rtype: _training_job_pb2.TrainingJobResourceConfig - """ - return _training_job_pb2.TrainingJobResourceConfig( - instance_count=self.instance_count, - instance_type=self.instance_type, - volume_size_in_gb=self.volume_size_in_gb, - distributed_protocol=self.distributed_protocol, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJobResourceConfig): - """ - - :param pb2_object: - :rtype: TrainingJobResourceConfig - """ - return cls( - instance_count=pb2_object.instance_count, - instance_type=pb2_object.instance_type, - volume_size_in_gb=pb2_object.volume_size_in_gb, - distributed_protocol=pb2_object.distributed_protocol, - ) - - -class MetricDefinition(_common.FlyteIdlEntity): - def __init__( - self, - name: str, - regex: str, - ): - self._name = name - self._regex = regex - - @property - def name(self) -> str: - """ - The user-defined name of the metric - :rtype: str - """ - return self._name - - @property - def regex(self) -> str: - """ - SageMaker hyperparameter tuning using this regex to parses your algorithm’s stdout and stderr - streams to find the algorithm metrics on which the users want to track - :rtype: str - """ - return self._regex - - def to_flyte_idl(self) -> _training_job_pb2.MetricDefinition: - """ - - :rtype: _training_job_pb2.MetricDefinition - """ - return _training_job_pb2.MetricDefinition( - name=self.name, - regex=self.regex, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): - """ - - :param pb2_object: _training_job_pb2.MetricDefinition - :rtype: MetricDefinition - """ - return cls( - name=pb2_object.name, - regex=pb2_object.regex, - ) - - -# TODO Convert to Enum -class InputMode(object): - """ - When using FILE input mode, different SageMaker built-in algorithms require different file types of input data - See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html - https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html - """ - - PIPE = _training_job_pb2.InputMode.PIPE - FILE = _training_job_pb2.InputMode.FILE - - -# TODO Convert to enum -class AlgorithmName(object): - """ - The algorithm name is used for deciding which pre-built image to point to. - This is only required for use cases where SageMaker's built-in algorithm mode is used. - While we currently only support a subset of the algorithms, more will be added to the list. - See: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html - """ - - CUSTOM = _training_job_pb2.AlgorithmName.CUSTOM - XGBOOST = _training_job_pb2.AlgorithmName.XGBOOST - - -# TODO convert to enum -class InputContentType(object): - """ - Specifies the type of content for input data. Different SageMaker built-in algorithms require different content types of input data - See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html - https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html - """ - - TEXT_CSV = _training_job_pb2.InputContentType.TEXT_CSV - - -class AlgorithmSpecification(_common.FlyteIdlEntity): - """ - Specifies the training algorithm to be used in the training job - This object is mostly a pass-through, with a couple of exceptions include: (1) in Flyte, users don't need to specify - TrainingImage; either use the built-in algorithm mode by using Flytekit's Simple Training Job and specifying an algorithm - name and an algorithm version or (2) when users want to supply custom algorithms they should set algorithm_name field to - CUSTOM. In this case, the value of the algorithm_version field has no effect - For pass-through use cases: refer to this AWS official document for more details - https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html - """ - - def __init__( - self, - algorithm_name: int = AlgorithmName.CUSTOM, - algorithm_version: str = "", - input_mode: int = InputMode.FILE, - metric_definitions: List[MetricDefinition] = None, - input_content_type: int = InputContentType.TEXT_CSV, - ): - self._input_mode = input_mode - self._input_content_type = input_content_type - self._algorithm_name = algorithm_name - self._algorithm_version = algorithm_version - self._metric_definitions = metric_definitions or [] - - @property - def input_mode(self) -> int: - """ - enum value from InputMode. The input mode can be either PIPE or FILE - :rtype: int - """ - return self._input_mode - - @property - def input_content_type(self) -> int: - """ - enum value from InputContentType. The content type of the input data - See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html - https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html - :rtype: int - """ - return self._input_content_type - - @property - def algorithm_name(self) -> int: - """ - The algorithm name is used for deciding which pre-built image to point to. - enum value from AlgorithmName. - :rtype: int - """ - return self._algorithm_name - - @property - def algorithm_version(self) -> str: - """ - version of the algorithm (if using built-in algorithm mode). - :rtype: str - """ - return self._algorithm_version - - @property - def metric_definitions(self) -> List[MetricDefinition]: - """ - A list of metric definitions for SageMaker to evaluate/track on the progress of the training job - See this: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html - - Note that, when you use one of the Amazon SageMaker built-in algorithms, you cannot define custom metrics. - If you are doing hyperparameter tuning, built-in algorithms automatically send metrics to hyperparameter tuning. - When using hyperparameter tuning, you do need to choose one of the metrics that the built-in algorithm emits as - the objective metric for the tuning job. - See this: https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html - :rtype: List[MetricDefinition] - """ - return self._metric_definitions - - def to_flyte_idl(self) -> _training_job_pb2.AlgorithmSpecification: - return _training_job_pb2.AlgorithmSpecification( - input_mode=self.input_mode, - algorithm_name=self.algorithm_name, - algorithm_version=self.algorithm_version, - metric_definitions=[m.to_flyte_idl() for m in self.metric_definitions], - input_content_type=self.input_content_type, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _training_job_pb2.AlgorithmSpecification): - return cls( - input_mode=pb2_object.input_mode, - algorithm_name=pb2_object.algorithm_name, - algorithm_version=pb2_object.algorithm_version, - metric_definitions=[MetricDefinition.from_flyte_idl(m) for m in pb2_object.metric_definitions], - input_content_type=pb2_object.input_content_type, - ) - - -class TrainingJob(_common.FlyteIdlEntity): - def __init__( - self, - algorithm_specification: AlgorithmSpecification, - training_job_resource_config: TrainingJobResourceConfig, - ): - self._algorithm_specification = algorithm_specification - self._training_job_resource_config = training_job_resource_config - - @property - def algorithm_specification(self) -> AlgorithmSpecification: - """ - Contains the information related to the algorithm to use in the training job - :rtype: AlgorithmSpecification - """ - return self._algorithm_specification - - @property - def training_job_resource_config(self) -> TrainingJobResourceConfig: - """ - Specifies the information around the instances that will be used to run the training job. - :rtype: TrainingJobResourceConfig - """ - return self._training_job_resource_config - - def to_flyte_idl(self) -> _training_job_pb2.TrainingJob: - """ - :rtype: _training_job_pb2.TrainingJob - """ - - return _training_job_pb2.TrainingJob( - algorithm_specification=self.algorithm_specification.to_flyte_idl() - if self.algorithm_specification - else None, - training_job_resource_config=self.training_job_resource_config.to_flyte_idl() - if self.training_job_resource_config - else None, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJob): - """ - - :param pb2_object: - :rtype: TrainingJob - """ - return cls( - algorithm_specification=pb2_object.algorithm_specification, - training_job_resource_config=pb2_object.training_job_resource_config, - ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py new file mode 100644 index 0000000000..445bbd0884 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import Any, Optional + +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin + + +@dataclass +class SagemakerEndpointConfig(object): + config: dict[str, Any] + region: str + + +class SagemakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SagemakerEndpointConfig]): + _TASK_TYPE = "sagemaker-endpoint" + + def __init__( + self, + name: str, + task_config: SagemakerEndpointConfig, + inputs: Optional[dict[str, Any]] = None, + **kwargs, + ): + super().__init__( + name=name, + task_config=task_config, + interface=Interface(inputs=inputs or {}), + task_type=self._TASK_TYPE, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: + config = {"config": self.task_config.config, "region": self.task_config.region} + s = Struct() + s.update(config) + return json_format.MessageToDict(s) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py deleted file mode 100644 index 7f456d19a0..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py +++ /dev/null @@ -1,192 +0,0 @@ -import typing -from dataclasses import dataclass -from typing import Any, Callable, Dict - -from flytekitplugins.awssagemaker.distributed_training import DistributedTrainingContext -from google.protobuf.json_format import MessageToDict -from typing_extensions import Annotated - -import flytekit -from flytekit import ExecutionParameters, FlyteContextManager, PythonFunctionTask, kwtypes -from flytekit.configuration import SerializationSettings -from flytekit.extend import ExecutionState, IgnoreOutputs, Interface, PythonTask, TaskPlugins -from flytekit.loggers import logger -from flytekit.types.directory.types import FlyteDirectory -from flytekit.types.file import FileExt, FlyteFile - -from .models import training_job as _training_job_models - - -@dataclass -class SagemakerTrainingJobConfig(object): - """ - Configuration for Running Training Jobs on Sagemaker. This config can be used to run either the built-in algorithms - or custom algorithms. - - Args: - training_job_resource_config: Configuration for Resources to use during the training - algorithm_specification: Specification of the algorithm to use - should_persist_output: This method will be invoked and will decide if the generated model should be persisted - as the output. ``NOTE: Useful only for distributed training`` - ``default: single node training - always persist output`` - ``default: distributed training - always persist output on node with rank-0`` - """ - - training_job_resource_config: _training_job_models.TrainingJobResourceConfig - algorithm_specification: _training_job_models.AlgorithmSpecification - # The default output-persisting predicate. - # With this predicate, only the copy running on the first host in the list of hosts would persist its output - should_persist_output: typing.Callable[[DistributedTrainingContext], bool] = lambda dctx: ( - dctx.current_host == dctx.hosts[0] - ) - - -class SagemakerBuiltinAlgorithmsTask(PythonTask[SagemakerTrainingJobConfig]): - """ - Implements an interface that allows execution of a SagemakerBuiltinAlgorithms. - Refer to `Sagemaker Builtin Algorithms`_ for details. - """ - - _SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task" - - OUTPUT_TYPE = Annotated[str, FileExt("tar.gz")] - - def __init__( - self, - name: str, - task_config: SagemakerTrainingJobConfig, - **kwargs, - ): - """ - Args: - name: name of this specific task. This should be unique within the project. A good strategy is to prefix - with the module name - metadata: Metadata for the task - task_config: Config to use for the SagemakerBuiltinAlgorithms - """ - if ( - task_config is None - or task_config.algorithm_specification is None - or task_config.training_job_resource_config is None - ): - raise ValueError("TaskConfig, algorithm_specification, training_job_resource_config are required") - - input_type = Annotated[ - str, FileExt(self._content_type_to_blob_format(task_config.algorithm_specification.input_content_type)) - ] - - interface = Interface( - # TODO change train and validation to be FlyteDirectory when available - inputs=kwtypes( - static_hyperparameters=dict, train=FlyteDirectory[input_type], validation=FlyteDirectory[input_type] - ), - outputs=kwtypes(model=FlyteFile[self.OUTPUT_TYPE]), - ) - super().__init__( - self._SAGEMAKER_TRAINING_JOB_TASK, - name, - interface=interface, - task_config=task_config, - **kwargs, - ) - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - training_job = _training_job_models.TrainingJob( - algorithm_specification=self._task_config.algorithm_specification, - training_job_resource_config=self._task_config.training_job_resource_config, - ) - return MessageToDict(training_job.to_flyte_idl()) - - def execute(self, **kwargs) -> Any: - raise NotImplementedError( - "Cannot execute Sagemaker Builtin Algorithms locally, for local testing, please mock!" - ) - - @classmethod - def _content_type_to_blob_format(cls, content_type: int) -> str: - """ - TODO Convert InputContentType to Enum and others - """ - if content_type == _training_job_models.InputContentType.TEXT_CSV: - return "csv" - else: - raise ValueError("Unsupported InputContentType: {}".format(content_type)) - - -class SagemakerCustomTrainingTask(PythonFunctionTask[SagemakerTrainingJobConfig]): - """ - Allows a custom training algorithm to be executed on Amazon Sagemaker. For this to work, make sure your container - is built according to Flyte plugin documentation (TODO point the link here) - """ - - _SAGEMAKER_CUSTOM_TRAINING_JOB_TASK = "sagemaker_custom_training_job_task" - - def __init__( - self, - task_config: SagemakerTrainingJobConfig, - task_function: Callable, - **kwargs, - ): - super().__init__( - task_config=task_config, - task_function=task_function, - task_type=self._SAGEMAKER_CUSTOM_TRAINING_JOB_TASK, - **kwargs, - ) - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - training_job = _training_job_models.TrainingJob( - algorithm_specification=self.task_config.algorithm_specification, - training_job_resource_config=self.task_config.training_job_resource_config, - ) - return MessageToDict(training_job.to_flyte_idl()) - - def _is_distributed(self) -> bool: - """ - Only if more than one instance is specified, we assume it is a distributed training setup - """ - return ( - self.task_config.training_job_resource_config - and self.task_config.training_job_resource_config.instance_count > 1 - ) - - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - """ - Pre-execute for Sagemaker will automatically add the distributed context to the execution params, only - if the number of execution instances is > 1. Otherwise this is considered to be a single node execution - """ - if self._is_distributed(): - logger.info("Distributed context detected!") - exec_state = FlyteContextManager.current_context().execution_state - if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION: - """ - This mode indicates we are actually in a remote execute environment (within sagemaker in this case) - """ - dist_ctx = DistributedTrainingContext.from_env() - else: - dist_ctx = DistributedTrainingContext.local_execute() - return user_params.builder().add_attr("DISTRIBUTED_TRAINING_CONTEXT", dist_ctx).build() - - return user_params - - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: - """ - In the case of distributed execution, we check the should_persist_predicate in the configuration to determine - if the output should be persisted. This is because in distributed training, multiple nodes may produce partial - outputs and only the user process knows the output that should be generated. They can control the choice using - the predicate. - - To control if output is generated across every execution, we override the post_execute method and sometimes - return a None - """ - if self._is_distributed(): - logger.info("Distributed context detected!") - dctx = flytekit.current_context().distributed_training_context - if not self.task_config.should_persist_output(dctx): - logger.info("output persistence predicate not met, Flytekit will ignore outputs") - raise IgnoreOutputs(f"Distributed context - Persistence predicate not met. Ignoring outputs - {dctx}") - return rval - - -# Register the Tensorflow Plugin into the flytekit core plugin system -TaskPlugins.register_pythontask_plugin(SagemakerTrainingJobConfig, SagemakerCustomTrainingTask) diff --git a/plugins/flytekit-aws-sagemaker/scripts/flytekit_sagemaker_runner.py b/plugins/flytekit-aws-sagemaker/scripts/flytekit_sagemaker_runner.py deleted file mode 100644 index 4a3d94fab5..0000000000 --- a/plugins/flytekit-aws-sagemaker/scripts/flytekit_sagemaker_runner.py +++ /dev/null @@ -1,92 +0,0 @@ -import argparse -import logging -import os -import subprocess -import sys - -FLYTE_ARG_PREFIX = "--__FLYTE" -FLYTE_ENV_VAR_PREFIX = f"{FLYTE_ARG_PREFIX}_ENV_VAR_" -FLYTE_CMD_PREFIX = f"{FLYTE_ARG_PREFIX}_CMD_" -FLYTE_ARG_SUFFIX = "__" - - -# This script is the "entrypoint" script for SageMaker. An environment variable must be set on the container (typically -# in the Dockerfile) of SAGEMAKER_PROGRAM=flytekit_sagemaker_runner.py. When the container is launched in SageMaker, -# it'll run `train flytekit_sagemaker_runner.py `, the responsibility of this script is then to decode -# the known hyperparameters (passed as command line args) to recreate the original command that will actually run the -# virtual environment and execute the intended task (e.g. `service_venv pyflyte-execute --task-module ....`) - -# An example for a valid command: -# python flytekit_sagemaker_runner.py --__FLYTE_ENV_VAR_env1__ val1 --__FLYTE_ENV_VAR_env2__ val2 -# --__FLYTE_CMD_0_service_venv__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_1_pyflyte-execute__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_2_--task-module__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_3_blah__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_4_--task-name__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_5_bloh__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_6_--output-prefix__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_7_s3://fake-bucket__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_8_--inputs__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_9_s3://fake-bucket__ __FLYTE_CMD_DUMMY_VALUE__ - - -def parse_args(cli_args): - parser = argparse.ArgumentParser(description="Running sagemaker task") - args, unknowns = parser.parse_known_args(cli_args) - - # Parse the command line and env vars - flyte_cmd = [] - env_vars = {} - i = 0 - - while i < len(unknowns): - unknown = unknowns[i] - logging.info(f"Processing argument {unknown}") - if unknown.startswith(FLYTE_CMD_PREFIX) and unknown.endswith(FLYTE_ARG_SUFFIX): - processed = unknown[len(FLYTE_CMD_PREFIX) :][: -len(FLYTE_ARG_SUFFIX)] - # Parse the format `1_--task-module` - parts = processed.split("_", maxsplit=1) - flyte_cmd.append((parts[0], parts[1])) - i += 1 - elif unknown.startswith(FLYTE_ENV_VAR_PREFIX) and unknown.endswith(FLYTE_ARG_SUFFIX): - processed = unknown[len(FLYTE_ENV_VAR_PREFIX) :][: -len(FLYTE_ARG_SUFFIX)] - i += 1 - if unknowns[i].startswith(FLYTE_ARG_PREFIX) is False: - env_vars[processed] = unknowns[i] - i += 1 - else: - # To prevent SageMaker from ignoring our __FLYTE_CMD_*__ hyperparameters, we need to set a dummy value - # which serves as a placeholder for each of them. The dummy value placeholder `__FLYTE_CMD_DUMMY_VALUE__` - # falls into this branch and will be ignored - i += 1 - - return flyte_cmd, env_vars - - -def sort_flyte_cmd(flyte_cmd): - # Order the cmd using the index (the first element in each tuple) - flyte_cmd = sorted(flyte_cmd, key=lambda x: int(x[0])) - flyte_cmd = [x[1] for x in flyte_cmd] - return flyte_cmd - - -def set_env_vars(env_vars): - for key, val in env_vars.items(): - os.environ[key] = val - - -def run(cli_args): - flyte_cmd, env_vars = parse_args(cli_args) - flyte_cmd = sort_flyte_cmd(flyte_cmd) - set_env_vars(env_vars) - - logging.info(f"Cmd:{flyte_cmd}") - logging.info(f"Env vars:{env_vars}") - - # Launching a subprocess with the selected entrypoint script and the rest of the arguments - logging.info(f"Launching command: {flyte_cmd}") - subprocess.run(flyte_cmd, stdout=sys.stdout, stderr=sys.stderr, encoding="utf-8", check=True) - - -if __name__ == "__main__": - run(sys.argv) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_flytekit_sagemaker_running.py b/plugins/flytekit-aws-sagemaker/tests/test_flytekit_sagemaker_running.py deleted file mode 100644 index f527deb5cc..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/test_flytekit_sagemaker_running.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import sys -from unittest import mock - -from scripts.flytekit_sagemaker_runner import run as _flyte_sagemaker_run - -cmd = [] -cmd.extend(["--__FLYTE_ENV_VAR_env1__", "val1"]) -cmd.extend(["--__FLYTE_ENV_VAR_env2__", "val2"]) -cmd.extend(["--__FLYTE_CMD_0_service_venv__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_1_pyflyte-execute__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_2_--task-module__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_3_blah__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_4_--task-name__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_5_bloh__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_6_--output-prefix__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_7_s3://fake-bucket__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_8_--inputs__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_9_s3://fake-bucket__", "__FLYTE_CMD_DUMMY_VALUE__"]) - - -@mock.patch.dict("os.environ") -@mock.patch("subprocess.run") -def test(mock_subprocess_run): - _flyte_sagemaker_run(cmd) - assert "env1" in os.environ - assert "env2" in os.environ - assert os.environ["env1"] == "val1" - assert os.environ["env2"] == "val2" - mock_subprocess_run.assert_called_with( - "service_venv pyflyte-execute --task-module blah --task-name bloh " - "--output-prefix s3://fake-bucket --inputs s3://fake-bucket".split(), - stdout=sys.stdout, - stderr=sys.stderr, - encoding="utf-8", - check=True, - ) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_hpo.py b/plugins/flytekit-aws-sagemaker/tests/test_hpo.py deleted file mode 100644 index e52994c664..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/test_hpo.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import pathlib -import tempfile - -import pytest -from flytekitplugins.awssagemaker.hpo import ( - HPOJob, - HPOTuningJobConfigTransformer, - ParameterRangesTransformer, - SagemakerHPOTask, -) -from flytekitplugins.awssagemaker.models.hpo_job import ( - HyperparameterTuningJobConfig, - HyperparameterTuningObjective, - HyperparameterTuningObjectiveType, - TrainingJobEarlyStoppingType, -) -from flytekitplugins.awssagemaker.models.parameter_ranges import IntegerParameterRange, ParameterRangeOneOf -from flytekitplugins.awssagemaker.models.training_job import ( - AlgorithmName, - AlgorithmSpecification, - TrainingJobResourceConfig, -) -from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerTrainingJobConfig - -from flytekit import FlyteContext -from flytekit.models.types import LiteralType, SimpleType - -from .test_training import _get_reg_settings - - -def test_hpo_for_builtin(): - trainer = SagemakerBuiltinAlgorithmsTask( - name="builtin-trainer", - task_config=SagemakerTrainingJobConfig( - training_job_resource_config=TrainingJobResourceConfig( - instance_count=1, - instance_type="ml-xlarge", - volume_size_in_gb=1, - ), - algorithm_specification=AlgorithmSpecification( - algorithm_name=AlgorithmName.XGBOOST, - ), - ), - ) - - hpo = SagemakerHPOTask( - name="test", - task_config=HPOJob(10, 10, ["x"]), - training_task=trainer, - ) - - assert hpo.python_interface.inputs.keys() == { - "static_hyperparameters", - "train", - "validation", - "hyperparameter_tuning_job_config", - "x", - } - assert hpo.python_interface.outputs.keys() == {"model"} - - assert hpo.get_custom(_get_reg_settings()) == { - "maxNumberOfTrainingJobs": "10", - "maxParallelTrainingJobs": "10", - "trainingJob": { - "algorithmSpecification": {"algorithmName": "XGBOOST"}, - "trainingJobResourceConfig": {"instanceCount": "1", "instanceType": "ml-xlarge", "volumeSizeInGb": "1"}, - }, - } - - with pytest.raises(NotImplementedError): - with tempfile.TemporaryDirectory() as tmp: - x = pathlib.Path(os.path.join(tmp, "x")) - y = pathlib.Path(os.path.join(tmp, "y")) - x.mkdir(parents=True, exist_ok=True) - y.mkdir(parents=True, exist_ok=True) - - hpo( - static_hyperparameters={}, - train=f"{x}", # file transformer doesn't handle pathlib.Path yet - validation=f"{y}", # file transformer doesn't handle pathlib.Path yet - hyperparameter_tuning_job_config=HyperparameterTuningJobConfig( - tuning_strategy=1, - tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, - metric_name="x", - ), - training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, - ), - x=ParameterRangeOneOf(param=IntegerParameterRange(10, 1, 1)), - ) - - -def test_hpoconfig_transformer(): - t = HPOTuningJobConfigTransformer() - assert t.get_literal_type(HyperparameterTuningJobConfig) == LiteralType(simple=SimpleType.STRUCT) - o = HyperparameterTuningJobConfig( - tuning_strategy=1, - tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, - metric_name="x", - ), - training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, - ) - ctx = FlyteContext.current_context() - lit = t.to_literal(ctx, python_val=o, python_type=HyperparameterTuningJobConfig, expected=None) - assert lit is not None - assert lit.scalar.generic is not None - ro = t.to_python_value(ctx, lit, HyperparameterTuningJobConfig) - assert ro is not None - assert ro == o - - -def test_parameter_ranges_transformer(): - t = ParameterRangesTransformer() - assert t.get_literal_type(ParameterRangeOneOf) == LiteralType(simple=SimpleType.STRUCT) - o = ParameterRangeOneOf(param=IntegerParameterRange(10, 0, 1)) - ctx = FlyteContext.current_context() - lit = t.to_literal(ctx, python_val=o, python_type=ParameterRangeOneOf, expected=None) - assert lit is not None - assert lit.scalar.generic is not None - ro = t.to_python_value(ctx, lit, ParameterRangeOneOf) - assert ro is not None - assert ro == o diff --git a/plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py b/plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py deleted file mode 100644 index 494eecd2ab..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py +++ /dev/null @@ -1,79 +0,0 @@ -from flytekitplugins.awssagemaker.models import hpo_job, training_job - - -def test_hyperparameter_tuning_objective(): - obj = hpo_job.HyperparameterTuningObjective( - objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE, metric_name="test_metric" - ) - obj2 = hpo_job.HyperparameterTuningObjective.from_flyte_idl(obj.to_flyte_idl()) - - assert obj == obj2 - - -def test_hyperparameter_job_config(): - jc = hpo_job.HyperparameterTuningJobConfig( - tuning_strategy=hpo_job.HyperparameterTuningStrategy.BAYESIAN, - tuning_objective=hpo_job.HyperparameterTuningObjective( - objective_type=hpo_job.HyperparameterTuningObjectiveType.MAXIMIZE, metric_name="test_metric" - ), - training_job_early_stopping_type=hpo_job.TrainingJobEarlyStoppingType.AUTO, - ) - - jc2 = hpo_job.HyperparameterTuningJobConfig.from_flyte_idl(jc.to_flyte_idl()) - assert jc2.tuning_strategy == jc.tuning_strategy - assert jc2.tuning_objective == jc.tuning_objective - assert jc2.training_job_early_stopping_type == jc.training_job_early_stopping_type - - -def test_hyperparameter_tuning_job(): - rc = training_job.TrainingJobResourceConfig( - instance_type="test_type", - instance_count=10, - volume_size_in_gb=25, - distributed_protocol=training_job.DistributedProtocol.MPI, - ) - alg = training_job.AlgorithmSpecification( - algorithm_name=training_job.AlgorithmName.CUSTOM, - algorithm_version="", - input_mode=training_job.InputMode.FILE, - input_content_type=training_job.InputContentType.TEXT_CSV, - ) - tj = training_job.TrainingJob( - training_job_resource_config=rc, - algorithm_specification=alg, - ) - hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj) - - hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl()) - - assert hpo.max_number_of_training_jobs == hpo2.max_number_of_training_jobs - assert hpo.max_parallel_training_jobs == hpo2.max_parallel_training_jobs - assert ( - hpo2.training_job.training_job_resource_config.instance_type - == hpo.training_job.training_job_resource_config.instance_type - ) - assert ( - hpo2.training_job.training_job_resource_config.instance_count - == hpo.training_job.training_job_resource_config.instance_count - ) - assert ( - hpo2.training_job.training_job_resource_config.distributed_protocol - == hpo.training_job.training_job_resource_config.distributed_protocol - ) - assert ( - hpo2.training_job.training_job_resource_config.volume_size_in_gb - == hpo.training_job.training_job_resource_config.volume_size_in_gb - ) - assert ( - hpo2.training_job.algorithm_specification.algorithm_name - == hpo.training_job.algorithm_specification.algorithm_name - ) - assert ( - hpo2.training_job.algorithm_specification.algorithm_version - == hpo.training_job.algorithm_specification.algorithm_version - ) - assert hpo2.training_job.algorithm_specification.input_mode == hpo.training_job.algorithm_specification.input_mode - assert ( - hpo2.training_job.algorithm_specification.input_content_type - == hpo.training_job.algorithm_specification.input_content_type - ) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py b/plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py deleted file mode 100644 index 6d33388c33..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py +++ /dev/null @@ -1,93 +0,0 @@ -import unittest - -import pytest -from flytekitplugins.awssagemaker.models import parameter_ranges - - -# assert statements cannot be written inside lambda expressions. This is a convenient function to work around that. -def assert_equal(a, b): - assert a == b - - -def test_continuous_parameter_range(): - pr = parameter_ranges.ContinuousParameterRange( - max_value=10, min_value=0.5, scaling_type=parameter_ranges.HyperparameterScalingType.REVERSELOGARITHMIC - ) - - pr2 = parameter_ranges.ContinuousParameterRange.from_flyte_idl(pr.to_flyte_idl()) - assert pr == pr2 - assert type(pr2.max_value) == float - assert type(pr2.min_value) == float - assert pr2.max_value == 10.0 - assert pr2.min_value == 0.5 - assert pr2.scaling_type == parameter_ranges.HyperparameterScalingType.REVERSELOGARITHMIC - - -def test_integer_parameter_range(): - pr = parameter_ranges.IntegerParameterRange( - max_value=1, min_value=0, scaling_type=parameter_ranges.HyperparameterScalingType.LOGARITHMIC - ) - - pr2 = parameter_ranges.IntegerParameterRange.from_flyte_idl(pr.to_flyte_idl()) - assert pr == pr2 - assert type(pr2.max_value) == int - assert type(pr2.min_value) == int - assert pr2.max_value == 1 - assert pr2.min_value == 0 - assert pr2.scaling_type == parameter_ranges.HyperparameterScalingType.LOGARITHMIC - - -def test_categorical_parameter_range(): - case = unittest.TestCase() - pr = parameter_ranges.CategoricalParameterRange(values=["abc", "cat"]) - - pr2 = parameter_ranges.CategoricalParameterRange.from_flyte_idl(pr.to_flyte_idl()) - assert pr == pr2 - assert isinstance(pr2.values, list) - case.assertCountEqual(pr2.values, pr.values) - - -def test_parameter_ranges(): - pr = parameter_ranges.ParameterRanges( - { - "a": parameter_ranges.CategoricalParameterRange(values=["a-1", "a-2"]), - "b": parameter_ranges.IntegerParameterRange( - min_value=1, max_value=5, scaling_type=parameter_ranges.HyperparameterScalingType.LINEAR - ), - "c": parameter_ranges.ContinuousParameterRange( - min_value=0.1, max_value=1.0, scaling_type=parameter_ranges.HyperparameterScalingType.LOGARITHMIC - ), - }, - ) - pr2 = parameter_ranges.ParameterRanges.from_flyte_idl(pr.to_flyte_idl()) - assert pr == pr2 - - -LIST_OF_PARAMETERS = [ - ( - parameter_ranges.IntegerParameterRange( - min_value=1, max_value=5, scaling_type=parameter_ranges.HyperparameterScalingType.LINEAR - ), - lambda param: assert_equal(param.integer_parameter_range.max_value, 5), - ), - ( - parameter_ranges.ContinuousParameterRange( - min_value=0.1, max_value=1.0, scaling_type=parameter_ranges.HyperparameterScalingType.LOGARITHMIC - ), - lambda param: assert_equal(param.continuous_parameter_range.max_value, 1), - ), - ( - parameter_ranges.CategoricalParameterRange(values=["a-1", "a-2"]), - lambda param: assert_equal(len(param.categorical_parameter_range.values), 2), - ), -] - - -@pytest.mark.parametrize("param_tuple", LIST_OF_PARAMETERS) -def test_parameter_ranges_oneof(param_tuple): - param, assertion = param_tuple - oneof = parameter_ranges.ParameterRangeOneOf(param=param) - oneof2 = parameter_ranges.ParameterRangeOneOf.from_flyte_idl(oneof.to_flyte_idl()) - assert oneof2 == oneof - assertion(oneof) - assertion(oneof2) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_training.py b/plugins/flytekit-aws-sagemaker/tests/test_training.py deleted file mode 100644 index 4d33a9e4bb..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/test_training.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -import pathlib -import tempfile - -import pytest -from flytekitplugins.awssagemaker.distributed_training import setup_envars_for_testing -from flytekitplugins.awssagemaker.models.training_job import ( - AlgorithmName, - AlgorithmSpecification, - DistributedProtocol, - TrainingJobResourceConfig, -) -from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerTrainingJobConfig - -import flytekit -from flytekit import task -from flytekit.configuration import Image, ImageConfig, SerializationSettings -from flytekit.core.context_manager import ExecutionParameters - - -def _get_reg_settings(): - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) - return settings - - -def test_builtin_training(): - trainer = SagemakerBuiltinAlgorithmsTask( - name="builtin-trainer", - task_config=SagemakerTrainingJobConfig( - training_job_resource_config=TrainingJobResourceConfig( - instance_count=1, - instance_type="ml-xlarge", - volume_size_in_gb=1, - ), - algorithm_specification=AlgorithmSpecification( - algorithm_name=AlgorithmName.XGBOOST, - ), - ), - ) - - assert trainer.python_interface.inputs.keys() == {"static_hyperparameters", "train", "validation"} - assert trainer.python_interface.outputs.keys() == {"model"} - - with tempfile.TemporaryDirectory() as tmp: - x = pathlib.Path(os.path.join(tmp, "x")) - y = pathlib.Path(os.path.join(tmp, "y")) - x.mkdir(parents=True, exist_ok=True) - y.mkdir(parents=True, exist_ok=True) - with pytest.raises(NotImplementedError): - # Type engine doesn't support pathlib.Path yet - trainer(static_hyperparameters={}, train=f"{x}", validation=f"{y}") - - assert trainer.get_custom(_get_reg_settings()) == { - "algorithmSpecification": {"algorithmName": "XGBOOST"}, - "trainingJobResourceConfig": {"instanceCount": "1", "instanceType": "ml-xlarge", "volumeSizeInGb": "1"}, - } - - -def test_custom_training(): - @task( - task_config=SagemakerTrainingJobConfig( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml-xlarge", - volume_size_in_gb=1, - ), - algorithm_specification=AlgorithmSpecification( - algorithm_name=AlgorithmName.CUSTOM, - ), - ) - ) - def my_custom_trainer(x: int) -> int: - return x - - assert my_custom_trainer.python_interface.inputs == {"x": int} - assert my_custom_trainer.python_interface.outputs == {"o0": int} - - assert my_custom_trainer(x=10) == 10 - - assert my_custom_trainer.get_custom(_get_reg_settings()) == { - "algorithmSpecification": {}, - "trainingJobResourceConfig": {"instanceCount": "1", "instanceType": "ml-xlarge", "volumeSizeInGb": "1"}, - } - - -def test_distributed_custom_training(): - setup_envars_for_testing() - - @task( - task_config=SagemakerTrainingJobConfig( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml-xlarge", - volume_size_in_gb=1, - instance_count=2, # Indicates distributed training - distributed_protocol=DistributedProtocol.MPI, - ), - algorithm_specification=AlgorithmSpecification( - algorithm_name=AlgorithmName.CUSTOM, - ), - ) - ) - def my_custom_trainer(x: int) -> int: - assert flytekit.current_context().distributed_training_context is not None - return x - - assert my_custom_trainer.python_interface.inputs == {"x": int} - assert my_custom_trainer.python_interface.outputs == {"o0": int} - - assert my_custom_trainer(x=10) == 10 - - assert my_custom_trainer._is_distributed() is True - - pb = ExecutionParameters.new_builder() - pb.working_dir = "/tmp" - p = pb.build() - new_p = my_custom_trainer.pre_execute(p) - assert new_p is not None - assert new_p.has_attr("distributed_training_context") - - assert my_custom_trainer.get_custom(_get_reg_settings()) == { - "algorithmSpecification": {}, - "trainingJobResourceConfig": { - "distributedProtocol": "MPI", - "instanceCount": "2", - "instanceType": "ml-xlarge", - "volumeSizeInGb": "1", - }, - } diff --git a/plugins/flytekit-aws-sagemaker/tests/test_training_job.py b/plugins/flytekit-aws-sagemaker/tests/test_training_job.py deleted file mode 100644 index 8774857b1f..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/test_training_job.py +++ /dev/null @@ -1,87 +0,0 @@ -import unittest - -from flytekitplugins.awssagemaker.models import training_job - - -def test_training_job_resource_config(): - rc = training_job.TrainingJobResourceConfig( - instance_count=1, - instance_type="random.instance", - volume_size_in_gb=25, - distributed_protocol=training_job.DistributedProtocol.MPI, - ) - - rc2 = training_job.TrainingJobResourceConfig.from_flyte_idl(rc.to_flyte_idl()) - assert rc2 == rc - assert rc2.distributed_protocol == training_job.DistributedProtocol.MPI - assert rc != training_job.TrainingJobResourceConfig( - instance_count=1, - instance_type="random.instance", - volume_size_in_gb=25, - distributed_protocol=training_job.DistributedProtocol.UNSPECIFIED, - ) - - assert rc != training_job.TrainingJobResourceConfig( - instance_count=1, - instance_type="oops", - volume_size_in_gb=25, - distributed_protocol=training_job.DistributedProtocol.MPI, - ) - - -def test_metric_definition(): - md = training_job.MetricDefinition(name="test-metric", regex="[a-zA-Z]*") - - md2 = training_job.MetricDefinition.from_flyte_idl(md.to_flyte_idl()) - assert md == md2 - assert md2.name == "test-metric" - assert md2.regex == "[a-zA-Z]*" - - -def test_algorithm_specification(): - case = unittest.TestCase() - alg_spec = training_job.AlgorithmSpecification( - algorithm_name=training_job.AlgorithmName.CUSTOM, - algorithm_version="v100", - input_mode=training_job.InputMode.FILE, - metric_definitions=[training_job.MetricDefinition(name="a", regex="b")], - input_content_type=training_job.InputContentType.TEXT_CSV, - ) - - alg_spec2 = training_job.AlgorithmSpecification.from_flyte_idl(alg_spec.to_flyte_idl()) - - assert alg_spec2.algorithm_name == training_job.AlgorithmName.CUSTOM - assert alg_spec2.algorithm_version == "v100" - assert alg_spec2.input_mode == training_job.InputMode.FILE - case.assertCountEqual(alg_spec.metric_definitions, alg_spec2.metric_definitions) - assert alg_spec == alg_spec2 - - -def test_training_job(): - rc = training_job.TrainingJobResourceConfig( - instance_type="test_type", - instance_count=10, - volume_size_in_gb=25, - distributed_protocol=training_job.DistributedProtocol.MPI, - ) - alg = training_job.AlgorithmSpecification( - algorithm_name=training_job.AlgorithmName.CUSTOM, - algorithm_version="", - input_mode=training_job.InputMode.FILE, - input_content_type=training_job.InputContentType.TEXT_CSV, - ) - tj = training_job.TrainingJob( - training_job_resource_config=rc, - algorithm_specification=alg, - ) - - tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl()) - # checking tj == tj2 would return false because we don't have the __eq__ magic method defined - assert tj.training_job_resource_config.instance_type == tj2.training_job_resource_config.instance_type - assert tj.training_job_resource_config.instance_count == tj2.training_job_resource_config.instance_count - assert tj.training_job_resource_config.distributed_protocol == tj2.training_job_resource_config.distributed_protocol - assert tj.training_job_resource_config.volume_size_in_gb == tj2.training_job_resource_config.volume_size_in_gb - assert tj.algorithm_specification.algorithm_name == tj2.algorithm_specification.algorithm_name - assert tj.algorithm_specification.algorithm_version == tj2.algorithm_specification.algorithm_version - assert tj.algorithm_specification.input_mode == tj2.algorithm_specification.input_mode - assert tj.algorithm_specification.input_content_type == tj2.algorithm_specification.input_content_type From 56277a6ba8e10f39128e267cfd35c790f6f9f952 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 29 Dec 2023 19:53:52 +0530 Subject: [PATCH 004/120] add deployment workflow Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/README.md | 4 +- .../awssagemaker/agents/boto3_agent.py | 4 ++ .../agents/sagemaker_deploy_agents.py | 11 +++- .../flytekitplugins/awssagemaker/task.py | 4 +- .../flytekitplugins/awssagemaker/workflow.py | 56 +++++++++++++++++++ 5 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index 02abb4de49..9705edbcaa 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -1,6 +1,6 @@ -# Flytekit AWS Sagemaker Plugin +# AWS Sagemaker Plugin -Amazon SageMaker provides several built-in machine learning algorithms that you can use for a variety of problem types. Flyte Sagemaker plugin intends to greatly simplify using Sagemaker for training. We have tried to distill the API into a meaningful subset that makes it easier for users to adopt and run with Sagemaker. +The plugin includes a deployment agent that allows you to deploy Sagemaker models, create, and inkoke endpoints for inference. To install the plugin, run the following command: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py index 95aaa4fa54..0a80589819 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py @@ -7,6 +7,7 @@ from flytekit.core.external_api_task import ExternalApiTask from flytekit.core.type_engine import TypeEngine from flytekit.models.literals import LiteralMap +from flytekit.extend.backend.base_agent import get_agent_secret from .boto3_mixin import Boto3AgentMixin @@ -34,6 +35,9 @@ def do( task_template=task_template, additional_args=additional_args, region=region, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), ) ctx = FlyteContextManager.current_context() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py index d97f01944b..62b801c963 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py @@ -15,9 +15,13 @@ from flytekit import FlyteContextManager from flytekit.core.external_api_task import ExternalApiTask from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, convert_to_flyte_state +from flytekit.extend.backend.base_agent import ( + AgentBase, + convert_to_flyte_state, + get_agent_secret, + AgentRegistry, +) from flytekit.models.literals import LiteralMap -from flytekit.extend.backend.base_agent import get_agent_secret from .boto3_mixin import Boto3AgentMixin @@ -206,3 +210,6 @@ def do( } ).to_flyte_idl() return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + + +AgentRegistry.register(SagemakerEndpointAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 445bbd0884..0bcf7b23df 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, Type from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct @@ -23,7 +23,7 @@ def __init__( self, name: str, task_config: SagemakerEndpointConfig, - inputs: Optional[dict[str, Any]] = None, + inputs: Optional[dict[str, Type]] = None, **kwargs, ): super().__init__( diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py new file mode 100644 index 0000000000..a5f7097999 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -0,0 +1,56 @@ +from flytekit import Workflow, kwtypes +from .agents.sagemaker_deploy_agents import ( + SagemakerModelTask, + SagemakerEndpointConfigTask, +) +from .task import SagemakerEndpointTask +from typing import Any, Optional + + +def create_sagemaker_deployment( + model_name: str, + region: str, + model_config: dict[str, Any], + endpoint_config_config: dict[str, Any], + endpoint_config: dict[str, Any], + model_additional_args: Optional[dict[str, Any]] = None, + endpoint_config_additional_args: Optional[dict[str, Any]] = None, +): + sagemaker_model_task = SagemakerModelTask( + name=f"sagemaker-model-{model_name}", + config=model_config, + region=region, + ) + + endpoint_config_task = SagemakerEndpointConfigTask( + name=f"sagemaker-endpoint-config-{model_name}", + config=endpoint_config_config, + region=region, + ) + + endpoint_task = SagemakerEndpointTask( + name=f"sagemaker-endpoint-{model_name}", + task_config=endpoint_config, + inputs=kwtypes(inputs=dict), + ) + + wf = Workflow(name=f"sagemaker-deploy-{model_name}") + wf.add_workflow_input("model_inputs", dict) + wf.add_workflow_input("endpoint_config_inputs", dict) + wf.add_workflow_input("endpoint_inputs", dict) + + wf.add_entity( + sagemaker_model_task, + inputs=wf.inputs["model_inputs"], + additional_args=model_additional_args, + ) + + wf.add_entity( + endpoint_config_task, + inputs=wf.inputs["endpoint_config_inputs"], + additional_args=endpoint_config_additional_args, + ) + + wf.add_entity(endpoint_task, inputs=wf.inputs["endpoint_inputs"]) + + return wf From 53b9fce15a4e44754e878137a9329c091f73e203 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 2 Jan 2024 16:17:30 +0530 Subject: [PATCH 005/120] add a workflow to delete sagemaker deployment Signed-off-by: Samhita Alla --- .../agents/sagemaker_deploy_agents.py | 90 +++++++++++++++++-- .../flytekitplugins/awssagemaker/workflow.py | 43 +++++++++ 2 files changed, 126 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py index 62b801c963..f8926b563a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py @@ -1,6 +1,6 @@ import json from dataclasses import asdict, dataclass -from typing import Any, Optional +from typing import Any, Optional, Type import grpc from flyteidl.admin.agent_pb2 import ( @@ -63,8 +63,8 @@ def do( "o0": TypeEngine.to_literal( ctx, result, - type(result), - TypeEngine.to_literal_type(type(result)), + dict[str, str], + TypeEngine.to_literal_type(dict[str, str]), ) } ).to_flyte_idl() @@ -101,8 +101,8 @@ def do( "o0": TypeEngine.to_literal( ctx, result, - type(result), - TypeEngine.to_literal_type(type(result)), + dict[str, str], + TypeEngine.to_literal_type(dict[str, str]), ) } ).to_flyte_idl() @@ -183,6 +183,7 @@ def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = No def do( self, task_template: TaskTemplate, + output_result_type: Type, inputs: Optional[LiteralMap] = None, additional_args: Optional[dict[str, Any]] = None, ) -> CreateTaskResponse: @@ -204,12 +205,87 @@ def do( "o0": TypeEngine.to_literal( ctx, result, - type(result), - TypeEngine.to_literal_type(type(result)), + dict[str, output_result_type], + TypeEngine.to_literal_type(dict[str, output_result_type]), ) } ).to_flyte_idl() return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) +class SagemakerDeleteEndpointTask(Boto3AgentMixin, ExternalApiTask): + """This agent deletes the Sagemaker model.""" + + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + + self._call( + method="delete_endpoint", + inputs=inputs, + config=task_template.custom["task_config"], + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=None)) + + +class SagemakerDeleteEndpointConfigTask(Boto3AgentMixin, ExternalApiTask): + """This agent deletes the endpoint config.""" + + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + + self._call( + method="delete_endpoint_config", + inputs=inputs, + config=task_template.custom["task_config"], + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=None)) + + +class SagemakerDeleteModelTask(Boto3AgentMixin, ExternalApiTask): + """This agent deletes an endpoint.""" + + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) + + def do( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + inputs = inputs or LiteralMap(literals={}) + + self._call( + method="delete_model", + inputs=inputs, + config=task_template.custom["task_config"], + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=None)) + + AgentRegistry.register(SagemakerEndpointAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index a5f7097999..a49b6b0f53 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -2,7 +2,12 @@ from .agents.sagemaker_deploy_agents import ( SagemakerModelTask, SagemakerEndpointConfigTask, + SagemakerDeleteEndpointTask, + SagemakerDeleteEndpointConfigTask, + SagemakerDeleteModelTask, ) +from flytekit.models import literals + from .task import SagemakerEndpointTask from typing import Any, Optional @@ -54,3 +59,41 @@ def create_sagemaker_deployment( wf.add_entity(endpoint_task, inputs=wf.inputs["endpoint_inputs"]) return wf + + +def delete_sagemaker_deployment(name: str): + sagemaker_delete_endpoint_task = SagemakerDeleteEndpointTask( + name=f"sagemaker-delete-endpoint-{name}", + config={"EndpointName": "{endpoint_name}"}, + inputs=kwtypes(inputs=dict), + ) + sagemaker_delete_endpoint_config_task = SagemakerDeleteEndpointConfigTask( + name=f"sagemaker-delete-endpoint-config-{name}", + config={"EndpointConfigName": "{endpoint_config_name}"}, + inputs=kwtypes(inputs=dict), + ) + sagemaker_delete_model_task = SagemakerDeleteModelTask( + name=f"sagemaker-delete-model-{name}", + config={"ModelName": "{model_name}"}, + inputs=kwtypes(inputs=dict), + ) + + wf = Workflow(name=f"sagemaker-delete-endpoint-wf-{name}") + wf.add_workflow_input("endpoint_name", str) + wf.add_workflow_input("endpoint_config_name", str) + wf.add_workflow_input("model_name", str) + + wf.add_entity( + sagemaker_delete_endpoint_task, + inputs=literals.LiteralMap({"endpoint_name": wf.inputs["endpoint_name"]}), + ) + wf.add_entity( + sagemaker_delete_endpoint_config_task, + inputs=literals.LiteralMap({"endpoint_config_name": wf.inputs["endpoint_config_name"]}), + ) + wf.add_entity( + sagemaker_delete_model_task, + inputs=literals.LiteralMap({"model_name", wf.inputs["model_name"]}), + ) + + return wf From c64da88fb2dde89fa81d0e819d83936a129a8b54 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 2 Jan 2024 23:13:36 +0530 Subject: [PATCH 006/120] clean up Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/__init__.py | 27 -- .../flytekitplugins/awssagemaker/agent.py | 90 ++++++ .../agents/sagemaker_deploy_agents.py | 291 ------------------ .../{agents => boto3}/__init__.py | 0 .../{agents/boto3_agent.py => boto3/agent.py} | 38 +-- .../{agents/boto3_mixin.py => boto3/mixin.py} | 8 +- .../flytekitplugins/awssagemaker/workflow.py | 96 ++++-- 7 files changed, 187 insertions(+), 363 deletions(-) create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py rename plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/{agents => boto3}/__init__.py (100%) rename plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/{agents/boto3_agent.py => boto3/agent.py} (64%) rename plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/{agents/boto3_mixin.py => boto3/mixin.py} (95%) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index 6dce099dfb..b008852a7e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -56,30 +56,3 @@ "TrainingJobEarlyStoppingType", "TrainingJobResourceConfig", ] - -from flytekitplugins.awssagemaker.models.hpo_job import ( - HyperparameterTuningJobConfig, - HyperparameterTuningObjective, - HyperparameterTuningObjectiveType, - HyperparameterTuningStrategy, - TrainingJobEarlyStoppingType, -) -from flytekitplugins.awssagemaker.models.parameter_ranges import ( - CategoricalParameterRange, - ContinuousParameterRange, - HyperparameterScalingType, - IntegerParameterRange, - ParameterRangeOneOf, -) -from flytekitplugins.awssagemaker.models.training_job import ( - AlgorithmName, - AlgorithmSpecification, - DistributedProtocol, - InputContentType, - InputMode, - TrainingJobResourceConfig, -) - -from .distributed_training import DISTRIBUTED_TRAINING_CONTEXT_KEY, DistributedTrainingContext -from .hpo import HPOJob, SagemakerHPOTask -from .training import SagemakerBuiltinAlgorithmsTask, SagemakerCustomTrainingTask, SagemakerTrainingJobConfig diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py new file mode 100644 index 0000000000..9f505fc602 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -0,0 +1,90 @@ +import json +from dataclasses import asdict, dataclass +from typing import Optional + +import grpc +from flyteidl.admin.agent_pb2 import ( + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) +from flyteidl.core.tasks_pb2 import TaskTemplate + +from flytekit.extend.backend.base_agent import ( + AgentBase, + convert_to_flyte_state, + get_agent_secret, +) +from flytekit.models.literals import LiteralMap + +from .boto3.mixin import Boto3AgentMixin + + +@dataclass +class Metadata: + endpoint_name: str + region: str + + +class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): + """This agent creates an endpoint.""" + + def __init__(self): + super().__init__( + service="sagemaker", + task_type="sagemaker-endpoint", + ) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + custom = task_template.custom + config = custom["config"] + region = custom["region"] + + await self._call( + "create_endpoint", + config=config, + task_template=task_template, + inputs=inputs, + region=region, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ) + + metadata = Metadata(endpoint_name=config["EndpointName"], region=region) + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + + endpoint_status = await self._call( + "describe_endpoint", + config={"EndpointName": metadata.endpoint_name}, + ) + + current_state = endpoint_status.get("EndpointStatus") + message = "" + if current_state in ("Failed", "UpdateRollbackFailed"): + message = endpoint_status.get("FailureReason") + + # THIS WON'T WORK. NEED TO FIX THIS. + flyte_state = convert_to_flyte_state(current_state) + + return GetTaskResponse(resource=Resource(state=flyte_state, message=message)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + + await self._call( + "delete_endpoint", + config={"EndpointName": metadata.endpoint_name}, + ) + + return DeleteTaskResponse() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py deleted file mode 100644 index f8926b563a..0000000000 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/sagemaker_deploy_agents.py +++ /dev/null @@ -1,291 +0,0 @@ -import json -from dataclasses import asdict, dataclass -from typing import Any, Optional, Type - -import grpc -from flyteidl.admin.agent_pb2 import ( - SUCCEEDED, - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) -from flyteidl.core.tasks_pb2 import TaskTemplate - -from flytekit import FlyteContextManager -from flytekit.core.external_api_task import ExternalApiTask -from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import ( - AgentBase, - convert_to_flyte_state, - get_agent_secret, - AgentRegistry, -) -from flytekit.models.literals import LiteralMap - -from .boto3_mixin import Boto3AgentMixin - - -@dataclass -class Metadata: - endpoint_name: str - region: str - - -class SagemakerModelTask(Boto3AgentMixin, ExternalApiTask): - """This agent creates a Sagemaker model.""" - - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) - - def do( - self, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - additional_args: Optional[dict[str, Any]] = None, - ) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - - result = self._call( - method="create_model", - config=task_template.custom["task_config"], - inputs=inputs, - task_template=task_template, - additional_args=additional_args, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - result, - dict[str, str], - TypeEngine.to_literal_type(dict[str, str]), - ) - } - ).to_flyte_idl() - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) - - -class SagemakerEndpointConfigTask(Boto3AgentMixin, ExternalApiTask): - """This agent creates an endpoint config.""" - - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) - - def do( - self, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - additional_args: Optional[dict[str, Any]] = None, - ) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - - result = self._call( - method="create_endpoint_config", - inputs=inputs, - config=task_template.custom["task_config"], - additional_args=additional_args, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - result, - dict[str, str], - TypeEngine.to_literal_type(dict[str, str]), - ) - } - ).to_flyte_idl() - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) - - -class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): - """This agent creates an endpoint.""" - - def __init__(self, region: str): - super().__init__( - service="sagemaker-runtime", - region=region, - task_type="sagemaker-endpoint", - asynchronous=True, - ) - - async def async_create( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - ) -> CreateTaskResponse: - custom = task_template.custom - config = custom["config"] - region = custom["region"] - - await self._call( - "create_endpoint", - config=config, - task_template=task_template, - inputs=inputs, - region=region, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - metadata = Metadata(endpoint_name=config["EndpointName"], region=region) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - - endpoint_status = await self._call( - "describe_endpoint", - config={"EndpointName": metadata.endpoint_name}, - ) - - current_state = endpoint_status.get("EndpointStatus") - message = "" - if current_state in ("Failed", "UpdateRollbackFailed"): - message = endpoint_status.get("FailureReason") - - # THIS WON'T WORK. NEED TO FIX THIS. - flyte_state = convert_to_flyte_state(current_state) - - return GetTaskResponse(resource=Resource(state=flyte_state, message=message)) - - async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) - - await self._call( - "delete_endpoint", - config={"EndpointName": metadata.endpoint_name}, - ) - - return DeleteTaskResponse() - - -class SagemakerInvokeEndpointTask(Boto3AgentMixin, ExternalApiTask): - """This agent invokes an endpoint.""" - - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) - - def do( - self, - task_template: TaskTemplate, - output_result_type: Type, - inputs: Optional[LiteralMap] = None, - additional_args: Optional[dict[str, Any]] = None, - ) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - - result = self._call( - method="invoke_endpoint", - inputs=inputs, - config=task_template.custom["task_config"], - additional_args=additional_args, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - result, - dict[str, output_result_type], - TypeEngine.to_literal_type(dict[str, output_result_type]), - ) - } - ).to_flyte_idl() - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) - - -class SagemakerDeleteEndpointTask(Boto3AgentMixin, ExternalApiTask): - """This agent deletes the Sagemaker model.""" - - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) - - def do( - self, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - ) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - - self._call( - method="delete_endpoint", - inputs=inputs, - config=task_template.custom["task_config"], - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=None)) - - -class SagemakerDeleteEndpointConfigTask(Boto3AgentMixin, ExternalApiTask): - """This agent deletes the endpoint config.""" - - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) - - def do( - self, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - ) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - - self._call( - method="delete_endpoint_config", - inputs=inputs, - config=task_template.custom["task_config"], - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=None)) - - -class SagemakerDeleteModelTask(Boto3AgentMixin, ExternalApiTask): - """This agent deletes an endpoint.""" - - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super().__init__(service="sagemaker-runtime", region=region, name=name, config=config, **kwargs) - - def do( - self, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - ) -> CreateTaskResponse: - inputs = inputs or LiteralMap(literals={}) - - self._call( - method="delete_model", - inputs=inputs, - config=task_template.custom["task_config"], - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ) - - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=None)) - - -AgentRegistry.register(SagemakerEndpointAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/__init__.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/__init__.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/__init__.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py similarity index 64% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py index 0a80589819..b99b608a39 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py @@ -9,27 +9,27 @@ from flytekit.models.literals import LiteralMap from flytekit.extend.backend.base_agent import get_agent_secret -from .boto3_mixin import Boto3AgentMixin +from .mixin import Boto3AgentMixin class SyncBotoAgentTask(Boto3AgentMixin, ExternalApiTask): """A general purpose boto3 agent that can be used to call any boto3 method synchronously.""" - def __init__(self, name: str, config: dict[str, Any], service: str, region: Optional[str] = None, **kwargs): - super().__init__(service=service, region=region, name=name, config=config, **kwargs) + def __init__( + self, name: str, service: str, method: str, config: dict[str, Any], region: Optional[str] = None, **kwargs + ): + super().__init__(service=service, method=method, region=region, name=name, config=config, **kwargs) def do( self, task_template: TaskTemplate, - method: str, - output_result_type: Type, + output_result_type: Optional[Type] = None, inputs: Optional[LiteralMap] = None, additional_args: Optional[dict[str, Any]] = None, region: Optional[str] = None, ): inputs = inputs or LiteralMap(literals={}) result = self._call( - method=method, config=task_template.custom["task_config"], inputs=inputs, task_template=task_template, @@ -40,16 +40,18 @@ def do( aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), ) - ctx = FlyteContextManager.current_context() - - outputs = LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - result, - output_result_type, - TypeEngine.to_literal_type(output_result_type), - ) - } - ).to_flyte_idl() + outputs = None + if result: + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + result, + output_result_type, + TypeEngine.to_literal_type(output_result_type), + ) + } + ).to_flyte_idl() + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py similarity index 95% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py index 8fa8025e19..dff8c6c1ae 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agents/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py @@ -73,7 +73,7 @@ class Boto3AgentMixin: It provides a single method, `_call`, which can be employed to invoke any boto3 method. """ - def __init__(self, *, service: Optional[str] = None, region: Optional[str] = None, **kwargs): + def __init__(self, *, service: str, method: str, region: Optional[str] = None, **kwargs): """ Initialize the Boto3AgentMixin. @@ -82,13 +82,13 @@ def __init__(self, *, service: Optional[str] = None, region: Optional[str] = Non """ self._service = service self._region = region + self._method = method super().__init__(**kwargs) async def _call( self, - method: str, config: dict[str, Any], - task_template: TaskTemplate, + task_template: Optional[TaskTemplate] = None, inputs: Optional[LiteralMap] = None, additional_args: Optional[dict[str, Any]] = None, region: Optional[str] = None, @@ -133,7 +133,7 @@ async def _call( aws_session_token=aws_session_token, ) as client: try: - result = await getattr(client, method)(**updated_config) + result = await getattr(client, self._method)(**updated_config) except Exception as e: raise e diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index a49b6b0f53..daa7e094c1 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -1,35 +1,34 @@ -from flytekit import Workflow, kwtypes -from .agents.sagemaker_deploy_agents import ( - SagemakerModelTask, - SagemakerEndpointConfigTask, - SagemakerDeleteEndpointTask, - SagemakerDeleteEndpointConfigTask, - SagemakerDeleteModelTask, -) +from flytekit import Workflow, kwtypes, LaunchPlan +from .boto3.agent import SyncBotoAgentTask + from flytekit.models import literals from .task import SagemakerEndpointTask -from typing import Any, Optional +from typing import Any, Optional, Union, Type def create_sagemaker_deployment( model_name: str, - region: str, model_config: dict[str, Any], endpoint_config_config: dict[str, Any], endpoint_config: dict[str, Any], + region: Optional[str] = None, model_additional_args: Optional[dict[str, Any]] = None, endpoint_config_additional_args: Optional[dict[str, Any]] = None, ): - sagemaker_model_task = SagemakerModelTask( + sagemaker_model_task = SyncBotoAgentTask( name=f"sagemaker-model-{model_name}", config=model_config, + service="sagemaker", + method="create_model", region=region, ) - endpoint_config_task = SagemakerEndpointConfigTask( + endpoint_config_task = SyncBotoAgentTask( name=f"sagemaker-endpoint-config-{model_name}", config=endpoint_config_config, + service="sagemaker", + method="create_endpoint_config", region=region, ) @@ -46,36 +45,54 @@ def create_sagemaker_deployment( wf.add_entity( sagemaker_model_task, + output_result_type=dict[str, str], inputs=wf.inputs["model_inputs"], additional_args=model_additional_args, ) wf.add_entity( endpoint_config_task, + output_result_type=dict[str, str], inputs=wf.inputs["endpoint_config_inputs"], additional_args=endpoint_config_additional_args, ) wf.add_entity(endpoint_task, inputs=wf.inputs["endpoint_inputs"]) - return wf + lp = LaunchPlan.get_or_create( + workflow=wf, + default_inputs={ + "model_inputs": None, + "endpoint_config_inputs": None, + "endpoint_status": None, + }, + ) + return lp -def delete_sagemaker_deployment(name: str): - sagemaker_delete_endpoint_task = SagemakerDeleteEndpointTask( +def delete_sagemaker_deployment(name: str, region: Optional[str] = None): + sagemaker_delete_endpoint = SyncBotoAgentTask( name=f"sagemaker-delete-endpoint-{name}", config={"EndpointName": "{endpoint_name}"}, - inputs=kwtypes(inputs=dict), + service="sagemaker", + method="delete_endpoint", + region=region, ) - sagemaker_delete_endpoint_config_task = SagemakerDeleteEndpointConfigTask( + + sagemaker_delete_endpoint_config = SyncBotoAgentTask( name=f"sagemaker-delete-endpoint-config-{name}", config={"EndpointConfigName": "{endpoint_config_name}"}, - inputs=kwtypes(inputs=dict), + service="sagemaker", + method="delete_endpoint_config", + region=region, ) - sagemaker_delete_model_task = SagemakerDeleteModelTask( + + sagemaker_delete_model = SyncBotoAgentTask( name=f"sagemaker-delete-model-{name}", config={"ModelName": "{model_name}"}, - inputs=kwtypes(inputs=dict), + service="sagemaker", + method="delete_model", + region=region, ) wf = Workflow(name=f"sagemaker-delete-endpoint-wf-{name}") @@ -84,16 +101,49 @@ def delete_sagemaker_deployment(name: str): wf.add_workflow_input("model_name", str) wf.add_entity( - sagemaker_delete_endpoint_task, + sagemaker_delete_endpoint, inputs=literals.LiteralMap({"endpoint_name": wf.inputs["endpoint_name"]}), ) wf.add_entity( - sagemaker_delete_endpoint_config_task, + sagemaker_delete_endpoint_config, inputs=literals.LiteralMap({"endpoint_config_name": wf.inputs["endpoint_config_name"]}), ) wf.add_entity( - sagemaker_delete_model_task, + sagemaker_delete_model, inputs=literals.LiteralMap({"model_name", wf.inputs["model_name"]}), ) return wf + + +def invoke_endpoint( + name: str, + config: dict[str, Any], + output_result_type: Type, + region: Optional[str] = None, +): + sagemaker_invoke_endpoint = SyncBotoAgentTask( + name=f"sagemaker-invoke-endpoint-{name}", + config=config, + service="sagemaker-runtime", + method="invoke_endpoint_async", + region=region, + ) + + wf = Workflow(name=f"sagemaker-invoke-endpoint-wf-{name}") + wf.add_workflow_input("inputs", dict) + + invoke_node = wf.add_entity( + sagemaker_invoke_endpoint, + inputs=wf.inputs["inputs"], + output_result_type=dict[str, Union[str, output_result_type]], + ) + + wf.add_workflow_output( + "result", + invoke_node.outputs["o0"], + python_type=dict[str, Union[str, output_result_type]], + ) + + lp = LaunchPlan.get_or_create(workflow=wf, default_inputs={"inputs": None}) + return lp From e2bbe0862326d9784a830c13bbf87ed6c43cfdcc Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 3 Jan 2024 22:59:35 +0530 Subject: [PATCH 007/120] modify dict update logic, create sagemaker deployment tasks Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/__init__.py | 72 +++++--------- .../flytekitplugins/awssagemaker/agent.py | 94 ++++++++++++++++++- .../awssagemaker/boto3/agent.py | 26 ++++- .../awssagemaker/boto3/mixin.py | 47 ++++------ .../flytekitplugins/awssagemaker/workflow.py | 75 +++++---------- 5 files changed, 178 insertions(+), 136 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index b008852a7e..52b2a16c20 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -1,58 +1,32 @@ """ .. currentmodule:: flytekitplugins.awssagemaker -This package contains things that are useful when extending Flytekit. - .. autosummary:: :template: custom.rst :toctree: generated/ - AlgorithmName - AlgorithmSpecification - CategoricalParameterRange - ContinuousParameterRange - DISTRIBUTED_TRAINING_CONTEXT_KEY - DistributedProtocol - DistributedTrainingContext - HPOJob - HyperparameterScalingType - HyperparameterTuningJobConfig - HyperparameterTuningObjective - HyperparameterTuningObjectiveType - HyperparameterTuningStrategy - InputContentType - InputMode - IntegerParameterRange - ParameterRangeOneOf - SagemakerCustomTrainingTask - SagemakerHPOTask - SagemakerTrainingJobConfig - TrainingJobEarlyStoppingType - TrainingJobResourceConfig + SagemakerDeleteEndpointConfigTask + SagemakerDeleteEndpointTask + SagemakerDeleteModelTask + SagemakerEndpointAgent + SagemakerEndpointConfigTask + SagemakerInvokeEndpointTask + SagemakerModelTask + SyncBotoAgentTask + SagemakerEndpointTask + create_sagemaker_deployment + delete_sagemaker_deployment """ -__all__ = [ - "AlgorithmName", - "AlgorithmSpecification", - "CategoricalParameterRange", - "ContinuousParameterRange", - "DISTRIBUTED_TRAINING_CONTEXT_KEY", - "DistributedProtocol", - "DistributedTrainingContext", - "HPOJob", - "HyperparameterScalingType", - "HyperparameterTuningJobConfig", - "HyperparameterTuningObjective", - "HyperparameterTuningObjectiveType", - "HyperparameterTuningStrategy", - "InputContentType", - "InputMode", - "IntegerParameterRange", - "ParameterRangeOneOf", - "SagemakerBuiltinAlgorithmsTask", - "SagemakerCustomTrainingTask", - "SagemakerHPOTask", - "SagemakerTrainingJobConfig", - "TrainingJobEarlyStoppingType", - "TrainingJobResourceConfig", -] +from .agent import ( + SagemakerDeleteEndpointConfigTask, + SagemakerDeleteEndpointTask, + SagemakerDeleteModelTask, + SagemakerEndpointAgent, + SagemakerEndpointConfigTask, + SagemakerInvokeEndpointTask, + SagemakerModelTask, +) +from .boto3.agent import SyncBotoAgentTask +from .task import SagemakerEndpointTask +from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 9f505fc602..d68ccc661c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -1,6 +1,6 @@ import json from dataclasses import asdict, dataclass -from typing import Optional +from typing import Optional, Any, Union, Type import grpc from flyteidl.admin.agent_pb2 import ( @@ -17,8 +17,10 @@ get_agent_secret, ) from flytekit.models.literals import LiteralMap +from flytekit import ImageSpec from .boto3.mixin import Boto3AgentMixin +from .boto3.agent import SyncBotoAgentTask @dataclass @@ -88,3 +90,93 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes ) return DeleteTaskResponse() + + +class SagemakerModelTask(SyncBotoAgentTask): + def __init__( + self, + name: str, + config: dict[str, Any], + region: Optional[str] = None, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + super(SagemakerModelTask, self).__init__( + service="sagemaker", + method="create_model", + region=region, + name=name, + config=config, + container_image=container_image, + **kwargs, + ) + + +class SagemakerEndpointConfigTask(SyncBotoAgentTask): + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super(SagemakerEndpointConfigTask, self).__init__( + service="sagemaker", + method="create_endpoint_config", + region=region, + name=name, + config=config, + **kwargs, + ) + + +class SagemakerDeleteEndpointTask(SyncBotoAgentTask): + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super(SagemakerDeleteEndpointTask, self).__init__( + service="sagemaker", + method="delete_endpoint", + region=region, + name=name, + config=config, + **kwargs, + ) + + +class SagemakerDeleteEndpointConfigTask(SyncBotoAgentTask): + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super(SagemakerDeleteEndpointConfigTask, self).__init__( + service="sagemaker", + method="delete_endpoint_config", + region=region, + name=name, + config=config, + **kwargs, + ) + + +class SagemakerDeleteModelTask(SyncBotoAgentTask): + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super(SagemakerDeleteModelTask, self).__init__( + service="sagemaker", + method="delete_model", + region=region, + name=name, + config=config, + **kwargs, + ) + + +class SagemakerInvokeEndpointTask(SyncBotoAgentTask): + def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): + super(SagemakerInvokeEndpointTask, self).__init__( + service="sagemaker-runtime", + method="invoke_endpoint_async", + region=region, + name=name, + config=config, + **kwargs, + ) + + def do( + self, + output_result_type: Type = dict[str, str], + **kwargs, + ): + super(SagemakerInvokeEndpointTask, self).do( + output_result_type=dict[str, Union[str, output_result_type]], + **kwargs, + ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py index b99b608a39..5d576e4a0d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py @@ -1,9 +1,9 @@ -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Union from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource from flyteidl.core.tasks_pb2 import TaskTemplate -from flytekit import FlyteContextManager +from flytekit import FlyteContextManager, ImageSpec from flytekit.core.external_api_task import ExternalApiTask from flytekit.core.type_engine import TypeEngine from flytekit.models.literals import LiteralMap @@ -12,18 +12,34 @@ from .mixin import Boto3AgentMixin +# ExternalApiTask needs to inherit from PythonFunctionTask class SyncBotoAgentTask(Boto3AgentMixin, ExternalApiTask): """A general purpose boto3 agent that can be used to call any boto3 method synchronously.""" def __init__( - self, name: str, service: str, method: str, config: dict[str, Any], region: Optional[str] = None, **kwargs + self, + name: str, + service: str, + method: str, + config: dict[str, Any], + region: Optional[str] = None, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, ): - super().__init__(service=service, method=method, region=region, name=name, config=config, **kwargs) + super().__init__( + service=service, + method=method, + region=region, + name=name, + config=config, + container_image=container_image, + **kwargs, + ) def do( self, task_template: TaskTemplate, - output_result_type: Optional[Type] = None, + output_result_type: Type = dict[str, str], inputs: Optional[LiteralMap] = None, additional_args: Optional[dict[str, Any]] = None, region: Optional[str] = None, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py index dff8c6c1ae..5f1ec40ae2 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py @@ -8,6 +8,17 @@ from flytekit.models.literals import LiteralMap +class AttrDict(dict): + """ + This class converts a dictionary into an attribute-style lookup. It is specifically designed for + namespacing inputs and outputs, providing a convenient way to access dictionary elements using dot notation. + """ + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: """ Recursively update a dictionary with values from another dictionary. @@ -22,34 +33,14 @@ def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: if original_dict is None: return None - # If the original value is a string and contains placeholder curly braces + # If the original value is a string if isinstance(original_dict, str): + # If the string contains placeholder curly braces, replace the placeholder with the actual value if "{" in original_dict and "}" in original_dict: - # Check if there are nested keys - if "." in original_dict: - # Create a copy of update_dict - update_dict_copy = update_dict.copy() - - # Fetch keys from the original_dict - keys = original_dict.strip("{}").split(".") - - # Get value from the nested dictionary - for key in keys: - update_dict_copy = update_dict_copy.get(key) - if not update_dict_copy: - raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") - - return update_dict_copy - - # Retrieve the original value using the key without curly braces - original_value = update_dict.get(original_dict.replace("{", "").replace("}", "")) - - # Check if original_value exists; if so, return it, - # otherwise, raise a ValueError indicating that the value for the key original_dict could not be found. - if original_value: - return original_value - else: - raise ValueError(f"Could not find value for {original_dict}.") + try: + return original_dict.format(**update_dict) + except KeyError as e: + raise ValueError(f"Variable {e} in placeholder not found in inputs {update_dict.keys()}") # If the string does not contain placeholders, return it as is return original_dict @@ -116,9 +107,9 @@ async def _call( """ args = {} if inputs: - args["inputs"] = literal_map_string_repr(inputs) + args["inputs"] = AttrDict(literal_map_string_repr(inputs)) if task_template: - args["container"] = MessageToDict(task_template.container) + args["container"] = AttrDict(MessageToDict(task_template.container)) if additional_args: args.update(additional_args) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index daa7e094c1..a2a605da20 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -1,10 +1,16 @@ -from flytekit import Workflow, kwtypes, LaunchPlan -from .boto3.agent import SyncBotoAgentTask +from flytekit import Workflow, kwtypes, LaunchPlan, ImageSpec +from .agent import ( + SagemakerModelTask, + SagemakerEndpointConfigTask, + SagemakerDeleteEndpointTask, + SagemakerDeleteEndpointConfigTask, + SagemakerDeleteModelTask, +) from flytekit.models import literals from .task import SagemakerEndpointTask -from typing import Any, Optional, Union, Type +from typing import Any, Optional, Union def create_sagemaker_deployment( @@ -12,23 +18,24 @@ def create_sagemaker_deployment( model_config: dict[str, Any], endpoint_config_config: dict[str, Any], endpoint_config: dict[str, Any], + container_image: Optional[Union[str, ImageSpec]] = None, region: Optional[str] = None, model_additional_args: Optional[dict[str, Any]] = None, endpoint_config_additional_args: Optional[dict[str, Any]] = None, ): - sagemaker_model_task = SyncBotoAgentTask( + """ + Creates Sagemaker model, endpoint config and endpoint. + """ + sagemaker_model_task = SagemakerModelTask( name=f"sagemaker-model-{model_name}", config=model_config, - service="sagemaker", - method="create_model", region=region, + container_image=container_image, ) - endpoint_config_task = SyncBotoAgentTask( + endpoint_config_task = SagemakerEndpointConfigTask( name=f"sagemaker-endpoint-config-{model_name}", config=endpoint_config_config, - service="sagemaker", - method="create_endpoint_config", region=region, ) @@ -45,14 +52,12 @@ def create_sagemaker_deployment( wf.add_entity( sagemaker_model_task, - output_result_type=dict[str, str], inputs=wf.inputs["model_inputs"], additional_args=model_additional_args, ) wf.add_entity( endpoint_config_task, - output_result_type=dict[str, str], inputs=wf.inputs["endpoint_config_inputs"], additional_args=endpoint_config_additional_args, ) @@ -71,27 +76,24 @@ def create_sagemaker_deployment( def delete_sagemaker_deployment(name: str, region: Optional[str] = None): - sagemaker_delete_endpoint = SyncBotoAgentTask( + """ + Deletes Sagemaker model, endpoint config and endpoint. + """ + sagemaker_delete_endpoint = SagemakerDeleteEndpointTask( name=f"sagemaker-delete-endpoint-{name}", config={"EndpointName": "{endpoint_name}"}, - service="sagemaker", - method="delete_endpoint", region=region, ) - sagemaker_delete_endpoint_config = SyncBotoAgentTask( + sagemaker_delete_endpoint_config = SagemakerDeleteEndpointConfigTask( name=f"sagemaker-delete-endpoint-config-{name}", config={"EndpointConfigName": "{endpoint_config_name}"}, - service="sagemaker", - method="delete_endpoint_config", region=region, ) - sagemaker_delete_model = SyncBotoAgentTask( + sagemaker_delete_model = SagemakerDeleteModelTask( name=f"sagemaker-delete-model-{name}", config={"ModelName": "{model_name}"}, - service="sagemaker", - method="delete_model", region=region, ) @@ -114,36 +116,3 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None): ) return wf - - -def invoke_endpoint( - name: str, - config: dict[str, Any], - output_result_type: Type, - region: Optional[str] = None, -): - sagemaker_invoke_endpoint = SyncBotoAgentTask( - name=f"sagemaker-invoke-endpoint-{name}", - config=config, - service="sagemaker-runtime", - method="invoke_endpoint_async", - region=region, - ) - - wf = Workflow(name=f"sagemaker-invoke-endpoint-wf-{name}") - wf.add_workflow_input("inputs", dict) - - invoke_node = wf.add_entity( - sagemaker_invoke_endpoint, - inputs=wf.inputs["inputs"], - output_result_type=dict[str, Union[str, output_result_type]], - ) - - wf.add_workflow_output( - "result", - invoke_node.outputs["o0"], - python_type=dict[str, Union[str, output_result_type]], - ) - - lp = LaunchPlan.get_or_create(workflow=wf, default_inputs={"inputs": None}) - return lp From 6e24e75a8aa3238ffe1d575c4adaf4b10e170586 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 3 Jan 2024 23:00:54 +0530 Subject: [PATCH 008/120] nit Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index d68ccc661c..da41cbf07a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -173,7 +173,7 @@ def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = No def do( self, - output_result_type: Type = dict[str, str], + output_result_type: Type, **kwargs, ): super(SagemakerInvokeEndpointTask, self).do( From 1fbdf512f9050d42dc033961babdfa7d07822f34 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 23 Jan 2024 11:31:31 +0530 Subject: [PATCH 009/120] dockerfile and setup changes Signed-off-by: Samhita Alla --- Dockerfile.agent | 1 + plugins/flytekit-aws-sagemaker/setup.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Dockerfile.agent b/Dockerfile.agent index 445cfcd8ca..4fdcc100a9 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -14,6 +14,7 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ flytekitplugins-mmcloud==$VERSION \ flytekitplugins-spark==$VERSION \ flytekitplugins-snowflake==$VERSION \ + flytekitplugins-awssagemaker==$VERSION \ && : CMD pyflyte serve --port 8000 diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 855dd32402..13035a5ea6 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0", "retry2==0.9.5"] +plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3"] __version__ = "0.0.0+develop" @@ -14,9 +14,9 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="AWS Plugins for flytekit", + description="Flytekit AWS Sagemaker plugin", namespace_packages=["flytekitplugins"], - packages=[f"flytekitplugins.{PLUGIN_NAME}", f"flytekitplugins.{PLUGIN_NAME}.models"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, license="apache2", python_requires=">=3.8", @@ -27,11 +27,12 @@ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - scripts=["scripts/flytekit_sagemaker_runner.py"], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) From e5392d4af39d40440f3bb155c484e920f020bbab Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 26 Jan 2024 22:29:11 +0530 Subject: [PATCH 010/120] update Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/README.md | 2 +- .../flytekitplugins/awssagemaker/__init__.py | 21 +- .../flytekitplugins/awssagemaker/agent.py | 114 ++-------- .../awssagemaker/boto3/agent.py | 70 +++--- .../awssagemaker/boto3/mixin.py | 36 +-- .../awssagemaker/boto3/task.py | 54 +++++ .../flytekitplugins/awssagemaker/task.py | 211 +++++++++++++++++- .../flytekitplugins/awssagemaker/workflow.py | 55 +++-- plugins/flytekit-aws-sagemaker/setup.py | 3 +- 9 files changed, 365 insertions(+), 201 deletions(-) create mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/task.py diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index 9705edbcaa..33cd38afef 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -1,6 +1,6 @@ # AWS Sagemaker Plugin -The plugin includes a deployment agent that allows you to deploy Sagemaker models, create, and inkoke endpoints for inference. +The plugin includes a deployment agent that allows you to deploy Sagemaker models, create and inkoke endpoints for inference. To install the plugin, run the following command: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index 52b2a16c20..dba763b3f4 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -5,28 +5,31 @@ :template: custom.rst :toctree: generated/ + SyncBotoAgent + SyncBotoTask + SagemakerModelTask + SagemakerEndpointConfigTask + SagemakerEndpointAgent + SagemakerEndpointTask SagemakerDeleteEndpointConfigTask SagemakerDeleteEndpointTask SagemakerDeleteModelTask - SagemakerEndpointAgent - SagemakerEndpointConfigTask SagemakerInvokeEndpointTask - SagemakerModelTask - SyncBotoAgentTask - SagemakerEndpointTask create_sagemaker_deployment delete_sagemaker_deployment """ -from .agent import ( +from .agent import SagemakerEndpointAgent +from .task import ( SagemakerDeleteEndpointConfigTask, SagemakerDeleteEndpointTask, SagemakerDeleteModelTask, - SagemakerEndpointAgent, + SagemakerEndpointTask, SagemakerEndpointConfigTask, SagemakerInvokeEndpointTask, SagemakerModelTask, ) -from .boto3.agent import SyncBotoAgentTask -from .task import SagemakerEndpointTask +from .boto3.agent import SyncBotoAgent +from .boto3.task import SyncBotoTask + from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index da41cbf07a..b8cb9e961d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -1,6 +1,6 @@ import json from dataclasses import asdict, dataclass -from typing import Optional, Any, Union, Type +from typing import Optional import grpc from flyteidl.admin.agent_pb2 import ( @@ -15,12 +15,19 @@ AgentBase, convert_to_flyte_state, get_agent_secret, + AgentRegistry, ) from flytekit.models.literals import LiteralMap -from flytekit import ImageSpec + from .boto3.mixin import Boto3AgentMixin -from .boto3.agent import SyncBotoAgentTask + + +states = { + "Creating": "Running", + "InService": "Success", + "Failed": "Failed", +} @dataclass @@ -36,6 +43,7 @@ def __init__(self): super().__init__( service="sagemaker", task_type="sagemaker-endpoint", + asynchronous=True, ) async def async_create( @@ -50,9 +58,8 @@ async def async_create( region = custom["region"] await self._call( - "create_endpoint", + method="create_endpoint", config=config, - task_template=task_template, inputs=inputs, region=region, aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), @@ -67,18 +74,16 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) endpoint_status = await self._call( - "describe_endpoint", + method="describe_endpoint", config={"EndpointName": metadata.endpoint_name}, ) current_state = endpoint_status.get("EndpointStatus") message = "" - if current_state in ("Failed", "UpdateRollbackFailed"): + if current_state == "Failed": message = endpoint_status.get("FailureReason") - # THIS WON'T WORK. NEED TO FIX THIS. - flyte_state = convert_to_flyte_state(current_state) - + flyte_state = convert_to_flyte_state(states[current_state]) return GetTaskResponse(resource=Resource(state=flyte_state, message=message)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: @@ -92,91 +97,4 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes return DeleteTaskResponse() -class SagemakerModelTask(SyncBotoAgentTask): - def __init__( - self, - name: str, - config: dict[str, Any], - region: Optional[str] = None, - container_image: Optional[Union[str, ImageSpec]] = None, - **kwargs, - ): - super(SagemakerModelTask, self).__init__( - service="sagemaker", - method="create_model", - region=region, - name=name, - config=config, - container_image=container_image, - **kwargs, - ) - - -class SagemakerEndpointConfigTask(SyncBotoAgentTask): - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super(SagemakerEndpointConfigTask, self).__init__( - service="sagemaker", - method="create_endpoint_config", - region=region, - name=name, - config=config, - **kwargs, - ) - - -class SagemakerDeleteEndpointTask(SyncBotoAgentTask): - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super(SagemakerDeleteEndpointTask, self).__init__( - service="sagemaker", - method="delete_endpoint", - region=region, - name=name, - config=config, - **kwargs, - ) - - -class SagemakerDeleteEndpointConfigTask(SyncBotoAgentTask): - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super(SagemakerDeleteEndpointConfigTask, self).__init__( - service="sagemaker", - method="delete_endpoint_config", - region=region, - name=name, - config=config, - **kwargs, - ) - - -class SagemakerDeleteModelTask(SyncBotoAgentTask): - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super(SagemakerDeleteModelTask, self).__init__( - service="sagemaker", - method="delete_model", - region=region, - name=name, - config=config, - **kwargs, - ) - - -class SagemakerInvokeEndpointTask(SyncBotoAgentTask): - def __init__(self, name: str, config: dict[str, Any], region: Optional[str] = None, **kwargs): - super(SagemakerInvokeEndpointTask, self).__init__( - service="sagemaker-runtime", - method="invoke_endpoint_async", - region=region, - name=name, - config=config, - **kwargs, - ) - - def do( - self, - output_result_type: Type, - **kwargs, - ): - super(SagemakerInvokeEndpointTask, self).do( - output_result_type=dict[str, Union[str, output_result_type]], - **kwargs, - ) +AgentRegistry.register(SagemakerEndpointAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py index 5d576e4a0d..7bb9de1481 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py @@ -1,56 +1,49 @@ -from typing import Any, Optional, Type, Union +from typing import Optional +import grpc from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource from flyteidl.core.tasks_pb2 import TaskTemplate - -from flytekit import FlyteContextManager, ImageSpec -from flytekit.core.external_api_task import ExternalApiTask +from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import ( + AgentBase, + AgentRegistry, + get_agent_secret, +) from flytekit.models.literals import LiteralMap -from flytekit.extend.backend.base_agent import get_agent_secret from .mixin import Boto3AgentMixin -# ExternalApiTask needs to inherit from PythonFunctionTask -class SyncBotoAgentTask(Boto3AgentMixin, ExternalApiTask): +class SyncBotoAgent(AgentBase): """A general purpose boto3 agent that can be used to call any boto3 method synchronously.""" - def __init__( - self, - name: str, - service: str, - method: str, - config: dict[str, Any], - region: Optional[str] = None, - container_image: Optional[Union[str, ImageSpec]] = None, - **kwargs, - ): + def __init__(self): super().__init__( - service=service, - method=method, - region=region, - name=name, - config=config, - container_image=container_image, - **kwargs, + task_type="sync-boto", + asynchronous=False, ) - def do( + def create( self, + context: grpc.ServicerContext, + output_prefix: str, task_template: TaskTemplate, - output_result_type: Type = dict[str, str], inputs: Optional[LiteralMap] = None, - additional_args: Optional[dict[str, Any]] = None, - region: Optional[str] = None, - ): - inputs = inputs or LiteralMap(literals={}) - result = self._call( - config=task_template.custom["task_config"], + ) -> CreateTaskResponse: + custom = task_template.custom + service = custom["service"] + config = custom["config"] + region = custom["region"] + method = custom["method"] + output_type = custom["output_type"] + + boto3_object = Boto3AgentMixin(service=service, region=region) + result = boto3_object._call( + method=method, + config=config, + container=task_template.container, inputs=inputs, - task_template=task_template, - additional_args=additional_args, - region=region, aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), @@ -64,10 +57,13 @@ def do( "o0": TypeEngine.to_literal( ctx, result, - output_result_type, - TypeEngine.to_literal_type(output_result_type), + output_type, + TypeEngine.to_literal_type(output_type), ) } ).to_flyte_idl() return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + + +AgentRegistry.register(SyncBotoAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py index 5f1ec40ae2..7a311b8fd5 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py @@ -1,11 +1,11 @@ from typing import Any, Optional import aioboto3 -from flyteidl.core.tasks_pb2 import TaskTemplate from google.protobuf.json_format import MessageToDict from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.models.literals import LiteralMap +from flytekit.models import task as _task_model class AttrDict(dict): @@ -28,7 +28,7 @@ def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: :param original_dict: The dictionary to update (in place) :param update_dict: The dictionary to use for updating - :return: The updated dictionary - note that the original dictionary is updated in place + :return: The updated dictionary """ if original_dict is None: return None @@ -60,11 +60,11 @@ def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: class Boto3AgentMixin: """ - This mixin facilitates the creation of a Flyte agent for any AWS service using boto3. - It provides a single method, `_call`, which can be employed to invoke any boto3 method. + This mixin facilitates the creation of a Boto3 agent for any AWS service. + It provides a single method, `_call`, which can be employed to invoke any Boto3 method. """ - def __init__(self, *, service: str, method: str, region: Optional[str] = None, **kwargs): + def __init__(self, *, service: str, region: Optional[str] = None, **kwargs): """ Initialize the Boto3AgentMixin. @@ -73,15 +73,15 @@ def __init__(self, *, service: str, method: str, region: Optional[str] = None, * """ self._service = service self._region = region - self._method = method + super().__init__(**kwargs) async def _call( self, + method: str, config: dict[str, Any], - task_template: Optional[TaskTemplate] = None, + container: Optional[_task_model.Container] = None, inputs: Optional[LiteralMap] = None, - additional_args: Optional[dict[str, Any]] = None, region: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, @@ -92,29 +92,29 @@ async def _call( :param method: The boto3 method to invoke, e.g., create_endpoint_config. :param config: The configuration for the method, e.g., {"EndpointConfigName": "my-endpoint-config"}. The config - may contain placeholders replaced by values from inputs, task_template, or additional_args. + may contain placeholders replaced by values from inputs and container. For example, if the config is {"EndpointConfigName": "{inputs.endpoint_config_name}", "EndpointName": "{endpoint_name}", "Image": "{container.image}"} - and the additional_args dict is {"endpoint_name": "my-endpoint"}, the inputs contain a string literal for - endpoint_config_name, and the task_template contains a container with an image, + the inputs contain a string literal for endpoint_config_name, and the container has the image, then the config will be updated to {"EndpointConfigName": "my-endpoint-config", "EndpointName": "my-endpoint", "Image": "my-image"} before invoking the boto3 method. - :param task_template: The task template for the task being created. + :param container: Container retrieved from the task template. :param inputs: The inputs for the task being created. - :param additional_args: Additional arguments for updating the config. These are optional and can be controlled by the task author. :param region: The region for the boto3 client. If not provided, the region specified in the constructor will be used. + :param aws_access_key_id: The access key ID to use to access the AWS resources. + :param aws_secret_access_key: The secret access key to use to access the AWS resources + :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} if inputs: args["inputs"] = AttrDict(literal_map_string_repr(inputs)) - if task_template: - args["container"] = AttrDict(MessageToDict(task_template.container)) - if additional_args: - args.update(additional_args) + if container: + args["container"] = AttrDict(MessageToDict(container)) updated_config = update_dict_fn(config, args) + # Asynchronouse Boto3 session session = aioboto3.Session() async with session.client( service_name=self._service, @@ -124,7 +124,7 @@ async def _call( aws_session_token=aws_session_token, ) as client: try: - result = await getattr(client, self._method)(**updated_config) + result = await getattr(client, method)(**updated_config) except Exception as e: raise e diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/task.py new file mode 100644 index 0000000000..fb70856dd9 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/task.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass +from typing import Any, Optional, Type, Union + +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct + +from flytekit import ImageSpec +from flytekit.configuration import SerializationSettings +from flytekit.core.python_function_task import PythonInstanceTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin + + +@dataclass +class SyncBotoConfig(object): + service: str + method: str + config: dict[str, Any] + region: str + + +class SyncBotoTask(AsyncAgentExecutorMixin, PythonInstanceTask[SyncBotoConfig]): + _TASK_TYPE = "sync-boto" + + def __init__( + self, + name: str, + task_config: SyncBotoConfig, + inputs: Optional[dict[str, Type]] = None, + output_type: Optional[Type] = None, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + self._output_type = output_type + super().__init__( + name=name, + task_config=task_config, + task_type=self._TASK_TYPE, + interface=Interface(inputs=inputs, outputs={"result": output_type}), + container_image=container_image, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: + config = { + "service": self.task_config.service, + "config": self.task_config.config, + "region": self.task_config.region, + "method": self.task_config.method, + "output_type": self._output_type, + } + s = Struct() + s.update(config) + return json_format.MessageToDict(s) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 0bcf7b23df..c7b054823b 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional, Type +from typing import Any, Optional, Type, Union from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct @@ -8,29 +8,105 @@ from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from .boto3.task import SyncBotoTask, SyncBotoConfig +from flytekit import ImageSpec + + +class SagemakerModelTask(SyncBotoTask): + def __init__( + self, + name: str, + config: dict[str, Any], + region: Optional[str] = None, + inputs: Optional[dict[str, Type]] = None, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + """ + Creates a Sagemaker model. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + :param container_image: The path where inference code is stored. + This can be either in Amazon EC2 Container Registry or in a Docker registry + that is accessible from the same VPC that you configure for your endpoint. + """ + super(SagemakerModelTask, self).__init__( + name=name, + task_config=SyncBotoConfig(service="sagemaker", method="create_model", config=config, region=region), + inputs=inputs, + output_type=dict[str, str], + container_image=container_image, + **kwargs, + ) + + +class SagemakerEndpointConfigTask(SyncBotoTask): + def __init__( + self, + name: str, + config: dict[str, Any], + region: Optional[str] = None, + inputs: Optional[dict[str, Type]] = None, + **kwargs, + ): + """ + Creates a Sagemaker endpoint configuration. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + """ + super(SagemakerEndpointConfigTask, self).__init__( + name=name, + task_config=SyncBotoConfig( + service="sagemaker", + method="create_endpoint_config", + config=config, + region=region, + ), + inputs=inputs, + output_type=dict[str, str], + **kwargs, + ) @dataclass -class SagemakerEndpointConfig(object): +class SagemakerEndpointMetadata(object): config: dict[str, Any] region: str -class SagemakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SagemakerEndpointConfig]): +class SagemakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SagemakerEndpointMetadata]): _TASK_TYPE = "sagemaker-endpoint" def __init__( self, name: str, - task_config: SagemakerEndpointConfig, + config: dict[str, Any], + region: Optional[str] = None, inputs: Optional[dict[str, Type]] = None, **kwargs, ): + """ + Creates a Sagemaker endpoint. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + """ super().__init__( name=name, - task_config=task_config, - interface=Interface(inputs=inputs or {}), + task_config=SagemakerEndpointMetadata( + config=config, + region=region, + ), task_type=self._TASK_TYPE, + interface=Interface(inputs=inputs, outputs={"result": dict[str, str]}), **kwargs, ) @@ -39,3 +115,126 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: s = Struct() s.update(config) return json_format.MessageToDict(s) + + +class SagemakerDeleteEndpointTask(SyncBotoTask): + def __init__( + self, + name: str, + config: dict[str, Any], + region: Optional[str] = None, + inputs: Optional[dict[str, Type]] = None, + **kwargs, + ): + """ + Deletes a Sagemaker endpoint. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + """ + super(SagemakerDeleteEndpointTask, self).__init__( + name=name, + task_config=SyncBotoConfig( + service="sagemaker", + method="delete_endpoint", + config=config, + region=region, + ), + inputs=inputs, + **kwargs, + ) + + +class SagemakerDeleteEndpointConfigTask(SyncBotoTask): + def __init__( + self, + name: str, + config: dict[str, Any], + region: Optional[str] = None, + inputs: Optional[dict[str, Type]] = None, + **kwargs, + ): + """ + Deletes a Sagemaker endpoint config. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + """ + super(SagemakerDeleteEndpointConfigTask, self).__init__( + name=name, + task_config=SyncBotoConfig( + service="sagemaker", + method="delete_endpoint_config", + config=config, + region=region, + ), + inputs=inputs, + **kwargs, + ) + + +class SagemakerDeleteModelTask(SyncBotoTask): + def __init__( + self, + name: str, + config: dict[str, Any], + region: Optional[str] = None, + inputs: Optional[dict[str, Type]] = None, + **kwargs, + ): + """ + Deletes a Sagemaker model. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + """ + super(SagemakerDeleteModelTask, self).__init__( + name=name, + task_config=SyncBotoConfig( + service="sagemaker", + method="delete_model", + config=config, + region=region, + ), + inputs=inputs, + **kwargs, + ) + + +class SagemakerInvokeEndpointTask(SyncBotoConfig): + def __init__( + self, + name: str, + config: dict[str, Any], + output_type: Type, + region: Optional[str] = None, + inputs: Optional[dict[str, Type]] = None, + **kwargs, + ): + """ + Invokes a Sagemaker endpoint. + + :param name: The name of the task. + :param config: The configuration to be provided to the boto3 API call. + :param output_type: The type of output. + :param region: The region for the boto3 client. + :param inputs: The input literal map to be used for updating the configuration. + """ + super(SagemakerInvokeEndpointTask, self).__init__( + name=name, + task_config=SyncBotoConfig( + service="sagemaker-runtime", + method="invoke_endpoint", + config=config, + region=region, + ), + inputs=inputs, + output_type=dict[str, Union[str, output_type]], + **kwargs, + ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index a2a605da20..e36a9e91b9 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -1,16 +1,14 @@ -from flytekit import Workflow, kwtypes, LaunchPlan, ImageSpec -from .agent import ( +from flytekit import Workflow, kwtypes, ImageSpec +from .task import ( SagemakerModelTask, SagemakerEndpointConfigTask, SagemakerDeleteEndpointTask, SagemakerDeleteEndpointConfigTask, SagemakerDeleteModelTask, + SagemakerEndpointTask, ) -from flytekit.models import literals - -from .task import SagemakerEndpointTask -from typing import Any, Optional, Union +from typing import Any, Optional, Union, Type def create_sagemaker_deployment( @@ -18,10 +16,11 @@ def create_sagemaker_deployment( model_config: dict[str, Any], endpoint_config_config: dict[str, Any], endpoint_config: dict[str, Any], + model_input_types: Optional[dict[str, Type]] = None, + endpoint_config_input_types: Optional[dict[str, Type]] = None, + endpoint_input_types: Optional[dict[str, Type]] = None, container_image: Optional[Union[str, ImageSpec]] = None, region: Optional[str] = None, - model_additional_args: Optional[dict[str, Any]] = None, - endpoint_config_additional_args: Optional[dict[str, Any]] = None, ): """ Creates Sagemaker model, endpoint config and endpoint. @@ -30,6 +29,7 @@ def create_sagemaker_deployment( name=f"sagemaker-model-{model_name}", config=model_config, region=region, + inputs=model_input_types, container_image=container_image, ) @@ -37,42 +37,34 @@ def create_sagemaker_deployment( name=f"sagemaker-endpoint-config-{model_name}", config=endpoint_config_config, region=region, + inputs=endpoint_config_input_types, ) endpoint_task = SagemakerEndpointTask( name=f"sagemaker-endpoint-{model_name}", - task_config=endpoint_config, - inputs=kwtypes(inputs=dict), + config=endpoint_config, + region=region, + inputs=endpoint_input_types, ) wf = Workflow(name=f"sagemaker-deploy-{model_name}") - wf.add_workflow_input("model_inputs", dict) - wf.add_workflow_input("endpoint_config_inputs", dict) - wf.add_workflow_input("endpoint_inputs", dict) + wf.add_workflow_input("model_inputs", Optional[dict]) + wf.add_workflow_input("endpoint_config_inputs", Optional[dict]) + wf.add_workflow_input("endpoint_inputs", Optional[dict]) wf.add_entity( sagemaker_model_task, - inputs=wf.inputs["model_inputs"], - additional_args=model_additional_args, + **wf.inputs["model_inputs"], ) wf.add_entity( endpoint_config_task, - inputs=wf.inputs["endpoint_config_inputs"], - additional_args=endpoint_config_additional_args, + **wf.inputs["endpoint_config_inputs"], ) - wf.add_entity(endpoint_task, inputs=wf.inputs["endpoint_inputs"]) + wf.add_entity(endpoint_task, **wf.inputs["endpoint_inputs"]) - lp = LaunchPlan.get_or_create( - workflow=wf, - default_inputs={ - "model_inputs": None, - "endpoint_config_inputs": None, - "endpoint_status": None, - }, - ) - return lp + return wf def delete_sagemaker_deployment(name: str, region: Optional[str] = None): @@ -83,18 +75,21 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None): name=f"sagemaker-delete-endpoint-{name}", config={"EndpointName": "{endpoint_name}"}, region=region, + inputs=kwtypes(endpoint_name=str), ) sagemaker_delete_endpoint_config = SagemakerDeleteEndpointConfigTask( name=f"sagemaker-delete-endpoint-config-{name}", config={"EndpointConfigName": "{endpoint_config_name}"}, region=region, + inputs=kwtypes(endpoint_config_name=str), ) sagemaker_delete_model = SagemakerDeleteModelTask( name=f"sagemaker-delete-model-{name}", config={"ModelName": "{model_name}"}, region=region, + inputs=kwtypes(model_name=str), ) wf = Workflow(name=f"sagemaker-delete-endpoint-wf-{name}") @@ -104,15 +99,15 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None): wf.add_entity( sagemaker_delete_endpoint, - inputs=literals.LiteralMap({"endpoint_name": wf.inputs["endpoint_name"]}), + **{"endpoint_name": wf.inputs["endpoint_name"]}, ) wf.add_entity( sagemaker_delete_endpoint_config, - inputs=literals.LiteralMap({"endpoint_config_name": wf.inputs["endpoint_config_name"]}), + **{"endpoint_config_name": wf.inputs["endpoint_config_name"]}, ) wf.add_entity( sagemaker_delete_model, - inputs=literals.LiteralMap({"model_name", wf.inputs["model_name"]}), + **{"model_name", wf.inputs["model_name"]}, ) return wf diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 13035a5ea6..08fdc816ba 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,11 +4,10 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3"] +plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3<=2.5.4"] __version__ = "0.0.0+develop" -# TODO: move sagemaker install script here. setup( name=microlib_name, version=__version__, From 2cfba8a6d8608336a624bea32ae6bcd6a1a63178 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 26 Jan 2024 22:29:32 +0530 Subject: [PATCH 011/120] update Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index dba763b3f4..68e5573aff 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -31,5 +31,4 @@ ) from .boto3.agent import SyncBotoAgent from .boto3.task import SyncBotoTask - from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment From 93eb22219651405cd2a844bb99433d0db95e552a Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 26 Jan 2024 22:41:37 +0530 Subject: [PATCH 012/120] pin aioboto3 version Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 08fdc816ba..77d8bf2255 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3<=2.5.4"] +plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3==11.1.1"] __version__ = "0.0.0+develop" From 34afdafd8e90cc1b16acd9446e6aa0b17506b582 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 26 Jan 2024 22:50:52 +0530 Subject: [PATCH 013/120] remove boto3 directory Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 2 +- .../flytekitplugins/awssagemaker/boto3/__init__.py | 0 .../awssagemaker/{boto3/agent.py => boto3_agent.py} | 2 +- .../awssagemaker/{boto3/mixin.py => boto3_mixin.py} | 0 .../awssagemaker/{boto3/task.py => boto3_task.py} | 0 .../flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py | 2 +- plugins/flytekit-aws-sagemaker/setup.py | 1 + 7 files changed, 4 insertions(+), 3 deletions(-) delete mode 100644 plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/__init__.py rename plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/{boto3/agent.py => boto3_agent.py} (98%) rename plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/{boto3/mixin.py => boto3_mixin.py} (100%) rename plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/{boto3/task.py => boto3_task.py} (100%) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index b8cb9e961d..521e218a38 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -20,7 +20,7 @@ from flytekit.models.literals import LiteralMap -from .boto3.mixin import Boto3AgentMixin +from .boto3_mixin import Boto3AgentMixin states = { diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py similarity index 98% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 7bb9de1481..724fbe198e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -12,7 +12,7 @@ ) from flytekit.models.literals import LiteralMap -from .mixin import Boto3AgentMixin +from .boto3_mixin import Boto3AgentMixin class SyncBotoAgent(AgentBase): diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/mixin.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3/task.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index c7b054823b..6453221319 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -8,7 +8,7 @@ from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin -from .boto3.task import SyncBotoTask, SyncBotoConfig +from .boto3_task import SyncBotoTask, SyncBotoConfig from flytekit import ImageSpec diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 77d8bf2255..a328177c0d 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,6 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" +# s3fs 2023.9.2 requires aiobotocore~=2.5.4 plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3==11.1.1"] __version__ = "0.0.0+develop" From 979379a9aa37cb39071ca804c027141efb2485e8 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 26 Jan 2024 22:53:08 +0530 Subject: [PATCH 014/120] update imports Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index 68e5573aff..482021be7d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -29,6 +29,6 @@ SagemakerInvokeEndpointTask, SagemakerModelTask, ) -from .boto3.agent import SyncBotoAgent -from .boto3.task import SyncBotoTask +from .boto3_agent import SyncBotoAgent +from .boto3_task import SyncBotoTask from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment From 1760ffed6e6050624067c65c2fcdb5630fc16e93 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 13:00:07 +0530 Subject: [PATCH 015/120] boto3 update code and add tests Signed-off-by: Samhita Alla --- .../awssagemaker/boto3_mixin.py | 48 ++++++++----- .../tests/agents/test_boto3_mixin.py | 40 ----------- .../__init__.py => test_boto3_agent.py} | 0 .../tests/test_boto3_mixin.py | 72 +++++++++++++++++++ 4 files changed, 101 insertions(+), 59 deletions(-) delete mode 100644 plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py rename plugins/flytekit-aws-sagemaker/tests/{agents/__init__.py => test_boto3_agent.py} (100%) create mode 100644 plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py index 7a311b8fd5..af775cd717 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py @@ -8,17 +8,6 @@ from flytekit.models import task as _task_model -class AttrDict(dict): - """ - This class converts a dictionary into an attribute-style lookup. It is specifically designed for - namespacing inputs and outputs, providing a convenient way to access dictionary elements using dot notation. - """ - - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: """ Recursively update a dictionary with values from another dictionary. @@ -33,14 +22,35 @@ def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: if original_dict is None: return None - # If the original value is a string + # If the original value is a string and contains placeholder curly braces if isinstance(original_dict, str): - # If the string contains placeholder curly braces, replace the placeholder with the actual value if "{" in original_dict and "}" in original_dict: - try: - return original_dict.format(**update_dict) - except KeyError as e: - raise ValueError(f"Variable {e} in placeholder not found in inputs {update_dict.keys()}") + # Check if there are nested keys + if "." in original_dict: + # Create a copy of update_dict + update_dict_copy = update_dict.copy() + + # Fetch keys from the original_dict + keys = original_dict.strip("{}").split(".") + + # Get value from the nested dictionary + for key in keys: + try: + update_dict_copy = update_dict_copy[key] + except Exception: + raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") + + return update_dict_copy + + # Retrieve the original value using the key without curly braces + original_value = update_dict.get(original_dict.strip("{}")) + + # Check if original_value exists; if so, return it, + # otherwise, raise a ValueError indicating that the value for the key original_dict could not be found. + if original_value: + return original_value + else: + raise ValueError(f"Could not find value for {original_dict}.") # If the string does not contain placeholders, return it as is return original_dict @@ -108,9 +118,9 @@ async def _call( """ args = {} if inputs: - args["inputs"] = AttrDict(literal_map_string_repr(inputs)) + args["inputs"] = literal_map_string_repr(inputs) if container: - args["container"] = AttrDict(MessageToDict(container)) + args["container"] = MessageToDict(container) updated_config = update_dict_fn(config, args) diff --git a/plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py deleted file mode 100644 index 7ea135a5ff..0000000000 --- a/plugins/flytekit-aws-sagemaker/tests/agents/test_boto3_mixin.py +++ /dev/null @@ -1,40 +0,0 @@ -import typing -from dataclasses import dataclass - -from flytekit import FlyteContext, StructuredDataset -from flytekit.core.type_engine import TypeEngine -from flytekit.interaction.string_literals import literal_map_string_repr -from flytekit.types.file import FlyteFile -from flytekitplugins.awssagemaker.agents.boto3_mixin import update_dict - - -@dataclass -class MyData: - image: str - model_name: str - model_path: str - - -# TODO improve this test to actually assert outputs -def test_update_dict(): - d = update_dict( - {"a": "{a}", "b": "{b}", "c": "{c}", "d": "{d}", "e": "{e}", "f": "{f}", - "j": {"a": "{a}", "b": "{f}", "c": "{e}"}}, - {"a": 1, "b": "hello", "c": True, "d": 1.0, "e": [1, 2, 3], "f": {"a": "b"}}) - assert d == {'a': 1, 'b': 'hello', 'c': True, 'd': 1.0, 'e': [1, 2, 3], 'f': {'a': 'b'}, - 'j': {'a': 1, 'b': {'a': 'b'}, 'c': [1, 2, 3]}} - - lm = TypeEngine.dict_to_literal_map(FlyteContext.current_context(), - {"a": 1, "b": "hello", "c": True, "d": 1.0, - "e": [1, 2, 3], "f": {"a": "b"}, "g": None, - "h": FlyteFile("s3://foo/bar", remote_path=False), - "i": StructuredDataset(uri="s3://foo/bar")}, - {"a": int, "b": str, "c": bool, "d": float, "e": typing.List[int], - "f": typing.Dict[str, str], "g": typing.Optional[str], "h": FlyteFile, - "i": StructuredDataset}) - - d = literal_map_string_repr(lm) - print(d) - - print("{data.image}, {data.model_name}, {data.model_path}".format( - data=MyData(image="foo", model_name="bar", model_path="baz"))) diff --git a/plugins/flytekit-aws-sagemaker/tests/agents/__init__.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/tests/agents/__init__.py rename to plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py new file mode 100644 index 0000000000..2d08fded41 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -0,0 +1,72 @@ +import typing + +from flytekit import FlyteContext, StructuredDataset +from flytekit.core.type_engine import TypeEngine +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.types.file import FlyteFile +from flytekitplugins.awssagemaker.boto3_mixin import update_dict_fn + + +def test_inputs(): + original_dict = { + "a": "{inputs.a}", + "b": "{inputs.b}", + "c": "{inputs.c}", + "d": "{inputs.d}", + "e": "{inputs.e}", + "f": "{inputs.f}", + "j": {"g": "{inputs.g}", "h": "{inputs.h}", "i": "{inputs.i}"}, + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + { + "a": 1, + "b": "hello", + "c": True, + "d": 1.0, + "e": [1, 2, 3], + "f": {"a": "b"}, + "g": None, + "h": FlyteFile("s3://foo/bar", remote_path=False), + "i": StructuredDataset(uri="s3://foo/bar"), + }, + { + "a": int, + "b": str, + "c": bool, + "d": float, + "e": typing.List[int], + "f": typing.Dict[str, str], + "g": typing.Optional[str], + "h": FlyteFile, + "i": StructuredDataset, + }, + ) + + result = update_dict_fn( + original_dict=original_dict, + update_dict={"inputs": literal_map_string_repr(inputs)}, + ) + + assert result == { + "a": 1, + "b": "hello", + "c": True, + "d": 1.0, + "e": [1, 2, 3], + "f": {"a": "b"}, + "j": { + "g": None, + "h": "s3://foo/bar", + "i": "s3://foo/bar", + }, + } + + +def test_container(): + original_dict = {"a": "{container.image}"} + container = {"image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} + + result = update_dict_fn(original_dict=original_dict, update_dict={"container": container}) + + assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} From 78a2069ae1008f74859dcc18f8f1f1f602ad1fde Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 16:03:30 +0530 Subject: [PATCH 016/120] remove output type Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 5 ++--- .../flytekitplugins/awssagemaker/boto3_task.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 724fbe198e..7f3d21d849 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -36,7 +36,6 @@ def create( config = custom["config"] region = custom["region"] method = custom["method"] - output_type = custom["output_type"] boto3_object = Boto3AgentMixin(service=service, region=region) result = boto3_object._call( @@ -57,8 +56,8 @@ def create( "o0": TypeEngine.to_literal( ctx, result, - output_type, - TypeEngine.to_literal_type(output_type), + type(result), + TypeEngine.to_literal_type(type(result)), ) } ).to_flyte_idl() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index fb70856dd9..224b455cb3 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -47,7 +47,6 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: "config": self.task_config.config, "region": self.task_config.region, "method": self.task_config.method, - "output_type": self._output_type, } s = Struct() s.update(config) From dab6780a1df87fe91c56ce2a1a93054bfcb4f93d Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 16:12:10 +0530 Subject: [PATCH 017/120] add await Signed-off-by: Samhita Alla --- .../awssagemaker/boto3_agent.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 7f3d21d849..73273c7924 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -11,9 +11,12 @@ get_agent_secret, ) from flytekit.models.literals import LiteralMap +import asyncio from .boto3_mixin import Boto3AgentMixin +TIMEOUT_SECONDS = 20 + class SyncBotoAgent(AgentBase): """A general purpose boto3 agent that can be used to call any boto3 method synchronously.""" @@ -24,7 +27,7 @@ def __init__(self): asynchronous=False, ) - def create( + async def create( self, context: grpc.ServicerContext, output_prefix: str, @@ -38,14 +41,17 @@ def create( method = custom["method"] boto3_object = Boto3AgentMixin(service=service, region=region) - result = boto3_object._call( - method=method, - config=config, - container=task_template.container, - inputs=inputs, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + result = await asyncio.wait_for( + boto3_object._call( + method=method, + config=config, + container=task_template.container, + inputs=inputs, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + ), + timeout=TIMEOUT_SECONDS, ) outputs = None From 36850c33c7b2050746555f14edb067924921c6a4 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 17:06:14 +0530 Subject: [PATCH 018/120] remove sync Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 6 +---- .../awssagemaker/boto3_agent.py | 13 ++++------ .../awssagemaker/boto3_task.py | 6 ++--- .../flytekitplugins/awssagemaker/task.py | 26 +++++++++---------- 4 files changed, 22 insertions(+), 29 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 521e218a38..594b7117c2 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -40,11 +40,7 @@ class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): """This agent creates an endpoint.""" def __init__(self): - super().__init__( - service="sagemaker", - task_type="sagemaker-endpoint", - asynchronous=True, - ) + super().__init__(service="sagemaker", task_type="sagemaker-endpoint") async def async_create( self, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 73273c7924..e0f4465f34 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -18,16 +18,13 @@ TIMEOUT_SECONDS = 20 -class SyncBotoAgent(AgentBase): - """A general purpose boto3 agent that can be used to call any boto3 method synchronously.""" +class BotoAgent(AgentBase): + """A general purpose boto3 agent that can be used to call any boto3 method.""" def __init__(self): - super().__init__( - task_type="sync-boto", - asynchronous=False, - ) + super().__init__(task_type="sync-boto") - async def create( + async def async_create( self, context: grpc.ServicerContext, output_prefix: str, @@ -71,4 +68,4 @@ async def create( return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) -AgentRegistry.register(SyncBotoAgent()) +AgentRegistry.register(BotoAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index 224b455cb3..a451583126 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -12,20 +12,20 @@ @dataclass -class SyncBotoConfig(object): +class BotoConfig(object): service: str method: str config: dict[str, Any] region: str -class SyncBotoTask(AsyncAgentExecutorMixin, PythonInstanceTask[SyncBotoConfig]): +class BotoTask(AsyncAgentExecutorMixin, PythonInstanceTask[BotoConfig]): _TASK_TYPE = "sync-boto" def __init__( self, name: str, - task_config: SyncBotoConfig, + task_config: BotoConfig, inputs: Optional[dict[str, Type]] = None, output_type: Optional[Type] = None, container_image: Optional[Union[str, ImageSpec]] = None, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 6453221319..e21f5bb82f 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -8,11 +8,11 @@ from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin -from .boto3_task import SyncBotoTask, SyncBotoConfig +from .boto3_task import BotoTask, BotoConfig from flytekit import ImageSpec -class SagemakerModelTask(SyncBotoTask): +class SagemakerModelTask(BotoTask): def __init__( self, name: str, @@ -35,7 +35,7 @@ def __init__( """ super(SagemakerModelTask, self).__init__( name=name, - task_config=SyncBotoConfig(service="sagemaker", method="create_model", config=config, region=region), + task_config=BotoConfig(service="sagemaker", method="create_model", config=config, region=region), inputs=inputs, output_type=dict[str, str], container_image=container_image, @@ -43,7 +43,7 @@ def __init__( ) -class SagemakerEndpointConfigTask(SyncBotoTask): +class SagemakerEndpointConfigTask(BotoTask): def __init__( self, name: str, @@ -62,7 +62,7 @@ def __init__( """ super(SagemakerEndpointConfigTask, self).__init__( name=name, - task_config=SyncBotoConfig( + task_config=BotoConfig( service="sagemaker", method="create_endpoint_config", config=config, @@ -117,7 +117,7 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: return json_format.MessageToDict(s) -class SagemakerDeleteEndpointTask(SyncBotoTask): +class SagemakerDeleteEndpointTask(BotoTask): def __init__( self, name: str, @@ -136,7 +136,7 @@ def __init__( """ super(SagemakerDeleteEndpointTask, self).__init__( name=name, - task_config=SyncBotoConfig( + task_config=BotoConfig( service="sagemaker", method="delete_endpoint", config=config, @@ -147,7 +147,7 @@ def __init__( ) -class SagemakerDeleteEndpointConfigTask(SyncBotoTask): +class SagemakerDeleteEndpointConfigTask(BotoTask): def __init__( self, name: str, @@ -166,7 +166,7 @@ def __init__( """ super(SagemakerDeleteEndpointConfigTask, self).__init__( name=name, - task_config=SyncBotoConfig( + task_config=BotoConfig( service="sagemaker", method="delete_endpoint_config", config=config, @@ -177,7 +177,7 @@ def __init__( ) -class SagemakerDeleteModelTask(SyncBotoTask): +class SagemakerDeleteModelTask(BotoTask): def __init__( self, name: str, @@ -196,7 +196,7 @@ def __init__( """ super(SagemakerDeleteModelTask, self).__init__( name=name, - task_config=SyncBotoConfig( + task_config=BotoConfig( service="sagemaker", method="delete_model", config=config, @@ -207,7 +207,7 @@ def __init__( ) -class SagemakerInvokeEndpointTask(SyncBotoConfig): +class SagemakerInvokeEndpointTask(BotoConfig): def __init__( self, name: str, @@ -228,7 +228,7 @@ def __init__( """ super(SagemakerInvokeEndpointTask, self).__init__( name=name, - task_config=SyncBotoConfig( + task_config=BotoConfig( service="sagemaker-runtime", method="invoke_endpoint", config=config, From 47f250e019dcf549adbf43ccd6a7fb21d73f0edb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 17:43:29 +0530 Subject: [PATCH 019/120] modify imports Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index 482021be7d..e053f3dacd 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -5,8 +5,8 @@ :template: custom.rst :toctree: generated/ - SyncBotoAgent - SyncBotoTask + BotoAgent + BotoTask SagemakerModelTask SagemakerEndpointConfigTask SagemakerEndpointAgent @@ -29,6 +29,6 @@ SagemakerInvokeEndpointTask, SagemakerModelTask, ) -from .boto3_agent import SyncBotoAgent -from .boto3_task import SyncBotoTask +from .boto3_agent import BotoAgent +from .boto3_task import BotoTask from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment From e601af80cc13250edefdc4eef286a900383b989e Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 18:14:51 +0530 Subject: [PATCH 020/120] modify container logic Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_mixin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py index af775cd717..35b1244772 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py @@ -1,7 +1,6 @@ from typing import Any, Optional import aioboto3 -from google.protobuf.json_format import MessageToDict from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.models.literals import LiteralMap @@ -120,7 +119,7 @@ async def _call( if inputs: args["inputs"] = literal_map_string_repr(inputs) if container: - args["container"] = MessageToDict(container) + args["container"] = {"image": container.image} updated_config = update_dict_fn(config, args) From cfccdd0d355a32554348b6edb2cbb71659e9d859 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 19:04:02 +0530 Subject: [PATCH 021/120] modify output key Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 2 +- .../flytekitplugins/awssagemaker/boto3_mixin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index e0f4465f34..e82ea6011e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -56,7 +56,7 @@ async def async_create( ctx = FlyteContextManager.current_context() outputs = LiteralMap( { - "o0": TypeEngine.to_literal( + "result": TypeEngine.to_literal( ctx, result, type(result), diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py index 35b1244772..917a6f8cf7 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py @@ -127,7 +127,7 @@ async def _call( session = aioboto3.Session() async with session.client( service_name=self._service, - region_name=region, + region_name=self._region or region, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, From 31500cef139a8eba51954d45eac7032a61f84717 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 19:09:24 +0530 Subject: [PATCH 022/120] add default container image Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/task.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index e21f5bb82f..7e839c72c2 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -4,7 +4,7 @@ from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct -from flytekit.configuration import SerializationSettings +from flytekit.configuration import SerializationSettings, DefaultImages from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin @@ -70,6 +70,7 @@ def __init__( ), inputs=inputs, output_type=dict[str, str], + container_image=DefaultImages.default_image(), **kwargs, ) @@ -143,6 +144,7 @@ def __init__( region=region, ), inputs=inputs, + container_image=DefaultImages.default_image(), **kwargs, ) @@ -173,6 +175,7 @@ def __init__( region=region, ), inputs=inputs, + container_image=DefaultImages.default_image(), **kwargs, ) @@ -203,6 +206,7 @@ def __init__( region=region, ), inputs=inputs, + container_image=DefaultImages.default_image(), **kwargs, ) @@ -236,5 +240,6 @@ def __init__( ), inputs=inputs, output_type=dict[str, Union[str, output_type]], + container_image=DefaultImages.default_image(), **kwargs, ) From b150887b514ebf09c440d867d1e185003c966bd3 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 19:23:03 +0530 Subject: [PATCH 023/120] remove struct Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_task.py | 7 +------ .../flytekitplugins/awssagemaker/task.py | 14 +++++++------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index a451583126..33fe9b4398 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -1,8 +1,6 @@ from dataclasses import dataclass from typing import Any, Optional, Type, Union -from google.protobuf import json_format -from google.protobuf.struct_pb2 import Struct from flytekit import ImageSpec from flytekit.configuration import SerializationSettings @@ -42,12 +40,9 @@ def __init__( ) def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: - config = { + return { "service": self.task_config.service, "config": self.task_config.config, "region": self.task_config.region, "method": self.task_config.method, } - s = Struct() - s.update(config) - return json_format.MessageToDict(s) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 7e839c72c2..1b06aa266f 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -17,7 +17,7 @@ def __init__( self, name: str, config: dict[str, Any], - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, @@ -48,7 +48,7 @@ def __init__( self, name: str, config: dict[str, Any], - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, **kwargs, ): @@ -88,7 +88,7 @@ def __init__( self, name: str, config: dict[str, Any], - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, **kwargs, ): @@ -123,7 +123,7 @@ def __init__( self, name: str, config: dict[str, Any], - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, **kwargs, ): @@ -154,7 +154,7 @@ def __init__( self, name: str, config: dict[str, Any], - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, **kwargs, ): @@ -185,7 +185,7 @@ def __init__( self, name: str, config: dict[str, Any], - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, **kwargs, ): @@ -217,7 +217,7 @@ def __init__( name: str, config: dict[str, Any], output_type: Type, - region: Optional[str] = None, + region: Optional[str], inputs: Optional[dict[str, Type]] = None, **kwargs, ): From 28c31fac9f9aee5f0ba8feba2da6ea9fc61742f0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 19:27:07 +0530 Subject: [PATCH 024/120] add region Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 2 ++ .../flytekitplugins/awssagemaker/task.py | 7 +------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 594b7117c2..71fcf92db5 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -72,6 +72,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - endpoint_status = await self._call( method="describe_endpoint", config={"EndpointName": metadata.endpoint_name}, + region=metadata.region, ) current_state = endpoint_status.get("EndpointStatus") @@ -88,6 +89,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes await self._call( "delete_endpoint", config={"EndpointName": metadata.endpoint_name}, + region=metadata.region, ) return DeleteTaskResponse() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 1b06aa266f..a51da99d4c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -1,8 +1,6 @@ from dataclasses import dataclass from typing import Any, Optional, Type, Union -from google.protobuf import json_format -from google.protobuf.struct_pb2 import Struct from flytekit.configuration import SerializationSettings, DefaultImages from flytekit.core.base_task import PythonTask @@ -112,10 +110,7 @@ def __init__( ) def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: - config = {"config": self.task_config.config, "region": self.task_config.region} - s = Struct() - s.update(config) - return json_format.MessageToDict(s) + return {"config": self.task_config.config, "region": self.task_config.region} class SagemakerDeleteEndpointTask(BotoTask): From 568d3fcabe3c25f8b363ac3ab75205c755c1e423 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 21:04:36 +0530 Subject: [PATCH 025/120] add output to gettaskresponse Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 71fcf92db5..4ab2b0a71e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -18,8 +18,8 @@ AgentRegistry, ) from flytekit.models.literals import LiteralMap - - +from flytekit.core.type_engine import TypeEngine +from flytekit import FlyteContextManager from .boto3_mixin import Boto3AgentMixin @@ -76,12 +76,27 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - ) current_state = endpoint_status.get("EndpointStatus") + flyte_state = convert_to_flyte_state(states[current_state]) + message = "" if current_state == "Failed": message = endpoint_status.get("FailureReason") - flyte_state = convert_to_flyte_state(states[current_state]) - return GetTaskResponse(resource=Resource(state=flyte_state, message=message)) + res = None + if current_state == "Success": + ctx = FlyteContextManager.current_context() + res = LiteralMap( + { + "result": TypeEngine.to_literal( + ctx, + endpoint_status, + dict, + TypeEngine.to_literal_type(dict), + ) + } + ) + + return GetTaskResponse(resource=Resource(state=flyte_state, outputs=res, message=message)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) From 06b5002e671f2575375b78bac5e88e5e2b9667ea Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 22:04:31 +0530 Subject: [PATCH 026/120] convert to dict to str Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 17 +++++++++++++---- .../flytekitplugins/awssagemaker/boto3_agent.py | 5 +++-- .../flytekitplugins/awssagemaker/boto3_task.py | 1 + .../flytekitplugins/awssagemaker/task.py | 2 +- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 4ab2b0a71e..5f0f2e6e14 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -22,6 +22,7 @@ from flytekit import FlyteContextManager from .boto3_mixin import Boto3AgentMixin +from datetime import datetime states = { "Creating": "Running", @@ -36,6 +37,14 @@ class Metadata: region: str +class DateTimeEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, datetime): + return o.isoformat() + + return json.JSONEncoder.default(self, o) + + class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): """This agent creates an endpoint.""" @@ -83,15 +92,15 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - message = endpoint_status.get("FailureReason") res = None - if current_state == "Success": + if current_state == "InService": ctx = FlyteContextManager.current_context() res = LiteralMap( { "result": TypeEngine.to_literal( ctx, - endpoint_status, - dict, - TypeEngine.to_literal_type(dict), + json.dumps(endpoint_status, cls=DateTimeEncoder), + str, + TypeEngine.to_literal_type(str), ) } ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index e82ea6011e..6ba223e06f 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -36,6 +36,7 @@ async def async_create( config = custom["config"] region = custom["region"] method = custom["method"] + output_type = custom["output_type"] boto3_object = Boto3AgentMixin(service=service, region=region) result = await asyncio.wait_for( @@ -59,8 +60,8 @@ async def async_create( "result": TypeEngine.to_literal( ctx, result, - type(result), - TypeEngine.to_literal_type(type(result)), + output_type, + TypeEngine.to_literal_type(output_type), ) } ).to_flyte_idl() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index 33fe9b4398..27728dd8d9 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -45,4 +45,5 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: "config": self.task_config.config, "region": self.task_config.region, "method": self.task_config.method, + "output_type": self._output_type, } diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index a51da99d4c..5d38161683 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -105,7 +105,7 @@ def __init__( region=region, ), task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs={"result": dict[str, str]}), + interface=Interface(inputs=inputs, outputs={"result": str}), **kwargs, ) From f5931b7245725ee8f37f1615dcc887332bc7c8e9 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 22:20:16 +0530 Subject: [PATCH 027/120] revert Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 18 ++++++------------ .../flytekitplugins/awssagemaker/task.py | 2 +- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 5f0f2e6e14..0356614c03 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -22,7 +22,6 @@ from flytekit import FlyteContextManager from .boto3_mixin import Boto3AgentMixin -from datetime import datetime states = { "Creating": "Running", @@ -37,14 +36,6 @@ class Metadata: region: str -class DateTimeEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, datetime): - return o.isoformat() - - return json.JSONEncoder.default(self, o) - - class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): """This agent creates an endpoint.""" @@ -98,9 +89,12 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - { "result": TypeEngine.to_literal( ctx, - json.dumps(endpoint_status, cls=DateTimeEncoder), - str, - TypeEngine.to_literal_type(str), + { + "EndpointName": endpoint_status.get("EndpointName"), + "EndpointArn": endpoint_status.get("EndpointArn"), + }, + dict, + TypeEngine.to_literal_type(dict), ) } ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 5d38161683..1e06639898 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -105,7 +105,7 @@ def __init__( region=region, ), task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs={"result": str}), + interface=Interface(inputs=inputs, outputs={"result": dict}), **kwargs, ) From 9fec7f422def2fa86bae3467afc56442999ff287 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 23:04:41 +0530 Subject: [PATCH 028/120] remove timeout and add creds to boto3 calls Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 6 +++++ .../awssagemaker/boto3_agent.py | 22 +++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 0356614c03..3e1c44779c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -73,6 +73,9 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - method="describe_endpoint", config={"EndpointName": metadata.endpoint_name}, region=metadata.region, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), ) current_state = endpoint_status.get("EndpointStatus") @@ -108,6 +111,9 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes "delete_endpoint", config={"EndpointName": metadata.endpoint_name}, region=metadata.region, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), ) return DeleteTaskResponse() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 6ba223e06f..399f667ae9 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -11,12 +11,9 @@ get_agent_secret, ) from flytekit.models.literals import LiteralMap -import asyncio from .boto3_mixin import Boto3AgentMixin -TIMEOUT_SECONDS = 20 - class BotoAgent(AgentBase): """A general purpose boto3 agent that can be used to call any boto3 method.""" @@ -39,17 +36,14 @@ async def async_create( output_type = custom["output_type"] boto3_object = Boto3AgentMixin(service=service, region=region) - result = await asyncio.wait_for( - boto3_object._call( - method=method, - config=config, - container=task_template.container, - inputs=inputs, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), - ), - timeout=TIMEOUT_SECONDS, + result = await boto3_object._call( + method=method, + config=config, + container=task_template.container, + inputs=inputs, + aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), + aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), + aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), ) outputs = None From a548fe3cf713eab3015e607c16ed21738da7d9de Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 23:31:57 +0530 Subject: [PATCH 029/120] add to_flyte_idl Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 22 ++++++++++++------- .../flytekitplugins/awssagemaker/task.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 3e1c44779c..f063a17394 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -1,6 +1,7 @@ import json from dataclasses import asdict, dataclass from typing import Optional +from datetime import datetime import grpc from flyteidl.admin.agent_pb2 import ( @@ -30,6 +31,14 @@ } +class DateTimeEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, datetime): + return o.isoformat() + + return json.JSONEncoder.default(self, o) + + @dataclass class Metadata: endpoint_name: str @@ -81,7 +90,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - current_state = endpoint_status.get("EndpointStatus") flyte_state = convert_to_flyte_state(states[current_state]) - message = "" + message = None if current_state == "Failed": message = endpoint_status.get("FailureReason") @@ -92,15 +101,12 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - { "result": TypeEngine.to_literal( ctx, - { - "EndpointName": endpoint_status.get("EndpointName"), - "EndpointArn": endpoint_status.get("EndpointArn"), - }, - dict, - TypeEngine.to_literal_type(dict), + json.dumps(endpoint_status, cls=DateTimeEncoder), + str, + TypeEngine.to_literal_type(str), ) } - ) + ).to_flyte_idl() return GetTaskResponse(resource=Resource(state=flyte_state, outputs=res, message=message)) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 1e06639898..5d38161683 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -105,7 +105,7 @@ def __init__( region=region, ), task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs={"result": dict}), + interface=Interface(inputs=inputs, outputs={"result": str}), **kwargs, ) From d84e4672cb54fd5499e7616d5201392c77210dff Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 30 Jan 2024 23:38:33 +0530 Subject: [PATCH 030/120] subclass fix Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 5d38161683..53bb71efdd 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -206,13 +206,13 @@ def __init__( ) -class SagemakerInvokeEndpointTask(BotoConfig): +class SagemakerInvokeEndpointTask(BotoTask): def __init__( self, name: str, config: dict[str, Any], - output_type: Type, region: Optional[str], + output_type: Optional[Type] = None, inputs: Optional[dict[str, Type]] = None, **kwargs, ): From 97006c40b2b617b3f98031ebc94afb0505cb436d Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 31 Jan 2024 15:54:10 +0530 Subject: [PATCH 031/120] invoke endpoint async Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/task.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 53bb71efdd..4a8c261556 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -212,7 +212,6 @@ def __init__( name: str, config: dict[str, Any], region: Optional[str], - output_type: Optional[Type] = None, inputs: Optional[dict[str, Type]] = None, **kwargs, ): @@ -229,12 +228,12 @@ def __init__( name=name, task_config=BotoConfig( service="sagemaker-runtime", - method="invoke_endpoint", + method="invoke_endpoint_async", config=config, region=region, ), inputs=inputs, - output_type=dict[str, Union[str, output_type]], + output_type=dict[str, str], container_image=DefaultImages.default_image(), **kwargs, ) From ceab2dc968a622716691ac644d5ccd17293eb61d Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 31 Jan 2024 22:13:09 +0530 Subject: [PATCH 032/120] remove output type Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 5 ++--- .../flytekitplugins/awssagemaker/boto3_task.py | 5 +---- .../flytekitplugins/awssagemaker/task.py | 4 ---- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 399f667ae9..09295affb0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -33,7 +33,6 @@ async def async_create( config = custom["config"] region = custom["region"] method = custom["method"] - output_type = custom["output_type"] boto3_object = Boto3AgentMixin(service=service, region=region) result = await boto3_object._call( @@ -54,8 +53,8 @@ async def async_create( "result": TypeEngine.to_literal( ctx, result, - output_type, - TypeEngine.to_literal_type(output_type), + dict, + TypeEngine.to_literal_type(dict), ) } ).to_flyte_idl() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index 27728dd8d9..e4d931eeca 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -25,16 +25,14 @@ def __init__( name: str, task_config: BotoConfig, inputs: Optional[dict[str, Type]] = None, - output_type: Optional[Type] = None, container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): - self._output_type = output_type super().__init__( name=name, task_config=task_config, task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs={"result": output_type}), + interface=Interface(inputs=inputs, outputs={"result": dict}), container_image=container_image, **kwargs, ) @@ -45,5 +43,4 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: "config": self.task_config.config, "region": self.task_config.region, "method": self.task_config.method, - "output_type": self._output_type, } diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 4a8c261556..13cbf2a1e0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -35,7 +35,6 @@ def __init__( name=name, task_config=BotoConfig(service="sagemaker", method="create_model", config=config, region=region), inputs=inputs, - output_type=dict[str, str], container_image=container_image, **kwargs, ) @@ -67,7 +66,6 @@ def __init__( region=region, ), inputs=inputs, - output_type=dict[str, str], container_image=DefaultImages.default_image(), **kwargs, ) @@ -220,7 +218,6 @@ def __init__( :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. - :param output_type: The type of output. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. """ @@ -233,7 +230,6 @@ def __init__( region=region, ), inputs=inputs, - output_type=dict[str, str], container_image=DefaultImages.default_image(), **kwargs, ) From afed76a25ab0b4035a8bcf8151d8d6a698667daa Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 31 Jan 2024 22:46:08 +0530 Subject: [PATCH 033/120] modify create sagemaker deployment code Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/workflow.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index e36a9e91b9..cd7ebc4d1d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -48,21 +48,19 @@ def create_sagemaker_deployment( ) wf = Workflow(name=f"sagemaker-deploy-{model_name}") - wf.add_workflow_input("model_inputs", Optional[dict]) - wf.add_workflow_input("endpoint_config_inputs", Optional[dict]) - wf.add_workflow_input("endpoint_inputs", Optional[dict]) - wf.add_entity( - sagemaker_model_task, - **wf.inputs["model_inputs"], - ) - - wf.add_entity( - endpoint_config_task, - **wf.inputs["endpoint_config_inputs"], - ) + inputs = { + sagemaker_model_task: model_input_types, + endpoint_config_task: endpoint_config_input_types, + endpoint_task: endpoint_input_types, + } - wf.add_entity(endpoint_task, **wf.inputs["endpoint_inputs"]) + for key, value in inputs.items(): + input_dict = {} + for param, type in value: + wf.add_workflow_input(param, type) + input_dict[param] = wf.inputs[param] + wf.add_entity(key, **input_dict) return wf @@ -99,15 +97,15 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None): wf.add_entity( sagemaker_delete_endpoint, - **{"endpoint_name": wf.inputs["endpoint_name"]}, + endpoint_name=wf.inputs["endpoint_name"], ) wf.add_entity( sagemaker_delete_endpoint_config, - **{"endpoint_config_name": wf.inputs["endpoint_config_name"]}, + endpoint_config_name=wf.inputs["endpoint_config_name"], ) wf.add_entity( sagemaker_delete_model, - **{"model_name", wf.inputs["model_name"]}, + model_name=wf.inputs["model_name"], ) return wf From 29f875d2a8a1d6962993259ad95f9dfb764a75b7 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 31 Jan 2024 22:49:30 +0530 Subject: [PATCH 034/120] dict loop Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/workflow.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index cd7ebc4d1d..a120c3043d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -57,9 +57,10 @@ def create_sagemaker_deployment( for key, value in inputs.items(): input_dict = {} - for param, type in value: - wf.add_workflow_input(param, type) - input_dict[param] = wf.inputs[param] + if isinstance(value, dict): + for param, t in value.items(): + wf.add_workflow_input(param, t) + input_dict[param] = wf.inputs[param] wf.add_entity(key, **input_dict) return wf From 3452ca99109510f7ab673a7f70b7b36b3c71e448 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 1 Feb 2024 15:26:46 +0530 Subject: [PATCH 035/120] add wf output Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/workflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index a120c3043d..45afaa0081 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -55,14 +55,16 @@ def create_sagemaker_deployment( endpoint_task: endpoint_input_types, } + nodes = [] for key, value in inputs.items(): input_dict = {} if isinstance(value, dict): for param, t in value.items(): wf.add_workflow_input(param, t) input_dict[param] = wf.inputs[param] - wf.add_entity(key, **input_dict) + nodes.append(wf.add_entity(key, **input_dict)) + wf.add_workflow_output("wf_output", nodes[2].outputs["result"], str) return wf From b78b69c0f7a693ca7ce40c9571b350393241bd1f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Sat, 3 Feb 2024 22:31:42 +0530 Subject: [PATCH 036/120] set lhs to an empty string for pythoninstancetask & modify param name in create deployment task Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 6 +++++- .../flytekitplugins/awssagemaker/workflow.py | 10 +++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 240768cc82..45f0fbc7a2 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -10,6 +10,7 @@ from flytekit.configuration.feature_flags import FeatureFlags from flytekit.exceptions import system as _system_exceptions from flytekit.loggers import logger +from flytekit.core.python_function_task import PythonInstanceTask def import_module_from_file(module_name, file): @@ -316,7 +317,10 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, elif f.instantiated_in: mod = importlib.import_module(f.instantiated_in) mod_name = mod.__name__ - name = f.lhs + if isinstance(f, PythonInstanceTask): + name = "" + else: + name = f.lhs else: mod, mod_name, name = _task_module_from_callable(f) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index 45afaa0081..6cac22c598 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -12,7 +12,7 @@ def create_sagemaker_deployment( - model_name: str, + name: str, model_config: dict[str, Any], endpoint_config_config: dict[str, Any], endpoint_config: dict[str, Any], @@ -26,7 +26,7 @@ def create_sagemaker_deployment( Creates Sagemaker model, endpoint config and endpoint. """ sagemaker_model_task = SagemakerModelTask( - name=f"sagemaker-model-{model_name}", + name=f"sagemaker-model-{name}", config=model_config, region=region, inputs=model_input_types, @@ -34,20 +34,20 @@ def create_sagemaker_deployment( ) endpoint_config_task = SagemakerEndpointConfigTask( - name=f"sagemaker-endpoint-config-{model_name}", + name=f"sagemaker-endpoint-config-{name}", config=endpoint_config_config, region=region, inputs=endpoint_config_input_types, ) endpoint_task = SagemakerEndpointTask( - name=f"sagemaker-endpoint-{model_name}", + name=f"sagemaker-endpoint-{name}", config=endpoint_config, region=region, inputs=endpoint_input_types, ) - wf = Workflow(name=f"sagemaker-deploy-{model_name}") + wf = Workflow(name=f"sagemaker-deploy-{name}") inputs = { sagemaker_model_task: model_input_types, From 6a66480a6c412a16a8e11b03255f69d6fb2ad8db Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Sat, 3 Feb 2024 23:15:58 +0530 Subject: [PATCH 037/120] update tracker and delete deployment workflow Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 13 +++++++++++-- .../flytekitplugins/awssagemaker/workflow.py | 6 +++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 45f0fbc7a2..f8c7f8c17a 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -10,7 +10,6 @@ from flytekit.configuration.feature_flags import FeatureFlags from flytekit.exceptions import system as _system_exceptions from flytekit.loggers import logger -from flytekit.core.python_function_task import PythonInstanceTask def import_module_from_file(module_name, file): @@ -304,6 +303,16 @@ def _task_module_from_callable(f: Callable): return mod, mod_name, name +def isPythonInstance(obj): + for cls in inspect.getmro(type(obj)): + try: + if cls.__name__ == "PythonInstanceTask": + return True + except Exception: + pass + return False + + def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, str, str]: """ Returns the task-name, absolute module and the string name of the callable. @@ -317,7 +326,7 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, elif f.instantiated_in: mod = importlib.import_module(f.instantiated_in) mod_name = mod.__name__ - if isinstance(f, PythonInstanceTask): + if isPythonInstance(f): name = "" else: name = f.lhs diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index 6cac22c598..cdbafb759a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -74,21 +74,21 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None): """ sagemaker_delete_endpoint = SagemakerDeleteEndpointTask( name=f"sagemaker-delete-endpoint-{name}", - config={"EndpointName": "{endpoint_name}"}, + config={"EndpointName": "{inputs.endpoint_name}"}, region=region, inputs=kwtypes(endpoint_name=str), ) sagemaker_delete_endpoint_config = SagemakerDeleteEndpointConfigTask( name=f"sagemaker-delete-endpoint-config-{name}", - config={"EndpointConfigName": "{endpoint_config_name}"}, + config={"EndpointConfigName": "{inputs.endpoint_config_name}"}, region=region, inputs=kwtypes(endpoint_config_name=str), ) sagemaker_delete_model = SagemakerDeleteModelTask( name=f"sagemaker-delete-model-{name}", - config={"ModelName": "{model_name}"}, + config={"ModelName": "{inputs.model_name}"}, region=region, inputs=kwtypes(model_name=str), ) From 6f6839fc8ac1101dba32137c6f7ed54084d39f7c Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Sat, 3 Feb 2024 23:19:57 +0530 Subject: [PATCH 038/120] instance to instancetask Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index f8c7f8c17a..beedb086ea 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -303,7 +303,7 @@ def _task_module_from_callable(f: Callable): return mod, mod_name, name -def isPythonInstance(obj): +def isPythonInstanceTask(obj): for cls in inspect.getmro(type(obj)): try: if cls.__name__ == "PythonInstanceTask": @@ -326,7 +326,7 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, elif f.instantiated_in: mod = importlib.import_module(f.instantiated_in) mod_name = mod.__name__ - if isPythonInstance(f): + if isPythonInstanceTask(f): name = "" else: name = f.lhs From 8625e32cc8252bed2afc914bbdd222276f902a13 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 15:00:56 +0530 Subject: [PATCH 039/120] add tests Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/README.md | 58 +++++- .../dev-requirements.txt | 2 + .../flytekitplugins/awssagemaker/__init__.py | 2 +- .../flytekitplugins/awssagemaker/agent.py | 4 +- .../awssagemaker/boto3_agent.py | 3 +- .../awssagemaker/boto3_task.py | 5 +- .../flytekitplugins/awssagemaker/task.py | 7 +- plugins/flytekit-aws-sagemaker/setup.py | 2 +- .../tests/test_agent.py | 120 ++++++++++++ .../tests/test_boto3_agent.py | 99 ++++++++++ .../tests/test_boto3_task.py | 52 ++++++ .../flytekit-aws-sagemaker/tests/test_task.py | 174 ++++++++++++++++++ .../tests/test_workflow.py | 55 ++++++ 13 files changed, 573 insertions(+), 10 deletions(-) create mode 100644 plugins/flytekit-aws-sagemaker/dev-requirements.txt create mode 100644 plugins/flytekit-aws-sagemaker/tests/test_agent.py create mode 100644 plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py create mode 100644 plugins/flytekit-aws-sagemaker/tests/test_task.py create mode 100644 plugins/flytekit-aws-sagemaker/tests/test_workflow.py diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index 33cd38afef..37cf1f10f2 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -1,9 +1,65 @@ # AWS Sagemaker Plugin -The plugin includes a deployment agent that allows you to deploy Sagemaker models, create and inkoke endpoints for inference. +The plugin features a deployment agent enabling you to deploy SageMaker models, create and trigger inference endpoints. +Additionally, you can entirely remove the SageMaker deployment using the `delete_sagemaker_deployment` workflow. To install the plugin, run the following command: ```bash pip install flytekitplugins-awssagemaker ``` + +Here is a sample SageMaker deployment workflow: + +```python +REGION = os.getenv("REGION") +MODEL_NAME = "sagemaker-xgboost" +ENDPOINT_CONFIG_NAME = "sagemaker-xgboost-endpoint-config" +ENDPOINT_NAME = "sagemaker-xgboost-endpoint" + +sagemaker_deployment_wf = create_sagemaker_deployment( + name="sagemaker-deployment", + model_input_types=kwtypes(model_path=str, execution_role_arn=str), + model_config={ + "ModelName": MODEL_NAME, + "PrimaryContainer": { + "Image": "{container.image}", + "ModelDataUrl": "{inputs.model_path}", + }, + "ExecutionRoleArn": "{inputs.execution_role_arn}", + }, + endpoint_config_input_types=kwtypes(instance_type=str), + endpoint_config_config={ + "EndpointConfigName": ENDPOINT_CONFIG_NAME, + "ProductionVariants": [ + { + "VariantName": "variant-name-1", + "ModelName": MODEL_NAME, + "InitialInstanceCount": 1, + "InstanceType": "{inputs.instance_type}", + }, + ], + "AsyncInferenceConfig": { + "OutputConfig": {"S3OutputPath": os.getenv("S3_OUTPUT_PATH")} + }, + }, + endpoint_config={ + "EndpointName": ENDPOINT_NAME, + "EndpointConfigName": ENDPOINT_CONFIG_NAME, + }, + container_image=custom_image, + region=REGION, +) + + +@workflow +def model_deployment_workflow( + model_path: str = os.getenv("MODEL_DATA_URL"), + execution_role_arn: str = os.getenv("EXECUTION_ROLE_ARN"), +) -> str: + return sagemaker_deployment_wf( + model_path=model_path, + execution_role_arn=execution_role_arn, + instance_type="ml.m4.xlarge", + ) +``` diff --git a/plugins/flytekit-aws-sagemaker/dev-requirements.txt b/plugins/flytekit-aws-sagemaker/dev-requirements.txt new file mode 100644 index 0000000000..c63a6da9cb --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/dev-requirements.txt @@ -0,0 +1,2 @@ +pytest-asyncio +pytest-mock diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index e053f3dacd..e3c2e33ba7 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -30,5 +30,5 @@ SagemakerModelTask, ) from .boto3_agent import BotoAgent -from .boto3_task import BotoTask +from .boto3_task import BotoConfig, BotoTask from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index f063a17394..b28ba57029 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -14,7 +14,7 @@ from flytekit.extend.backend.base_agent import ( AgentBase, - convert_to_flyte_state, + convert_to_flyte_phase, get_agent_secret, AgentRegistry, ) @@ -88,7 +88,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - ) current_state = endpoint_status.get("EndpointStatus") - flyte_state = convert_to_flyte_state(states[current_state]) + flyte_state = convert_to_flyte_phase(states[current_state]) message = None if current_state == "Failed": diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 09295affb0..d7f2bf2a0c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -19,7 +19,7 @@ class BotoAgent(AgentBase): """A general purpose boto3 agent that can be used to call any boto3 method.""" def __init__(self): - super().__init__(task_type="sync-boto") + super().__init__(task_type="boto") async def async_create( self, @@ -35,6 +35,7 @@ async def async_create( method = custom["method"] boto3_object = Boto3AgentMixin(service=service, region=region) + result = await boto3_object._call( method=method, config=config, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index e4d931eeca..750ec3461d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -18,13 +18,14 @@ class BotoConfig(object): class BotoTask(AsyncAgentExecutorMixin, PythonInstanceTask[BotoConfig]): - _TASK_TYPE = "sync-boto" + _TASK_TYPE = "boto" def __init__( self, name: str, task_config: BotoConfig, inputs: Optional[dict[str, Type]] = None, + outputs: Optional[dict[str, Type]] = None, container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): @@ -32,7 +33,7 @@ def __init__( name=name, task_config=task_config, task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs={"result": dict}), + interface=Interface(inputs=inputs, outputs=outputs), container_image=container_image, **kwargs, ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 13cbf2a1e0..2724a4f8db 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -7,7 +7,7 @@ from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from .boto3_task import BotoTask, BotoConfig -from flytekit import ImageSpec +from flytekit import ImageSpec, kwtypes class SagemakerModelTask(BotoTask): @@ -35,6 +35,7 @@ def __init__( name=name, task_config=BotoConfig(service="sagemaker", method="create_model", config=config, region=region), inputs=inputs, + outputs=kwtypes(result=dict), container_image=container_image, **kwargs, ) @@ -66,6 +67,7 @@ def __init__( region=region, ), inputs=inputs, + outputs=kwtypes(result=dict), container_image=DefaultImages.default_image(), **kwargs, ) @@ -103,7 +105,7 @@ def __init__( region=region, ), task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs={"result": str}), + interface=Interface(inputs=inputs, outputs=kwtypes(result=str)), **kwargs, ) @@ -230,6 +232,7 @@ def __init__( region=region, ), inputs=inputs, + outputs=kwtypes(result=dict), container_image=DefaultImages.default_image(), **kwargs, ) diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index a328177c0d..8f7d4ee930 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" # s3fs 2023.9.2 requires aiobotocore~=2.5.4 -plugin_requires = ["flytekit>=1.10.0", "flyteidl>=1.10.7b0", "aioboto3==11.1.1"] +plugin_requires = ["flytekit>=1.10.0", "flyteidl", "aioboto3==11.1.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-sagemaker/tests/test_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_agent.py new file mode 100644 index 0000000000..bbcdefb7f1 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/tests/test_agent.py @@ -0,0 +1,120 @@ +from datetime import timedelta +from unittest import mock + +import pytest +import json +from dataclasses import asdict + +from flytekit import FlyteContext, FlyteContextManager +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate +from flytekitplugins.awssagemaker.agent import Metadata +from flyteidl.admin.agent_pb2 import RUNNING, DeleteTaskResponse + + +@pytest.mark.asyncio +@mock.patch( + "flytekitplugins.awssagemaker.agent.get_agent_secret", + return_value="mocked_secret", +) +@mock.patch( + "flytekitplugins.awssagemaker.agent.Boto3AgentMixin._call", + return_value={ + "EndpointName": "sagemaker-xgboost-endpoint", + "EndpointArn": "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint", + "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "ProductionVariants": [ + { + "VariantName": "variant-name-1", + "DeployedImages": [ + { + "SpecifiedImage": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:iL3_jIEY3lQPB4wnlS7HKA..", + "ResolvedImage": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost@sha256:0725042bf15f384c46e93bbf7b22c0502859981fc8830fd3aea4127469e8cf1e", + "ResolutionTime": "2024-01-31T22:14:07.193000+05:30", + } + ], + "CurrentWeight": 1.0, + "DesiredWeight": 1.0, + "CurrentInstanceCount": 1, + "DesiredInstanceCount": 1, + } + ], + "EndpointStatus": "InService", + "CreationTime": "2024-01-31T22:14:06.553000+05:30", + "LastModifiedTime": "2024-01-31T22:16:37.001000+05:30", + "AsyncInferenceConfig": { + "OutputConfig": {"S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output"} + }, + "ResponseMetadata": { + "RequestId": "50d8bfa7-d84-4bd9-bf11-992832f42793", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "50d8bfa7-d840-4bd9-bf11-992832f42793", + "content-type": "application/x-amz-json-1.1", + "content-length": "865", + "date": "Wed, 31 Jan 2024 16:46:38 GMT", + }, + "RetryAttempts": 0, + }, + }, +) +async def test_agent(mock_boto_call, mock_secret): + ctx = FlyteContextManager.current_context() + agent = AgentRegistry.get_agent("sagemaker-endpoint") + task_id = Identifier( + resource_type=ResourceType.TASK, + project="project", + domain="domain", + name="name", + version="version", + ) + task_config = { + "service": "sagemaker", + "config": { + "EndpointName": "sagemaker-endpoint", + "EndpointConfigName": "endpoint-config-name", + }, + "region": "us-east-2", + "method": "create_endpoint", + } + task_metadata = TaskMetadata( + discoverable=True, + runtime=RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timeout=timedelta(days=1), + retries=literals.RetryStrategy(3), + interruptible=True, + discovery_version="0.1.1b0", + deprecated_error_message="This is deprecated!", + cache_serializable=True, + pod_template_name="A", + ) + + task_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=None, + type="sagemaker-endpoint", + ) + output_prefix = FlyteContext.current_context().file_access.get_random_local_directory() + + # CREATE + response = await agent.async_create(ctx, output_prefix, task_template) + + metadata = Metadata(endpoint_name="sagemaker-endpoint", region="us-east-2") + metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") + assert response.resource_meta == metadata_bytes + + # GET + response = await agent.async_get(ctx, metadata_bytes) + assert response.resource.state == RUNNING + from_json = json.loads(response.resource.outputs.literals["result"].scalar.primitive.string_value) + assert from_json["EndpointName"] == "sagemaker-xgboost-endpoint" + assert from_json["EndpointArn"] == "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint" + + # DELETE + delete_response = await agent.async_delete(ctx, metadata_bytes) + assert isinstance(delete_response, DeleteTaskResponse) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index e69de29bb2..c9a7fb8c51 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -0,0 +1,99 @@ +from datetime import timedelta +from unittest import mock + +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED + +from flytekit import FlyteContext, FlyteContextManager +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate + + +@pytest.mark.asyncio +@mock.patch( + "flytekitplugins.awssagemaker.boto3_agent.get_agent_secret", + return_value="mocked_secret", +) +@mock.patch( + "flytekitplugins.awssagemaker.boto3_agent.Boto3AgentMixin._call", + return_value={ + "ResponseMetadata": { + "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", + "HTTPStatusCode": 200.0, + "RetryAttempts": 0.0, + "HTTPHeaders": { + "content-type": "application/x-amz-json-1.1", + "date": "Wed, 31 Jan 2024 16:43:52 GMT", + "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", + "content-length": "114", + }, + }, + "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", + }, +) +async def test_agent(mock_boto_call, mock_secret): + ctx = FlyteContextManager.current_context() + agent = AgentRegistry.get_agent("boto") + task_id = Identifier( + resource_type=ResourceType.TASK, + project="project", + domain="domain", + name="name", + version="version", + ) + task_config = { + "service": "sagemaker", + "config": { + "EndpointConfigName": "endpoint-config-name", + "ProductionVariants": [ + { + "VariantName": "variant-name-1", + "ModelName": "{inputs.model_name}", + "InitialInstanceCount": 1, + "InstanceType": "ml.m4.xlarge", + }, + ], + "AsyncInferenceConfig": {"OutputConfig": {"S3OutputPath": "{inputs.s3_output_path}"}}, + }, + "region": "us-east-2", + "method": "create_endpoint_config", + } + task_metadata = TaskMetadata( + discoverable=True, + runtime=RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timeout=timedelta(days=1), + retries=literals.RetryStrategy(3), + interruptible=True, + discovery_version="0.1.1b0", + deprecated_error_message="This is deprecated!", + cache_serializable=True, + pod_template_name="A", + ) + + task_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=None, + type="boto", + ) + task_inputs = literals.LiteralMap( + { + "model_name": literals.Literal( + scalar=literals.Scalar(primitive=literals.Primitive(string_value="sagemaker-model")) + ), + "s3_output_path": literals.Literal( + scalar=literals.Scalar(primitive=literals.Primitive(string_value="s3-output-path")) + ), + }, + ) + output_prefix = FlyteContext.current_context().file_access.get_random_local_directory() + + response = await agent.async_create(ctx, output_prefix, task_template, task_inputs) + + assert response.HasField("resource") + assert response.resource.state == SUCCEEDED + assert response.resource.outputs is not None diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py new file mode 100644 index 0000000000..e5fe6f32c7 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py @@ -0,0 +1,52 @@ +from flytekitplugins.awssagemaker import BotoConfig, BotoTask + +from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig, SerializationSettings + + +def test_boto_task_and_config(): + boto_task = BotoTask( + name="boto_task", + task_config=BotoConfig( + service="sagemaker", + method="create_model", + config={ + "ModelName": "{inputs.model_name}", + "PrimaryContainer": { + "Image": "{container.image}", + "ModelDataUrl": "{inputs.model_data_url}", + }, + "ExecutionRoleArn": "{inputs.execution_role_arn}", + }, + region="us-east-2", + ), + inputs=kwtypes(model_name=str, model_data_url=str, execution_role_arn=str), + outputs=kwtypes(result=dict), + container_image="1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost", + ) + + assert len(boto_task.interface.inputs) == 3 + assert len(boto_task.interface.outputs) == 1 + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + + retrieved_setttings = boto_task.get_custom(serialization_settings) + + assert retrieved_setttings["service"] == "sagemaker" + assert retrieved_setttings["config"] == { + "ModelName": "{inputs.model_name}", + "PrimaryContainer": { + "Image": "{container.image}", + "ModelDataUrl": "{inputs.model_data_url}", + }, + "ExecutionRoleArn": "{inputs.execution_role_arn}", + } + assert retrieved_setttings["region"] == "us-east-2" + assert retrieved_setttings["method"] == "create_model" diff --git a/plugins/flytekit-aws-sagemaker/tests/test_task.py b/plugins/flytekit-aws-sagemaker/tests/test_task.py new file mode 100644 index 0000000000..3ee270275d --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/tests/test_task.py @@ -0,0 +1,174 @@ +from flytekitplugins.awssagemaker import ( + SagemakerModelTask, + SagemakerDeleteEndpointConfigTask, + SagemakerDeleteEndpointTask, + SagemakerDeleteModelTask, + SagemakerEndpointConfigTask, + SagemakerEndpointTask, + SagemakerInvokeEndpointTask, +) + +import pytest +from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig, SerializationSettings + + +@pytest.mark.parametrize( + "name,config,service,method,inputs,container_image,no_of_inputs,no_of_outputs,region,task", + [ + ( + "sagemaker_model", + { + "ModelName": "{inputs.model_name}", + "PrimaryContainer": { + "Image": "{container.image}", + "ModelDataUrl": "{inputs.model_data_url}", + }, + "ExecutionRoleArn": "{inputs.execution_role_arn}", + }, + "sagemaker", + "create_model", + kwtypes(model_name=str, model_data_url=str, execution_role_arn=str), + "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost", + 3, + 1, + "us-east-2", + SagemakerModelTask, + ), + ( + "sagemaker_endpoint_config", + { + "EndpointConfigName": "{inputs.endpoint_config_name}", + "ProductionVariants": [ + { + "VariantName": "variant-name-1", + "ModelName": "{inputs.model_name}", + "InitialInstanceCount": 1, + "InstanceType": "ml.m4.xlarge", + }, + ], + "AsyncInferenceConfig": {"OutputConfig": {"S3OutputPath": "{inputs.s3_output_path}"}}, + }, + "sagemaker", + "create_endpoint_config", + kwtypes(endpoint_config_name=str, model_name=str, s3_output_path=str), + None, + 3, + 1, + "us-east-2", + SagemakerEndpointConfigTask, + ), + ( + "sagemaker_endpoint", + { + "EndpointName": "{inputs.endpoint_name}", + "EndpointConfigName": "{inputs.endpoint_config_name}", + }, + None, + None, + kwtypes(endpoint_name=str, endpoint_config_name=str), + None, + 2, + 1, + "us-east-2", + SagemakerEndpointTask, + ), + ( + "sagemaker_delete_endpoint", + {"EndpointName": "{inputs.endpoint_name}"}, + "sagemaker", + "delete_endpoint", + kwtypes(endpoint_name=str), + None, + 1, + 0, + "us-east-2", + SagemakerDeleteEndpointTask, + ), + ( + "sagemaker_delete_endpoint_config", + {"EndpointConfigName": "{inputs.endpoint_config_name}"}, + "sagemaker", + "delete_endpoint_config", + kwtypes(endpoint_config_name=str), + None, + 1, + 0, + "us-east-2", + SagemakerDeleteEndpointConfigTask, + ), + ( + "sagemaker_delete_model", + {"ModelName": "{inputs.model_name}"}, + "sagemaker", + "delete_model", + kwtypes(model_name=str), + None, + 1, + 0, + "us-east-2", + SagemakerDeleteModelTask, + ), + ( + "sagemaker_invoke_endpoint", + { + "EndpointName": "{inputs.endpoint_name}", + "InputLocation": "s3://sagemaker-agent-xgboost/inference_input", + }, + "sagemaker-runtime", + "invoke_endpoint_async", + kwtypes(endpoint_name=str), + None, + 1, + 1, + "us-east-2", + SagemakerInvokeEndpointTask, + ), + ], +) +def test_sagemaker_task( + name, + config, + service, + method, + inputs, + container_image, + no_of_inputs, + no_of_outputs, + region, + task, +): + if container_image: + sagemaker_task = task( + name=name, + config=config, + region=region, + inputs=inputs, + container_image=container_image, + ) + else: + sagemaker_task = task( + name=name, + config=config, + region=region, + inputs=inputs, + ) + + assert len(sagemaker_task.interface.inputs) == no_of_inputs + assert len(sagemaker_task.interface.outputs) == no_of_outputs + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="123", + image_config=ImageConfig(default_image=default_img, images=[default_img]), + env={}, + ) + + retrieved_settings = sagemaker_task.get_custom(serialization_settings) + + assert retrieved_settings.get("service") == service + assert retrieved_settings["config"] == config + assert retrieved_settings["region"] == region + assert retrieved_settings.get("method") == method diff --git a/plugins/flytekit-aws-sagemaker/tests/test_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_workflow.py new file mode 100644 index 0000000000..1e57a38526 --- /dev/null +++ b/plugins/flytekit-aws-sagemaker/tests/test_workflow.py @@ -0,0 +1,55 @@ +from flytekitplugins.awssagemaker import ( + create_sagemaker_deployment, + delete_sagemaker_deployment, +) +from flytekit import kwtypes + + +def test_sagemaker_deployment_workflow(): + sagemaker_deployment_wf = create_sagemaker_deployment( + name="sagemaker-deployment", + model_input_types=kwtypes(model_path=str, execution_role_arn=str), + model_config={ + "ModelName": "sagemaker-xgboost", + "PrimaryContainer": { + "Image": "{container.image}", + "ModelDataUrl": "{inputs.model_path}", + }, + "ExecutionRoleArn": "{inputs.execution_role_arn}", + }, + endpoint_config_input_types=kwtypes(instance_type=str), + endpoint_config_config={ + "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "ProductionVariants": [ + { + "VariantName": "variant-name-1", + "ModelName": "sagemaker-xgboost", + "InitialInstanceCount": 1, + "InstanceType": "{inputs.instance_type}", + }, + ], + "AsyncInferenceConfig": { + "OutputConfig": {"S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output"} + }, + }, + endpoint_config={ + "EndpointName": "sagemaker-xgboost-endpoint", + "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + }, + container_image="1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost", + region="us-east-2", + ) + + assert len(sagemaker_deployment_wf.interface.inputs) == 3 + assert len(sagemaker_deployment_wf.interface.outputs) == 1 + assert len(sagemaker_deployment_wf.nodes) == 3 + + +def test_sagemaker_deployment_deletion_workflow(): + sagemaker_deployment_deletion_wf = delete_sagemaker_deployment( + name="sagemaker-deployment-deletion", region="us-east-2" + ) + + assert len(sagemaker_deployment_deletion_wf.interface.inputs) == 3 + assert len(sagemaker_deployment_deletion_wf.interface.outputs) == 0 + assert len(sagemaker_deployment_deletion_wf.nodes) == 3 From 9f65a99c63ccd105a4fc9d01b51c5392c86625ba Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 15:25:30 +0530 Subject: [PATCH 040/120] ruff isort Signed-off-by: Samhita Alla --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e91d21bd1..037724f5a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.1.6 + rev: v0.2.2 hooks: # Run the linter. - id: ruff From 0840f5137385e76e5875f7316f50fe966377ca00 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 15:28:27 +0530 Subject: [PATCH 041/120] ruff isort Signed-off-by: Samhita Alla --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 037724f5a0..7d51221146 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: hooks: # Run the linter. - id: ruff - args: [--fix, --show-fixes, --show-source] + args: [--fix, --show-fixes, --output-format=full] # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/pyproject.toml b/pyproject.toml index 89d046fc7d..30fa79b943 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ branch = true [tool.ruff] line-length = 120 select = ["E", "W", "F", "I"] -ignore = [ +lint.ignore = [ # Whitespace before '{symbol}' "E203", # Too many leading # before block comment From e8622ec799a7f379a39a2f7244c728932d7d7c6c Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 15:29:30 +0530 Subject: [PATCH 042/120] isort Signed-off-by: Samhita Alla --- flytekit/core/condition.py | 2 +- .../flytekitplugins/airflow/task.py | 2 +- .../flytekitplugins/awssagemaker/__init__.py | 6 +++--- .../flytekitplugins/awssagemaker/agent.py | 10 +++++----- .../flytekitplugins/awssagemaker/boto3_agent.py | 1 + .../flytekitplugins/awssagemaker/boto3_mixin.py | 2 +- .../flytekitplugins/awssagemaker/boto3_task.py | 3 +-- .../flytekitplugins/awssagemaker/task.py | 8 ++++---- .../flytekitplugins/awssagemaker/workflow.py | 13 +++++++------ plugins/flytekit-aws-sagemaker/tests/test_agent.py | 8 ++++---- .../tests/test_boto3_mixin.py | 3 ++- plugins/flytekit-aws-sagemaker/tests/test_task.py | 4 ++-- .../flytekit-aws-sagemaker/tests/test_workflow.py | 1 + 13 files changed, 33 insertions(+), 30 deletions(-) diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index bc7b4df865..50403574c1 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -428,7 +428,7 @@ def transform_to_comp_expr(expr: ComparisonExpression) -> Tuple[_core_cond.Compa def transform_to_boolexpr( - expr: Union[ComparisonExpression, ConjunctionExpression] + expr: Union[ComparisonExpression, ConjunctionExpression], ) -> Tuple[_core_cond.BooleanExpression, typing.List[Promise]]: if isinstance(expr, ConjunctionExpression): cexpr, promises = transform_to_conj_expr(expr) diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index 17a023dfdb..cf8f992ad9 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -137,7 +137,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def _get_airflow_instance( - airflow_obj: AirflowObj + airflow_obj: AirflowObj, ) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: # Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original # airflow task instead of the Flyte task. diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index e3c2e33ba7..68cbe0b41b 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -20,15 +20,15 @@ """ from .agent import SagemakerEndpointAgent +from .boto3_agent import BotoAgent +from .boto3_task import BotoConfig, BotoTask from .task import ( SagemakerDeleteEndpointConfigTask, SagemakerDeleteEndpointTask, SagemakerDeleteModelTask, - SagemakerEndpointTask, SagemakerEndpointConfigTask, + SagemakerEndpointTask, SagemakerInvokeEndpointTask, SagemakerModelTask, ) -from .boto3_agent import BotoAgent -from .boto3_task import BotoConfig, BotoTask from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index b28ba57029..8a313b62c8 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -1,7 +1,7 @@ import json from dataclasses import asdict, dataclass -from typing import Optional from datetime import datetime +from typing import Optional import grpc from flyteidl.admin.agent_pb2 import ( @@ -12,17 +12,17 @@ ) from flyteidl.core.tasks_pb2 import TaskTemplate +from flytekit import FlyteContextManager +from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentBase, + AgentRegistry, convert_to_flyte_phase, get_agent_secret, - AgentRegistry, ) from flytekit.models.literals import LiteralMap -from flytekit.core.type_engine import TypeEngine -from flytekit import FlyteContextManager -from .boto3_mixin import Boto3AgentMixin +from .boto3_mixin import Boto3AgentMixin states = { "Creating": "Running", diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index d7f2bf2a0c..18a729a276 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -3,6 +3,7 @@ import grpc from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource from flyteidl.core.tasks_pb2 import TaskTemplate + from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py index 917a6f8cf7..2d3f358fa0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py @@ -3,8 +3,8 @@ import aioboto3 from flytekit.interaction.string_literals import literal_map_string_repr -from flytekit.models.literals import LiteralMap from flytekit.models import task as _task_model +from flytekit.models.literals import LiteralMap def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index 750ec3461d..e4ab88ced4 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -1,11 +1,10 @@ from dataclasses import dataclass from typing import Any, Optional, Type, Union - from flytekit import ImageSpec from flytekit.configuration import SerializationSettings -from flytekit.core.python_function_task import PythonInstanceTask from flytekit.core.interface import Interface +from flytekit.core.python_function_task import PythonInstanceTask from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 2724a4f8db..45f612e0b5 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from typing import Any, Optional, Type, Union - -from flytekit.configuration import SerializationSettings, DefaultImages +from flytekit import ImageSpec, kwtypes +from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin -from .boto3_task import BotoTask, BotoConfig -from flytekit import ImageSpec, kwtypes + +from .boto3_task import BotoConfig, BotoTask class SagemakerModelTask(BotoTask): diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index cdbafb759a..a65ccf2697 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -1,15 +1,16 @@ -from flytekit import Workflow, kwtypes, ImageSpec +from typing import Any, Optional, Type, Union + +from flytekit import ImageSpec, Workflow, kwtypes + from .task import ( - SagemakerModelTask, - SagemakerEndpointConfigTask, - SagemakerDeleteEndpointTask, SagemakerDeleteEndpointConfigTask, + SagemakerDeleteEndpointTask, SagemakerDeleteModelTask, + SagemakerEndpointConfigTask, SagemakerEndpointTask, + SagemakerModelTask, ) -from typing import Any, Optional, Union, Type - def create_sagemaker_deployment( name: str, diff --git a/plugins/flytekit-aws-sagemaker/tests/test_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_agent.py index bbcdefb7f1..c75b9a1cc8 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_agent.py @@ -1,9 +1,11 @@ +import json +from dataclasses import asdict from datetime import timedelta from unittest import mock import pytest -import json -from dataclasses import asdict +from flyteidl.admin.agent_pb2 import RUNNING, DeleteTaskResponse +from flytekitplugins.awssagemaker.agent import Metadata from flytekit import FlyteContext, FlyteContextManager from flytekit.extend.backend.base_agent import AgentRegistry @@ -11,8 +13,6 @@ from flytekit.models import literals from flytekit.models.core.identifier import ResourceType from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate -from flytekitplugins.awssagemaker.agent import Metadata -from flyteidl.admin.agent_pb2 import RUNNING, DeleteTaskResponse @pytest.mark.asyncio diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 2d08fded41..5b95d02f0b 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -1,10 +1,11 @@ import typing +from flytekitplugins.awssagemaker.boto3_mixin import update_dict_fn + from flytekit import FlyteContext, StructuredDataset from flytekit.core.type_engine import TypeEngine from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.types.file import FlyteFile -from flytekitplugins.awssagemaker.boto3_mixin import update_dict_fn def test_inputs(): diff --git a/plugins/flytekit-aws-sagemaker/tests/test_task.py b/plugins/flytekit-aws-sagemaker/tests/test_task.py index 3ee270275d..8d0863cd58 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_task.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_task.py @@ -1,14 +1,14 @@ +import pytest from flytekitplugins.awssagemaker import ( - SagemakerModelTask, SagemakerDeleteEndpointConfigTask, SagemakerDeleteEndpointTask, SagemakerDeleteModelTask, SagemakerEndpointConfigTask, SagemakerEndpointTask, SagemakerInvokeEndpointTask, + SagemakerModelTask, ) -import pytest from flytekit import kwtypes from flytekit.configuration import Image, ImageConfig, SerializationSettings diff --git a/plugins/flytekit-aws-sagemaker/tests/test_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_workflow.py index 1e57a38526..96002b65a2 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_workflow.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_workflow.py @@ -2,6 +2,7 @@ create_sagemaker_deployment, delete_sagemaker_deployment, ) + from flytekit import kwtypes From 131b488be181044f455c4a5f0b31f1278a8849fb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 18:17:55 +0530 Subject: [PATCH 043/120] add test Signed-off-by: Samhita Alla --- .github/workflows/pythonbuild.yml | 7 +++-- tests/flytekit/unit/core/tracker/e.py | 8 +++++ .../unit/core/tracker/test_tracking.py | 29 ++++++++++++++++--- 3 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 tests/flytekit/unit/core/tracker/e.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index afa22ddde7..0b5e60a731 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -63,9 +63,9 @@ jobs: strategy: fail-fast: false matrix: - os: [ ubuntu-latest ] - python-version: [ "3.11" ] - pandas: [ "pandas<2.0.0", "pandas>=2.0.0" ] + os: [ubuntu-latest] + python-version: ["3.11"] + pandas: ["pandas<2.0.0", "pandas>=2.0.0"] steps: - uses: insightsengineering/disk-space-reclaimer@v1 - uses: actions/checkout@v4 @@ -186,6 +186,7 @@ jobs: - flytekit-async-fsspec - flytekit-aws-athena - flytekit-aws-batch + - flytekit-aws-sagemaker # TODO: uncomment this when the sagemaker agent is implemented: https://github.com/flyteorg/flyte/issues/4079 # - flytekit-aws-sagemaker - flytekit-bigquery diff --git a/tests/flytekit/unit/core/tracker/e.py b/tests/flytekit/unit/core/tracker/e.py new file mode 100644 index 0000000000..f724df6d61 --- /dev/null +++ b/tests/flytekit/unit/core/tracker/e.py @@ -0,0 +1,8 @@ +from flytekit.core.python_function_task import PythonInstanceTask + + +class E(PythonInstanceTask): + ... + + +e_instantiated = E(name="e-instantiated", task_config={}) diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index b33725436d..19159a348e 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -8,6 +8,7 @@ from tests.flytekit.unit.core.tracker import d from tests.flytekit.unit.core.tracker.b import b_local_a, local_b from tests.flytekit.unit.core.tracker.c import b_in_c, c_local_a +from tests.flytekit.unit.core.tracker.e import e_instantiated def test_tracking(): @@ -50,16 +51,31 @@ def convert_to_test(d: dict) -> typing.Tuple[typing.List[str], typing.List]: "core.task": (task, ("flytekit.core.task.task", "flytekit.core.task", "task")), "current-mod-tasks": ( d.tasks, - ("tests.flytekit.unit.core.tracker.d.tasks", "tests.flytekit.unit.core.tracker.d", "tasks"), + ( + "tests.flytekit.unit.core.tracker.d.tasks", + "tests.flytekit.unit.core.tracker.d", + "tasks", + ), + ), + "tasks-core-task": ( + d.task, + ("flytekit.core.task.task", "flytekit.core.task", "task"), ), - "tasks-core-task": (d.task, ("flytekit.core.task.task", "flytekit.core.task", "task")), "tracked-local": ( local_b, - ("tests.flytekit.unit.core.tracker.b.local_b", "tests.flytekit.unit.core.tracker.b", "local_b"), + ( + "tests.flytekit.unit.core.tracker.b.local_b", + "tests.flytekit.unit.core.tracker.b", + "local_b", + ), ), "tracked-b-in-c": ( b_in_c, - ("tests.flytekit.unit.core.tracker.c.b_in_c", "tests.flytekit.unit.core.tracker.c", "b_in_c"), + ( + "tests.flytekit.unit.core.tracker.c.b_in_c", + "tests.flytekit.unit.core.tracker.c", + "b_in_c", + ), ), } ) @@ -81,6 +97,11 @@ def test_extract_task_module(test_input, expected): raise +def test_extract_task_module_with_python_instance_task(): + _, _, name, _ = extract_task_module(e_instantiated) + assert name == "" + + local_task = task(d.inner_function) From 7bf91cc594fc50183f2364bb2f9c05a4138cc7d0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 18:30:19 +0530 Subject: [PATCH 044/120] pin greatexpectations version Signed-off-by: Samhita Alla --- .../great_expectations/schema.py | 28 +++++++++++++++---- plugins/flytekit-greatexpectations/setup.py | 2 +- pyproject.toml | 4 +-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py index 3413cdcdd3..e12fa8a99a 100644 --- a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py +++ b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py @@ -88,7 +88,9 @@ def __init__(self): super().__init__(name="GreatExpectations Transformer", t=GreatExpectationsType) @staticmethod - def get_config(t: Type[GreatExpectationsType]) -> Tuple[Type, GreatExpectationsFlyteConfig]: + def get_config( + t: Type[GreatExpectationsType], + ) -> Tuple[Type, GreatExpectationsFlyteConfig]: return t.config() def get_literal_type(self, t: Type[GreatExpectationsType]) -> LiteralType: @@ -138,13 +140,20 @@ def _flyte_schema( # copy parquet file to user-given directory if lv.scalar.structured_dataset: - ctx.file_access.get_data(lv.scalar.structured_dataset.uri, ge_conf.local_file_path, is_multipart=True) + ctx.file_access.get_data( + lv.scalar.structured_dataset.uri, + ge_conf.local_file_path, + is_multipart=True, + ) else: ctx.file_access.get_data(lv.scalar.schema.uri, ge_conf.local_file_path, is_multipart=True) temp_dataset = os.path.basename(ge_conf.local_file_path) - return FlyteSchemaTransformer().to_python_value(ctx, lv, expected_python_type), temp_dataset + return ( + FlyteSchemaTransformer().to_python_value(ctx, lv, expected_python_type), + temp_dataset, + ) def _flyte_file( self, @@ -199,7 +208,12 @@ def to_python_value( context = ge.data_context.DataContext(ge_conf.context_root_dir) # type: ignore # determine the type of data connector - selected_datasource = list(filter(lambda x: x["name"] == ge_conf.datasource_name, context.list_datasources())) + selected_datasource = list( + filter( + lambda x: x["name"] == ge_conf.datasource_name, + context.list_datasources(), + ) + ) if not selected_datasource: raise ValueError("Datasource doesn't exist!") @@ -226,7 +240,11 @@ def to_python_value( # FlyteSchema if lv.scalar.schema or lv.scalar.structured_dataset: return_dataset, temp_dataset = self._flyte_schema( - is_runtime=is_runtime, ctx=ctx, ge_conf=ge_conf, lv=lv, expected_python_type=type_conf[0] + is_runtime=is_runtime, + ctx=ctx, + ge_conf=ge_conf, + lv=lv, + expected_python_type=type_conf[0], ) # FlyteFile diff --git a/plugins/flytekit-greatexpectations/setup.py b/plugins/flytekit-greatexpectations/setup.py index f73b515539..0ef3fcf2fc 100644 --- a/plugins/flytekit-greatexpectations/setup.py +++ b/plugins/flytekit-greatexpectations/setup.py @@ -6,7 +6,7 @@ plugin_requires = [ "flytekit>=1.5.0,<2.0.0", - "great-expectations>=0.13.30", + "great-expectations>=0.13.30,<=0.18.8", "sqlalchemy>=1.4.23,<2.0.0", "pyspark==3.3.1", "s3fs<2023.6.0", diff --git a/pyproject.toml b/pyproject.toml index 30fa79b943..3596c73538 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,7 @@ branch = true [tool.ruff] line-length = 120 -select = ["E", "W", "F", "I"] +lint.select = ["E", "W", "F", "I"] lint.ignore = [ # Whitespace before '{symbol}' "E203", @@ -132,7 +132,7 @@ lint.ignore = [ "E731", ] -[tool.ruff.extend-per-file-ignores] +[tool.ruff.lint.extend-per-file-ignores] "*/__init__.py" = [ # unused-import "F401", From 0c4390a1eecce0bb60c79fbd19ba687a8708d3bc Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 19:03:43 +0530 Subject: [PATCH 045/120] update secret name Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 18 +++++++++--------- .../awssagemaker/boto3_agent.py | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 8a313b62c8..2daa29a5c2 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -67,9 +67,9 @@ async def async_create( config=config, inputs=inputs, region=region, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), + aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), + aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) metadata = Metadata(endpoint_name=config["EndpointName"], region=region) @@ -82,9 +82,9 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - method="describe_endpoint", config={"EndpointName": metadata.endpoint_name}, region=metadata.region, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), + aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), + aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) current_state = endpoint_status.get("EndpointStatus") @@ -117,9 +117,9 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes "delete_endpoint", config={"EndpointName": metadata.endpoint_name}, region=metadata.region, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), + aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), + aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) return DeleteTaskResponse() diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 18a729a276..3f64a49076 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -42,9 +42,9 @@ async def async_create( config=config, container=task_template.container, inputs=inputs, - aws_access_key_id=get_agent_secret(secret_key="AWS_ACCESS_KEY"), - aws_secret_access_key=get_agent_secret(secret_key="AWS_SECRET_ACCESS_KEY"), - aws_session_token=get_agent_secret(secret_key="AWS_SESSION_TOKEN"), + aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), + aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), + aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) outputs = None From eb30a63e9c494f7deb4c5eb73291a546fd847726 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 21:43:54 +0530 Subject: [PATCH 046/120] add typing dict and update tracker Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 8 ++++-- .../awssagemaker/boto3_mixin.py | 6 ++-- plugins/flytekit-papermill/tests/test_task.py | 28 ++++++++++++++++--- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index bc4e6a2333..a1610588d0 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -328,10 +328,12 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, elif f.instantiated_in: mod = importlib.import_module(f.instantiated_in) mod_name = mod.__name__ - if isPythonInstanceTask(f): - name = "" - else: + try: name = f.lhs + except Exception: + if not isPythonInstanceTask(f): + raise AssertionError(f"Unable to determine module of {f}") + name = "" else: raise AssertionError(f"Unable to determine module of {f}") else: diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py index 2d3f358fa0..9025c25864 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Dict, Optional import aioboto3 @@ -7,7 +7,7 @@ from flytekit.models.literals import LiteralMap -def update_dict_fn(original_dict: Any, update_dict: dict[str, Any]) -> Any: +def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: """ Recursively update a dictionary with values from another dictionary. For example, if original_dict is {"EndpointConfigName": "{endpoint_config_name}"}, @@ -88,7 +88,7 @@ def __init__(self, *, service: str, region: Optional[str] = None, **kwargs): async def _call( self, method: str, - config: dict[str, Any], + config: Dict[str, Any], container: Optional[_task_model.Container] = None, inputs: Optional[LiteralMap] = None, region: Optional[str] = None, diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 8c229f71f9..9c7b778afb 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -63,7 +63,11 @@ def test_notebook_task_simple(): sqr, out, render = nb_simple.execute(pi=4) assert sqr == 16.0 assert nb_simple.python_interface.inputs == {"pi": float} - assert nb_simple.python_interface.outputs.keys() == {"square", "out_nb", "out_rendered_nb"} + assert nb_simple.python_interface.outputs.keys() == { + "square", + "out_nb", + "out_rendered_nb", + } assert nb_simple.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb_simple.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") assert ( @@ -86,7 +90,14 @@ def test_notebook_task_multi_values(): assert h == "blah world!" assert type(n) == datetime.datetime assert nb.python_interface.inputs == {"x": int, "y": int, "h": str} - assert nb.python_interface.outputs.keys() == {"z", "m", "h", "n", "out_nb", "out_rendered_nb"} + assert nb.python_interface.outputs.keys() == { + "z", + "m", + "h", + "n", + "out_nb", + "out_rendered_nb", + } assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") @@ -104,7 +115,13 @@ def test_notebook_task_complex(): assert w is not None assert x.x == 10 assert nb.python_interface.inputs == {"n": int, "h": str, "w": str} - assert nb.python_interface.outputs.keys() == {"h", "w", "x", "out_nb", "out_rendered_nb"} + assert nb.python_interface.outputs.keys() == { + "h", + "w", + "x", + "out_nb", + "out_rendered_nb", + } assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") @@ -241,7 +258,10 @@ def wf(a: float) -> typing.List[float]: def test_register_notebook_task(mock_client, mock_remote): mock_remote._client = mock_client mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" - mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + mock_remote.return_value.fast_package.return_value = ( + "dummy_md5_bytes", + "dummy_native_url", + ) runner = CliRunner() context_manager.FlyteEntities.entities.clear() notebook_task = """ From 236a07256c047ebb3ccba7b7ab146c0fe50bfca6 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 20 Feb 2024 23:13:28 +0530 Subject: [PATCH 047/120] modify tracker test Signed-off-by: Samhita Alla --- tests/flytekit/unit/core/tracker/e.py | 8 ---- .../unit/core/tracker/test_tracking.py | 45 +++++++++++++++++-- 2 files changed, 42 insertions(+), 11 deletions(-) delete mode 100644 tests/flytekit/unit/core/tracker/e.py diff --git a/tests/flytekit/unit/core/tracker/e.py b/tests/flytekit/unit/core/tracker/e.py deleted file mode 100644 index f724df6d61..0000000000 --- a/tests/flytekit/unit/core/tracker/e.py +++ /dev/null @@ -1,8 +0,0 @@ -from flytekit.core.python_function_task import PythonInstanceTask - - -class E(PythonInstanceTask): - ... - - -e_instantiated = E(name="e-instantiated", task_config={}) diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index 19159a348e..efa331a496 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -4,11 +4,13 @@ from flytekit import task from flytekit.configuration.feature_flags import FeatureFlags +from flytekit.core.base_task import PythonTask +from flytekit.core.python_function_task import PythonInstanceTask from flytekit.core.tracker import extract_task_module +from flytekit.exceptions import system as _system_exceptions from tests.flytekit.unit.core.tracker import d from tests.flytekit.unit.core.tracker.b import b_local_a, local_b from tests.flytekit.unit.core.tracker.c import b_in_c, c_local_a -from tests.flytekit.unit.core.tracker.e import e_instantiated def test_tracking(): @@ -97,10 +99,47 @@ def test_extract_task_module(test_input, expected): raise -def test_extract_task_module_with_python_instance_task(): - _, _, name, _ = extract_task_module(e_instantiated) +class FakePythonInstanceTaskWithExceptionLHS(PythonInstanceTask): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._raises_exception = True + + @property + def lhs(self): + if self._raises_exception: + raise _system_exceptions.FlyteSystemException("Raising an exception") + return "some value" + + +python_instance_task_instantiated = FakePythonInstanceTaskWithExceptionLHS(name="python_instance_task", task_config={}) + + +class FakePythonTaskWithExceptionLHS(PythonTask): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._raises_exception = True + + @property + def lhs(self): + if self._raises_exception: + raise _system_exceptions.FlyteSystemException("Raising an exception") + return "some value" + + +python_task_instantiated = FakePythonTaskWithExceptionLHS( + name="python_task", + task_config={}, + task_type="python-task", +) + + +def test_raise_exception_when_accessing_nonexistent_lhs(): + _, _, name, _ = extract_task_module(python_instance_task_instantiated) assert name == "" + with pytest.raises(AssertionError): + extract_task_module(python_task_instantiated) + local_task = task(d.inner_function) From 8dc4086801e917af2b45a8b2e75cda9d66378dcd Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 15:04:37 +0530 Subject: [PATCH 048/120] add sync agent Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/README.md | 2 +- .../dev-requirements.txt | 1 - .../flytekitplugins/awssagemaker/__init__.py | 32 +++++----- .../flytekitplugins/awssagemaker/agent.py | 58 +++++++------------ .../awssagemaker/boto3_agent.py | 24 +++----- .../flytekitplugins/awssagemaker/task.py | 44 +++++++------- .../flytekitplugins/awssagemaker/workflow.py | 28 ++++----- plugins/flytekit-aws-sagemaker/setup.py | 2 +- .../tests/test_agent.py | 27 ++++----- .../tests/test_boto3_agent.py | 15 +++-- .../flytekit-aws-sagemaker/tests/test_task.py | 28 ++++----- 11 files changed, 117 insertions(+), 144 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index 37cf1f10f2..b73e85a110 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -1,4 +1,4 @@ -# AWS Sagemaker Plugin +# AWS SageMaker Plugin The plugin features a deployment agent enabling you to deploy SageMaker models, create and trigger inference endpoints. Additionally, you can entirely remove the SageMaker deployment using the `delete_sagemaker_deployment` workflow. diff --git a/plugins/flytekit-aws-sagemaker/dev-requirements.txt b/plugins/flytekit-aws-sagemaker/dev-requirements.txt index c63a6da9cb..2d73dba5b4 100644 --- a/plugins/flytekit-aws-sagemaker/dev-requirements.txt +++ b/plugins/flytekit-aws-sagemaker/dev-requirements.txt @@ -1,2 +1 @@ pytest-asyncio -pytest-mock diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index 68cbe0b41b..1afe8c05b1 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -7,28 +7,28 @@ BotoAgent BotoTask - SagemakerModelTask - SagemakerEndpointConfigTask - SagemakerEndpointAgent - SagemakerEndpointTask - SagemakerDeleteEndpointConfigTask - SagemakerDeleteEndpointTask - SagemakerDeleteModelTask - SagemakerInvokeEndpointTask + SageMakerModelTask + SageMakerEndpointConfigTask + SageMakerEndpointAgent + SageMakerEndpointTask + SageMakerDeleteEndpointConfigTask + SageMakerDeleteEndpointTask + SageMakerDeleteModelTask + SageMakerInvokeEndpointTask create_sagemaker_deployment delete_sagemaker_deployment """ -from .agent import SagemakerEndpointAgent +from .agent import SageMakerEndpointAgent from .boto3_agent import BotoAgent from .boto3_task import BotoConfig, BotoTask from .task import ( - SagemakerDeleteEndpointConfigTask, - SagemakerDeleteEndpointTask, - SagemakerDeleteModelTask, - SagemakerEndpointConfigTask, - SagemakerEndpointTask, - SagemakerInvokeEndpointTask, - SagemakerModelTask, + SageMakerDeleteEndpointConfigTask, + SageMakerDeleteEndpointTask, + SageMakerDeleteModelTask, + SageMakerEndpointConfigTask, + SageMakerEndpointTask, + SageMakerInvokeEndpointTask, + SageMakerModelTask, ) from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 2daa29a5c2..7b780ba6aa 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -1,26 +1,19 @@ import json -from dataclasses import asdict, dataclass +from dataclasses import dataclass from datetime import datetime from typing import Optional -import grpc -from flyteidl.admin.agent_pb2 import ( - CreateTaskResponse, - DeleteTaskResponse, - GetTaskResponse, - Resource, -) -from flyteidl.core.tasks_pb2 import TaskTemplate - from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( - AgentBase, AgentRegistry, - convert_to_flyte_phase, - get_agent_secret, + AsyncAgentBase, + Resource, + ResourceMeta, ) +from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate from .boto3_mixin import Boto3AgentMixin @@ -40,24 +33,24 @@ def default(self, o): @dataclass -class Metadata: +class SageMakerEndpointMetadata(ResourceMeta): endpoint_name: str region: str -class SagemakerEndpointAgent(Boto3AgentMixin, AgentBase): +class SageMakerEndpointAgent(Boto3AgentMixin, AsyncAgentBase): """This agent creates an endpoint.""" def __init__(self): - super().__init__(service="sagemaker", task_type="sagemaker-endpoint") - - async def async_create( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - ) -> CreateTaskResponse: + super().__init__( + service="sagemaker", + task_type_name="sagemaker-endpoint", + metadata_type=SageMakerEndpointMetadata, + ) + + async def create( + self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs + ) -> SageMakerEndpointMetadata: custom = task_template.custom config = custom["config"] region = custom["region"] @@ -72,12 +65,9 @@ async def async_create( aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - metadata = Metadata(endpoint_name=config["EndpointName"], region=region) - return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - - async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + return SageMakerEndpointMetadata(endpoint_name=config["EndpointName"], region=region) + async def get(self, metadata: SageMakerEndpointMetadata, **kwargs) -> Resource: endpoint_status = await self._call( method="describe_endpoint", config={"EndpointName": metadata.endpoint_name}, @@ -108,11 +98,9 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - } ).to_flyte_idl() - return GetTaskResponse(resource=Resource(state=flyte_state, outputs=res, message=message)) - - async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + return Resource(phase=flyte_state, outputs=res, message=message) + async def delete(self, metadata: SageMakerEndpointMetadata, **kwargs): await self._call( "delete_endpoint", config={"EndpointName": metadata.endpoint_name}, @@ -122,7 +110,5 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - return DeleteTaskResponse() - -AgentRegistry.register(SagemakerEndpointAgent()) +AgentRegistry.register(SageMakerEndpointAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 3f64a49076..0bc29666f1 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -1,34 +1,28 @@ from typing import Optional -import grpc -from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource -from flyteidl.core.tasks_pb2 import TaskTemplate +from flyteidl.core.execution_pb2 import TaskExecution from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( - AgentBase, AgentRegistry, - get_agent_secret, + Resource, + SyncAgentBase, ) +from flytekit.extend.backend.utils import get_agent_secret from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate from .boto3_mixin import Boto3AgentMixin -class BotoAgent(AgentBase): +class BotoAgent(SyncAgentBase): """A general purpose boto3 agent that can be used to call any boto3 method.""" def __init__(self): - super().__init__(task_type="boto") + super().__init__(task_type_name="boto") - async def async_create( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, - ) -> CreateTaskResponse: + async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: custom = task_template.custom service = custom["service"] config = custom["config"] @@ -61,7 +55,7 @@ async def async_create( } ).to_flyte_idl() - return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) AgentRegistry.register(BotoAgent()) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index 45f612e0b5..e0d3f21ebf 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -10,7 +10,7 @@ from .boto3_task import BotoConfig, BotoTask -class SagemakerModelTask(BotoTask): +class SageMakerModelTask(BotoTask): def __init__( self, name: str, @@ -21,7 +21,7 @@ def __init__( **kwargs, ): """ - Creates a Sagemaker model. + Creates a SageMaker model. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. @@ -31,7 +31,7 @@ def __init__( This can be either in Amazon EC2 Container Registry or in a Docker registry that is accessible from the same VPC that you configure for your endpoint. """ - super(SagemakerModelTask, self).__init__( + super(SageMakerModelTask, self).__init__( name=name, task_config=BotoConfig(service="sagemaker", method="create_model", config=config, region=region), inputs=inputs, @@ -41,7 +41,7 @@ def __init__( ) -class SagemakerEndpointConfigTask(BotoTask): +class SageMakerEndpointConfigTask(BotoTask): def __init__( self, name: str, @@ -51,14 +51,14 @@ def __init__( **kwargs, ): """ - Creates a Sagemaker endpoint configuration. + Creates a SageMaker endpoint configuration. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. """ - super(SagemakerEndpointConfigTask, self).__init__( + super(SageMakerEndpointConfigTask, self).__init__( name=name, task_config=BotoConfig( service="sagemaker", @@ -74,12 +74,12 @@ def __init__( @dataclass -class SagemakerEndpointMetadata(object): +class SageMakerEndpointMetadata(object): config: dict[str, Any] region: str -class SagemakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SagemakerEndpointMetadata]): +class SageMakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SageMakerEndpointMetadata]): _TASK_TYPE = "sagemaker-endpoint" def __init__( @@ -91,7 +91,7 @@ def __init__( **kwargs, ): """ - Creates a Sagemaker endpoint. + Creates a SageMaker endpoint. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. @@ -100,7 +100,7 @@ def __init__( """ super().__init__( name=name, - task_config=SagemakerEndpointMetadata( + task_config=SageMakerEndpointMetadata( config=config, region=region, ), @@ -113,7 +113,7 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: return {"config": self.task_config.config, "region": self.task_config.region} -class SagemakerDeleteEndpointTask(BotoTask): +class SageMakerDeleteEndpointTask(BotoTask): def __init__( self, name: str, @@ -123,14 +123,14 @@ def __init__( **kwargs, ): """ - Deletes a Sagemaker endpoint. + Deletes a SageMaker endpoint. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. """ - super(SagemakerDeleteEndpointTask, self).__init__( + super(SageMakerDeleteEndpointTask, self).__init__( name=name, task_config=BotoConfig( service="sagemaker", @@ -144,7 +144,7 @@ def __init__( ) -class SagemakerDeleteEndpointConfigTask(BotoTask): +class SageMakerDeleteEndpointConfigTask(BotoTask): def __init__( self, name: str, @@ -154,14 +154,14 @@ def __init__( **kwargs, ): """ - Deletes a Sagemaker endpoint config. + Deletes a SageMaker endpoint config. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. """ - super(SagemakerDeleteEndpointConfigTask, self).__init__( + super(SageMakerDeleteEndpointConfigTask, self).__init__( name=name, task_config=BotoConfig( service="sagemaker", @@ -175,7 +175,7 @@ def __init__( ) -class SagemakerDeleteModelTask(BotoTask): +class SageMakerDeleteModelTask(BotoTask): def __init__( self, name: str, @@ -185,14 +185,14 @@ def __init__( **kwargs, ): """ - Deletes a Sagemaker model. + Deletes a SageMaker model. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. """ - super(SagemakerDeleteModelTask, self).__init__( + super(SageMakerDeleteModelTask, self).__init__( name=name, task_config=BotoConfig( service="sagemaker", @@ -206,7 +206,7 @@ def __init__( ) -class SagemakerInvokeEndpointTask(BotoTask): +class SageMakerInvokeEndpointTask(BotoTask): def __init__( self, name: str, @@ -216,14 +216,14 @@ def __init__( **kwargs, ): """ - Invokes a Sagemaker endpoint. + Invokes a SageMaker endpoint. :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. """ - super(SagemakerInvokeEndpointTask, self).__init__( + super(SageMakerInvokeEndpointTask, self).__init__( name=name, task_config=BotoConfig( service="sagemaker-runtime", diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index a65ccf2697..04c6a34ec9 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -3,12 +3,12 @@ from flytekit import ImageSpec, Workflow, kwtypes from .task import ( - SagemakerDeleteEndpointConfigTask, - SagemakerDeleteEndpointTask, - SagemakerDeleteModelTask, - SagemakerEndpointConfigTask, - SagemakerEndpointTask, - SagemakerModelTask, + SageMakerDeleteEndpointConfigTask, + SageMakerDeleteEndpointTask, + SageMakerDeleteModelTask, + SageMakerEndpointConfigTask, + SageMakerEndpointTask, + SageMakerModelTask, ) @@ -24,9 +24,9 @@ def create_sagemaker_deployment( region: Optional[str] = None, ): """ - Creates Sagemaker model, endpoint config and endpoint. + Creates SageMaker model, endpoint config and endpoint. """ - sagemaker_model_task = SagemakerModelTask( + sagemaker_model_task = SageMakerModelTask( name=f"sagemaker-model-{name}", config=model_config, region=region, @@ -34,14 +34,14 @@ def create_sagemaker_deployment( container_image=container_image, ) - endpoint_config_task = SagemakerEndpointConfigTask( + endpoint_config_task = SageMakerEndpointConfigTask( name=f"sagemaker-endpoint-config-{name}", config=endpoint_config_config, region=region, inputs=endpoint_config_input_types, ) - endpoint_task = SagemakerEndpointTask( + endpoint_task = SageMakerEndpointTask( name=f"sagemaker-endpoint-{name}", config=endpoint_config, region=region, @@ -71,23 +71,23 @@ def create_sagemaker_deployment( def delete_sagemaker_deployment(name: str, region: Optional[str] = None): """ - Deletes Sagemaker model, endpoint config and endpoint. + Deletes SageMaker model, endpoint config and endpoint. """ - sagemaker_delete_endpoint = SagemakerDeleteEndpointTask( + sagemaker_delete_endpoint = SageMakerDeleteEndpointTask( name=f"sagemaker-delete-endpoint-{name}", config={"EndpointName": "{inputs.endpoint_name}"}, region=region, inputs=kwtypes(endpoint_name=str), ) - sagemaker_delete_endpoint_config = SagemakerDeleteEndpointConfigTask( + sagemaker_delete_endpoint_config = SageMakerDeleteEndpointConfigTask( name=f"sagemaker-delete-endpoint-config-{name}", config={"EndpointConfigName": "{inputs.endpoint_config_name}"}, region=region, inputs=kwtypes(endpoint_config_name=str), ) - sagemaker_delete_model = SagemakerDeleteModelTask( + sagemaker_delete_model = SageMakerDeleteModelTask( name=f"sagemaker-delete-model-{name}", config={"ModelName": "{inputs.model_name}"}, region=region, diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 8f7d4ee930..1be226bdea 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -14,7 +14,7 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="Flytekit AWS Sagemaker plugin", + description="Flytekit AWS SageMaker plugin", namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, diff --git a/plugins/flytekit-aws-sagemaker/tests/test_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_agent.py index c75b9a1cc8..10885408dc 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_agent.py @@ -1,13 +1,11 @@ import json -from dataclasses import asdict from datetime import timedelta from unittest import mock import pytest -from flyteidl.admin.agent_pb2 import RUNNING, DeleteTaskResponse -from flytekitplugins.awssagemaker.agent import Metadata +from flyteidl.core.execution_pb2 import TaskExecution +from flytekitplugins.awssagemaker.agent import SageMakerEndpointMetadata -from flytekit import FlyteContext, FlyteContextManager from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier from flytekit.models import literals @@ -62,7 +60,6 @@ }, ) async def test_agent(mock_boto_call, mock_secret): - ctx = FlyteContextManager.current_context() agent = AgentRegistry.get_agent("sagemaker-endpoint") task_id = Identifier( resource_type=ResourceType.TASK, @@ -99,22 +96,20 @@ async def test_agent(mock_boto_call, mock_secret): interface=None, type="sagemaker-endpoint", ) - output_prefix = FlyteContext.current_context().file_access.get_random_local_directory() # CREATE - response = await agent.async_create(ctx, output_prefix, task_template) - - metadata = Metadata(endpoint_name="sagemaker-endpoint", region="us-east-2") - metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") - assert response.resource_meta == metadata_bytes + metadata = SageMakerEndpointMetadata(endpoint_name="sagemaker-endpoint", region="us-east-2") + response = await agent.create(task_template) + assert response == metadata # GET - response = await agent.async_get(ctx, metadata_bytes) - assert response.resource.state == RUNNING - from_json = json.loads(response.resource.outputs.literals["result"].scalar.primitive.string_value) + resource = await agent.get(metadata) + assert resource.phase == TaskExecution.SUCCEEDED + + from_json = json.loads(resource.outputs.literals["result"].scalar.primitive.string_value) assert from_json["EndpointName"] == "sagemaker-xgboost-endpoint" assert from_json["EndpointArn"] == "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint" # DELETE - delete_response = await agent.async_delete(ctx, metadata_bytes) - assert isinstance(delete_response, DeleteTaskResponse) + delete_response = await agent.delete(metadata) + assert delete_response is None diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index c9a7fb8c51..5e2183d6a7 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -2,9 +2,8 @@ from unittest import mock import pytest -from flyteidl.admin.agent_pb2 import SUCCEEDED +from flyteidl.core.execution_pb2 import TaskExecution -from flytekit import FlyteContext, FlyteContextManager from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier from flytekit.models import literals @@ -35,7 +34,6 @@ }, ) async def test_agent(mock_boto_call, mock_secret): - ctx = FlyteContextManager.current_context() agent = AgentRegistry.get_agent("boto") task_id = Identifier( resource_type=ResourceType.TASK, @@ -90,10 +88,11 @@ async def test_agent(mock_boto_call, mock_secret): ), }, ) - output_prefix = FlyteContext.current_context().file_access.get_random_local_directory() - response = await agent.async_create(ctx, output_prefix, task_template, task_inputs) + resource = await agent.do(task_template, task_inputs) - assert response.HasField("resource") - assert response.resource.state == SUCCEEDED - assert response.resource.outputs is not None + assert resource.phase == TaskExecution.SUCCEEDED + assert ( + resource.outputs.literals["result"].scalar.generic.fields["EndpointConfigArn"].string_value + == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" + ) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_task.py b/plugins/flytekit-aws-sagemaker/tests/test_task.py index 8d0863cd58..17a47dee36 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_task.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_task.py @@ -1,12 +1,12 @@ import pytest from flytekitplugins.awssagemaker import ( - SagemakerDeleteEndpointConfigTask, - SagemakerDeleteEndpointTask, - SagemakerDeleteModelTask, - SagemakerEndpointConfigTask, - SagemakerEndpointTask, - SagemakerInvokeEndpointTask, - SagemakerModelTask, + SageMakerDeleteEndpointConfigTask, + SageMakerDeleteEndpointTask, + SageMakerDeleteModelTask, + SageMakerEndpointConfigTask, + SageMakerEndpointTask, + SageMakerInvokeEndpointTask, + SageMakerModelTask, ) from flytekit import kwtypes @@ -33,7 +33,7 @@ 3, 1, "us-east-2", - SagemakerModelTask, + SageMakerModelTask, ), ( "sagemaker_endpoint_config", @@ -56,7 +56,7 @@ 3, 1, "us-east-2", - SagemakerEndpointConfigTask, + SageMakerEndpointConfigTask, ), ( "sagemaker_endpoint", @@ -71,7 +71,7 @@ 2, 1, "us-east-2", - SagemakerEndpointTask, + SageMakerEndpointTask, ), ( "sagemaker_delete_endpoint", @@ -83,7 +83,7 @@ 1, 0, "us-east-2", - SagemakerDeleteEndpointTask, + SageMakerDeleteEndpointTask, ), ( "sagemaker_delete_endpoint_config", @@ -95,7 +95,7 @@ 1, 0, "us-east-2", - SagemakerDeleteEndpointConfigTask, + SageMakerDeleteEndpointConfigTask, ), ( "sagemaker_delete_model", @@ -107,7 +107,7 @@ 1, 0, "us-east-2", - SagemakerDeleteModelTask, + SageMakerDeleteModelTask, ), ( "sagemaker_invoke_endpoint", @@ -122,7 +122,7 @@ 1, 1, "us-east-2", - SagemakerInvokeEndpointTask, + SageMakerInvokeEndpointTask, ), ], ) From 976aa1b46841f9583228ab018cfddac9d1ab84bb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 15:39:44 +0530 Subject: [PATCH 049/120] add name Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 0bc29666f1..be0d35a6fd 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -19,6 +19,8 @@ class BotoAgent(SyncAgentBase): """A general purpose boto3 agent that can be used to call any boto3 method.""" + name = "Boto Agent" + def __init__(self): super().__init__(task_type_name="boto") From 4127d9509bfff477ae762cf1e8340c2198f1505f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 15:41:51 +0530 Subject: [PATCH 050/120] add syncagentexecutormixin Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_task.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index e4ab88ced4..810e2d7c29 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -5,7 +5,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.interface import Interface from flytekit.core.python_function_task import PythonInstanceTask -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin @dataclass @@ -16,7 +16,7 @@ class BotoConfig(object): region: str -class BotoTask(AsyncAgentExecutorMixin, PythonInstanceTask[BotoConfig]): +class BotoTask(SyncAgentExecutorMixin, PythonInstanceTask[BotoConfig]): _TASK_TYPE = "boto" def __init__( From 6c4109ba7da8024f9b52fa2c1f558e3dc928358d Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 15:48:26 +0530 Subject: [PATCH 051/120] modify sync output Signed-off-by: Samhita Alla --- .../awssagemaker/boto3_agent.py | 18 +---------------- .../tests/test_boto3_agent.py | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index be0d35a6fd..07f9564cdd 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -2,8 +2,6 @@ from flyteidl.core.execution_pb2 import TaskExecution -from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentRegistry, Resource, @@ -43,21 +41,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - outputs = None - if result: - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - { - "result": TypeEngine.to_literal( - ctx, - result, - dict, - TypeEngine.to_literal_type(dict), - ) - } - ).to_flyte_idl() - - return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) + return Resource(phase=TaskExecution.SUCCEEDED, outputs={"result": result}) AgentRegistry.register(BotoAgent()) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index 5e2183d6a7..77a2fe438b 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -92,7 +92,19 @@ async def test_agent(mock_boto_call, mock_secret): resource = await agent.do(task_template, task_inputs) assert resource.phase == TaskExecution.SUCCEEDED - assert ( - resource.outputs.literals["result"].scalar.generic.fields["EndpointConfigArn"].string_value - == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" - ) + assert resource.outputs == { + "result": { + "ResponseMetadata": { + "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", + "HTTPStatusCode": 200.0, + "RetryAttempts": 0.0, + "HTTPHeaders": { + "content-type": "application/x-amz-json-1.1", + "date": "Wed, 31 Jan 2024 16:43:52 GMT", + "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", + "content-length": "114", + }, + }, + "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", + } + } From 692c2b47c48f0d20cff30e089c48cccf8189f3ec Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 15:53:14 +0530 Subject: [PATCH 052/120] metadata to resource_meta Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 7b780ba6aa..2300ded5b6 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -67,11 +67,11 @@ async def create( return SageMakerEndpointMetadata(endpoint_name=config["EndpointName"], region=region) - async def get(self, metadata: SageMakerEndpointMetadata, **kwargs) -> Resource: + async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: endpoint_status = await self._call( method="describe_endpoint", - config={"EndpointName": metadata.endpoint_name}, - region=metadata.region, + config={"EndpointName": resource_meta.endpoint_name}, + region=resource_meta.region, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), aws_session_token=get_agent_secret(secret_key="aws-session-token"), @@ -100,11 +100,11 @@ async def get(self, metadata: SageMakerEndpointMetadata, **kwargs) -> Resource: return Resource(phase=flyte_state, outputs=res, message=message) - async def delete(self, metadata: SageMakerEndpointMetadata, **kwargs): + async def delete(self, resource_meta: SageMakerEndpointMetadata, **kwargs): await self._call( "delete_endpoint", - config={"EndpointName": metadata.endpoint_name}, - region=metadata.region, + config={"EndpointName": resource_meta.endpoint_name}, + region=resource_meta.region, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), aws_session_token=get_agent_secret(secret_key="aws-session-token"), From ba2c16a0f93580c9eb9dadd1e2146eba6ba92c16 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 16:01:26 +0530 Subject: [PATCH 053/120] remote conversion to flyte idl Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/agent.py | 2 +- .../awssagemaker/boto3_agent.py | 18 ++++++++++++++++- .../tests/test_boto3_agent.py | 20 ++++--------------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py index 2300ded5b6..2db22013e4 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py @@ -96,7 +96,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou TypeEngine.to_literal_type(str), ) } - ).to_flyte_idl() + ) return Resource(phase=flyte_state, outputs=res, message=message) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 07f9564cdd..3057508502 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -2,6 +2,8 @@ from flyteidl.core.execution_pb2 import TaskExecution +from flytekit import FlyteContextManager +from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentRegistry, Resource, @@ -41,7 +43,21 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - return Resource(phase=TaskExecution.SUCCEEDED, outputs={"result": result}) + outputs = None + if result: + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "result": TypeEngine.to_literal( + ctx, + result, + dict, + TypeEngine.to_literal_type(dict), + ) + } + ) + + return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) AgentRegistry.register(BotoAgent()) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index 77a2fe438b..5e2183d6a7 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -92,19 +92,7 @@ async def test_agent(mock_boto_call, mock_secret): resource = await agent.do(task_template, task_inputs) assert resource.phase == TaskExecution.SUCCEEDED - assert resource.outputs == { - "result": { - "ResponseMetadata": { - "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", - "HTTPStatusCode": 200.0, - "RetryAttempts": 0.0, - "HTTPHeaders": { - "content-type": "application/x-amz-json-1.1", - "date": "Wed, 31 Jan 2024 16:43:52 GMT", - "x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2", - "content-length": "114", - }, - }, - "EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config", - } - } + assert ( + resource.outputs.literals["result"].scalar.generic.fields["EndpointConfigArn"].string_value + == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" + ) From 01b41c9ee2a26ddd92d313548dcd9b59269416e0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 16:15:52 +0530 Subject: [PATCH 054/120] add output type Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_task.py | 5 ++--- .../flytekitplugins/awssagemaker/task.py | 3 --- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py index 810e2d7c29..5a8873b798 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Optional, Type, Union -from flytekit import ImageSpec +from flytekit import ImageSpec, kwtypes from flytekit.configuration import SerializationSettings from flytekit.core.interface import Interface from flytekit.core.python_function_task import PythonInstanceTask @@ -24,7 +24,6 @@ def __init__( name: str, task_config: BotoConfig, inputs: Optional[dict[str, Type]] = None, - outputs: Optional[dict[str, Type]] = None, container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): @@ -32,7 +31,7 @@ def __init__( name=name, task_config=task_config, task_type=self._TASK_TYPE, - interface=Interface(inputs=inputs, outputs=outputs), + interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)), container_image=container_image, **kwargs, ) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py index e0d3f21ebf..866e5bf4d7 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py @@ -35,7 +35,6 @@ def __init__( name=name, task_config=BotoConfig(service="sagemaker", method="create_model", config=config, region=region), inputs=inputs, - outputs=kwtypes(result=dict), container_image=container_image, **kwargs, ) @@ -67,7 +66,6 @@ def __init__( region=region, ), inputs=inputs, - outputs=kwtypes(result=dict), container_image=DefaultImages.default_image(), **kwargs, ) @@ -232,7 +230,6 @@ def __init__( region=region, ), inputs=inputs, - outputs=kwtypes(result=dict), container_image=DefaultImages.default_image(), **kwargs, ) From 2a843accd38a73b5a7ad35b495d1e33f6647e42a Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 21:37:57 +0530 Subject: [PATCH 055/120] floats to ints Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 3057508502..be9b975f14 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -16,6 +16,21 @@ from .boto3_mixin import Boto3AgentMixin +def convert_floats_with_no_fraction_to_ints(data): + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, float) and value.is_integer(): + data[key] = int(value) + elif isinstance(value, dict) or isinstance(value, list): + convert_floats_with_no_fraction_to_ints(value) + elif isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, float) and item.is_integer(): + data[i] = int(item) + elif isinstance(item, dict) or isinstance(item, list): + convert_floats_with_no_fraction_to_ints(item) + + class BotoAgent(SyncAgentBase): """A general purpose boto3 agent that can be used to call any boto3 method.""" @@ -27,7 +42,7 @@ def __init__(self): async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: custom = task_template.custom service = custom["service"] - config = custom["config"] + config = convert_floats_with_no_fraction_to_ints(custom["config"]) region = custom["region"] method = custom["method"] From 590f59d574e7eaf0faec17e6d5918f630bec5a7f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 21:41:16 +0530 Subject: [PATCH 056/120] in place modification Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index be9b975f14..fbb4bd6646 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -42,7 +42,9 @@ def __init__(self): async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: custom = task_template.custom service = custom["service"] - config = convert_floats_with_no_fraction_to_ints(custom["config"]) + raw_config = custom["config"] + convert_floats_with_no_fraction_to_ints(raw_config) + config = raw_config region = custom["region"] method = custom["method"] From 8366cf946e55184e31e9548000fe23918ae1c250 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 21:52:14 +0530 Subject: [PATCH 057/120] chain tasks Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 1 + .../flytekitplugins/awssagemaker/workflow.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index fbb4bd6646..2928774637 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -16,6 +16,7 @@ from .boto3_mixin import Boto3AgentMixin +# https://github.com/flyteorg/flyte/issues/4505 def convert_floats_with_no_fraction_to_ints(data): if isinstance(data, dict): for key, value in data.items(): diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py index 04c6a34ec9..608b7a9473 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py @@ -63,7 +63,10 @@ def create_sagemaker_deployment( for param, t in value.items(): wf.add_workflow_input(param, t) input_dict[param] = wf.inputs[param] - nodes.append(wf.add_entity(key, **input_dict)) + node = wf.add_entity(key, **input_dict) + if len(nodes) > 0: + nodes[-1] >> node + nodes.append(node) wf.add_workflow_output("wf_output", nodes[2].outputs["result"], str) return wf @@ -99,17 +102,19 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None): wf.add_workflow_input("endpoint_config_name", str) wf.add_workflow_input("model_name", str) - wf.add_entity( + node_t1 = wf.add_entity( sagemaker_delete_endpoint, endpoint_name=wf.inputs["endpoint_name"], ) - wf.add_entity( + node_t2 = wf.add_entity( sagemaker_delete_endpoint_config, endpoint_config_name=wf.inputs["endpoint_config_name"], ) - wf.add_entity( + node_t3 = wf.add_entity( sagemaker_delete_model, model_name=wf.inputs["model_name"], ) + node_t1 >> node_t2 + node_t2 >> node_t3 return wf From e298f40764308483e05c3149456bc06cee65be23 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 21 Feb 2024 22:18:08 +0530 Subject: [PATCH 058/120] great expectations revert Signed-off-by: Samhita Alla --- plugins/flytekit-greatexpectations/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-greatexpectations/setup.py b/plugins/flytekit-greatexpectations/setup.py index 0ef3fcf2fc..f73b515539 100644 --- a/plugins/flytekit-greatexpectations/setup.py +++ b/plugins/flytekit-greatexpectations/setup.py @@ -6,7 +6,7 @@ plugin_requires = [ "flytekit>=1.5.0,<2.0.0", - "great-expectations>=0.13.30,<=0.18.8", + "great-expectations>=0.13.30", "sqlalchemy>=1.4.23,<2.0.0", "pyspark==3.3.1", "s3fs<2023.6.0", From 1d410269d698d504889fef0a23381da3bf205186 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 23 Feb 2024 23:04:59 +0530 Subject: [PATCH 059/120] optimize float to int code Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker/boto3_agent.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py index 2928774637..32673f1145 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py @@ -20,16 +20,13 @@ def convert_floats_with_no_fraction_to_ints(data): if isinstance(data, dict): for key, value in data.items(): - if isinstance(value, float) and value.is_integer(): - data[key] = int(value) - elif isinstance(value, dict) or isinstance(value, list): - convert_floats_with_no_fraction_to_ints(value) + data[key] = convert_floats_with_no_fraction_to_ints(value) elif isinstance(data, list): for i, item in enumerate(data): - if isinstance(item, float) and item.is_integer(): - data[i] = int(item) - elif isinstance(item, dict) or isinstance(item, list): - convert_floats_with_no_fraction_to_ints(item) + data[i] = convert_floats_with_no_fraction_to_ints(item) + elif isinstance(data, float) and data.is_integer(): + return int(data) + return data class BotoAgent(SyncAgentBase): From 5de9684fcc45e2ddf020803c6329dbfceb81f218 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Sat, 24 Feb 2024 13:08:45 +0530 Subject: [PATCH 060/120] snake case Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index a1610588d0..5871cbef81 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -305,7 +305,7 @@ def _task_module_from_callable(f: Callable): return mod, mod_name, name -def isPythonInstanceTask(obj): +def is_python_instance_task(obj): for cls in inspect.getmro(type(obj)): try: if cls.__name__ == "PythonInstanceTask": @@ -331,7 +331,7 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, try: name = f.lhs except Exception: - if not isPythonInstanceTask(f): + if not is_python_instance_task(f): raise AssertionError(f"Unable to determine module of {f}") name = "" else: From 6990dc4a69876c6ecbd136d9f699e8701bae753c Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 29 Feb 2024 19:10:18 +0530 Subject: [PATCH 061/120] modify plugin name Signed-off-by: Samhita Alla --- .github/workflows/pythonbuild.yml | 2 +- Dockerfile.agent | 2 +- flytekit/core/tracker.py | 34 +++++++++---------- .../README.md | 10 +++--- .../dev-requirements.txt | 0 .../awssagemaker_inference}/__init__.py | 2 +- .../awssagemaker_inference}/agent.py | 0 .../awssagemaker_inference}/boto3_agent.py | 0 .../awssagemaker_inference}/boto3_mixin.py | 0 .../awssagemaker_inference}/boto3_task.py | 0 .../awssagemaker_inference}/task.py | 0 .../awssagemaker_inference}/workflow.py | 0 .../setup.py | 4 +-- .../tests/__init__.py | 0 .../tests/test_agent.py | 2 +- .../tests/test_boto3_agent.py | 0 .../tests/test_boto3_mixin.py | 2 +- .../tests/test_boto3_task.py | 2 +- .../tests/test_task.py | 2 +- .../tests/test_workflow.py | 2 +- 20 files changed, 31 insertions(+), 33 deletions(-) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/README.md (90%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/dev-requirements.txt (100%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/__init__.py (93%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/agent.py (100%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/boto3_agent.py (100%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/boto3_mixin.py (100%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/boto3_task.py (100%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/task.py (100%) rename plugins/{flytekit-aws-sagemaker/flytekitplugins/awssagemaker => flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference}/workflow.py (100%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/setup.py (92%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/__init__.py (100%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/test_agent.py (98%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/test_boto3_agent.py (100%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/test_boto3_mixin.py (96%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/test_boto3_task.py (96%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/test_task.py (98%) rename plugins/{flytekit-aws-sagemaker => flytekit-awssagemaker-inference}/tests/test_workflow.py (97%) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 0b5e60a731..e5e3a14f21 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -186,7 +186,7 @@ jobs: - flytekit-async-fsspec - flytekit-aws-athena - flytekit-aws-batch - - flytekit-aws-sagemaker + - flytekit-awssagemaker-inference # TODO: uncomment this when the sagemaker agent is implemented: https://github.com/flyteorg/flyte/issues/4079 # - flytekit-aws-sagemaker - flytekit-bigquery diff --git a/Dockerfile.agent b/Dockerfile.agent index 9bf5ab72ab..0715221f37 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -14,7 +14,7 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ flytekitplugins-mmcloud==$VERSION \ flytekitplugins-spark==$VERSION \ flytekitplugins-snowflake==$VERSION \ - flytekitplugins-awssagemaker==$VERSION \ + flytekitplugins-awssagemaker-inference==$VERSION \ && apt-get clean autoclean \ && apt-get autoremove --yes \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 5871cbef81..07d09364c9 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -89,6 +89,16 @@ def __call__(cls, *args, **kwargs): return o +def is_python_instance_task(obj): + for cls in inspect.getmro(type(obj)): + try: + if cls.__name__ == "PythonInstanceTask": + return True + except Exception: + pass + return False + + class TrackedInstance(metaclass=InstanceTrackingMeta): """ Please see the notes for the metaclass above first. @@ -171,8 +181,11 @@ def _candidate_name_matches(candidate) -> bool: except ValueError as err: logger.warning(f"Caught ValueError {err} while attempting to auto-assign name") - logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") - raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") + if not is_python_instance_task(self): + logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") + raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") + else: + return "" def isnested(func: Callable) -> bool: @@ -305,16 +318,6 @@ def _task_module_from_callable(f: Callable): return mod, mod_name, name -def is_python_instance_task(obj): - for cls in inspect.getmro(type(obj)): - try: - if cls.__name__ == "PythonInstanceTask": - return True - except Exception: - pass - return False - - def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, str, str]: """ Returns the task-name, absolute module and the string name of the callable. @@ -328,12 +331,7 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, elif f.instantiated_in: mod = importlib.import_module(f.instantiated_in) mod_name = mod.__name__ - try: - name = f.lhs - except Exception: - if not is_python_instance_task(f): - raise AssertionError(f"Unable to determine module of {f}") - name = "" + name = f.lhs else: raise AssertionError(f"Unable to determine module of {f}") else: diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-awssagemaker-inference/README.md similarity index 90% rename from plugins/flytekit-aws-sagemaker/README.md rename to plugins/flytekit-awssagemaker-inference/README.md index b73e85a110..020c757fc8 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-awssagemaker-inference/README.md @@ -1,4 +1,4 @@ -# AWS SageMaker Plugin +# AWS SageMaker Inference Plugin The plugin features a deployment agent enabling you to deploy SageMaker models, create and trigger inference endpoints. Additionally, you can entirely remove the SageMaker deployment using the `delete_sagemaker_deployment` workflow. @@ -6,16 +6,16 @@ Additionally, you can entirely remove the SageMaker deployment using the `delete To install the plugin, run the following command: ```bash -pip install flytekitplugins-awssagemaker +pip install flytekitplugins-awssagemaker-inference ``` Here is a sample SageMaker deployment workflow: ```python REGION = os.getenv("REGION") -MODEL_NAME = "sagemaker-xgboost" -ENDPOINT_CONFIG_NAME = "sagemaker-xgboost-endpoint-config" -ENDPOINT_NAME = "sagemaker-xgboost-endpoint" +MODEL_NAME = "xgboost" +ENDPOINT_CONFIG_NAME = "xgboost-endpoint-config" +ENDPOINT_NAME = "xgboost-endpoint" sagemaker_deployment_wf = create_sagemaker_deployment( name="sagemaker-deployment", diff --git a/plugins/flytekit-aws-sagemaker/dev-requirements.txt b/plugins/flytekit-awssagemaker-inference/dev-requirements.txt similarity index 100% rename from plugins/flytekit-aws-sagemaker/dev-requirements.txt rename to plugins/flytekit-awssagemaker-inference/dev-requirements.txt diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/__init__.py similarity index 93% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/__init__.py index 1afe8c05b1..169abc2e76 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/__init__.py @@ -1,5 +1,5 @@ """ -.. currentmodule:: flytekitplugins.awssagemaker +.. currentmodule:: flytekitplugins.awssagemaker_inference .. autosummary:: :template: custom.rst diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/agent.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_agent.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_mixin.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/boto3_task.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/task.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/workflow.py rename to plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-awssagemaker-inference/setup.py similarity index 92% rename from plugins/flytekit-aws-sagemaker/setup.py rename to plugins/flytekit-awssagemaker-inference/setup.py index 1be226bdea..7db550b92c 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-awssagemaker-inference/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -PLUGIN_NAME = "awssagemaker" +PLUGIN_NAME = "awssagemaker-inference" microlib_name = f"flytekitplugins-{PLUGIN_NAME}" @@ -14,7 +14,7 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="Flytekit AWS SageMaker plugin", + description="Flytekit AWS SageMaker Inference Plugin", namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, diff --git a/plugins/flytekit-aws-sagemaker/tests/__init__.py b/plugins/flytekit-awssagemaker-inference/tests/__init__.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/tests/__init__.py rename to plugins/flytekit-awssagemaker-inference/tests/__init__.py diff --git a/plugins/flytekit-aws-sagemaker/tests/test_agent.py b/plugins/flytekit-awssagemaker-inference/tests/test_agent.py similarity index 98% rename from plugins/flytekit-aws-sagemaker/tests/test_agent.py rename to plugins/flytekit-awssagemaker-inference/tests/test_agent.py index 10885408dc..20e3390323 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_agent.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_agent.py @@ -4,7 +4,7 @@ import pytest from flyteidl.core.execution_pb2 import TaskExecution -from flytekitplugins.awssagemaker.agent import SageMakerEndpointMetadata +from flytekitplugins.awssagemaker_inference.agent import SageMakerEndpointMetadata from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py similarity index 100% rename from plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py rename to plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py similarity index 96% rename from plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py rename to plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py index 5b95d02f0b..7a801c3a16 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py @@ -1,6 +1,6 @@ import typing -from flytekitplugins.awssagemaker.boto3_mixin import update_dict_fn +from flytekitplugins.awssagemaker_inference.boto3_mixin import update_dict_fn from flytekit import FlyteContext, StructuredDataset from flytekit.core.type_engine import TypeEngine diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_task.py similarity index 96% rename from plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py rename to plugins/flytekit-awssagemaker-inference/tests/test_boto3_task.py index e5fe6f32c7..78dce7eae3 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_task.py @@ -1,4 +1,4 @@ -from flytekitplugins.awssagemaker import BotoConfig, BotoTask +from flytekitplugins.awssagemaker_inference import BotoConfig, BotoTask from flytekit import kwtypes from flytekit.configuration import Image, ImageConfig, SerializationSettings diff --git a/plugins/flytekit-aws-sagemaker/tests/test_task.py b/plugins/flytekit-awssagemaker-inference/tests/test_task.py similarity index 98% rename from plugins/flytekit-aws-sagemaker/tests/test_task.py rename to plugins/flytekit-awssagemaker-inference/tests/test_task.py index 17a47dee36..fdeb6cab93 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_task.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_task.py @@ -1,5 +1,5 @@ import pytest -from flytekitplugins.awssagemaker import ( +from flytekitplugins.awssagemaker_inference import ( SageMakerDeleteEndpointConfigTask, SageMakerDeleteEndpointTask, SageMakerDeleteModelTask, diff --git a/plugins/flytekit-aws-sagemaker/tests/test_workflow.py b/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py similarity index 97% rename from plugins/flytekit-aws-sagemaker/tests/test_workflow.py rename to plugins/flytekit-awssagemaker-inference/tests/test_workflow.py index 96002b65a2..739405ee6e 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_workflow.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py @@ -1,4 +1,4 @@ -from flytekitplugins.awssagemaker import ( +from flytekitplugins.awssagemaker_inference import ( create_sagemaker_deployment, delete_sagemaker_deployment, ) From 0b1510459141da05f78eaede123d5423d465fc57 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 29 Feb 2024 19:19:37 +0530 Subject: [PATCH 062/120] modify plugin name Signed-off-by: Samhita Alla --- plugins/flytekit-awssagemaker-inference/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/setup.py b/plugins/flytekit-awssagemaker-inference/setup.py index 7db550b92c..e91f13afdb 100644 --- a/plugins/flytekit-awssagemaker-inference/setup.py +++ b/plugins/flytekit-awssagemaker-inference/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -PLUGIN_NAME = "awssagemaker-inference" +PLUGIN_NAME = "awssagemaker_inference" microlib_name = f"flytekitplugins-{PLUGIN_NAME}" From 64ad2e252f1171c43f822223b28d0680a1c3a6dc Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 29 Feb 2024 22:05:08 +0530 Subject: [PATCH 063/120] modify tracker tests Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 2 +- .../unit/core/tracker/test_tracking.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 07d09364c9..490b458bbf 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -140,7 +140,7 @@ def find_lhs(self) -> str: raise _system_exceptions.FlyteSystemException(f"Object {self} does not have an _instantiated in") logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}") - m = importlib.import_module(self._instantiated_in) + m = importlib.import_module(self.instantiated_in) for k in dir(m): try: if getattr(m, k) is self: diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index efa331a496..e028be477b 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -7,7 +7,7 @@ from flytekit.core.base_task import PythonTask from flytekit.core.python_function_task import PythonInstanceTask from flytekit.core.tracker import extract_task_module -from flytekit.exceptions import system as _system_exceptions +from flytekit.exceptions.system import FlyteSystemException from tests.flytekit.unit.core.tracker import d from tests.flytekit.unit.core.tracker.b import b_local_a, local_b from tests.flytekit.unit.core.tracker.c import b_in_c, c_local_a @@ -102,13 +102,11 @@ def test_extract_task_module(test_input, expected): class FakePythonInstanceTaskWithExceptionLHS(PythonInstanceTask): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._raises_exception = True @property - def lhs(self): - if self._raises_exception: - raise _system_exceptions.FlyteSystemException("Raising an exception") - return "some value" + def instantiated_in(self) -> str: + # random module + return "tests.flytekit.unit.exceptions.test_base" python_instance_task_instantiated = FakePythonInstanceTaskWithExceptionLHS(name="python_instance_task", task_config={}) @@ -120,10 +118,9 @@ def __init__(self, *args, **kwargs): self._raises_exception = True @property - def lhs(self): - if self._raises_exception: - raise _system_exceptions.FlyteSystemException("Raising an exception") - return "some value" + def instantiated_in(self) -> str: + # random module + return "tests.flytekit.unit.exceptions.test_base" python_task_instantiated = FakePythonTaskWithExceptionLHS( @@ -137,7 +134,7 @@ def test_raise_exception_when_accessing_nonexistent_lhs(): _, _, name, _ = extract_task_module(python_instance_task_instantiated) assert name == "" - with pytest.raises(AssertionError): + with pytest.raises(FlyteSystemException): extract_task_module(python_task_instantiated) From 327cbdd9b10196e0da807da88d6a698ffbc1744f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 4 Mar 2024 15:34:25 +0530 Subject: [PATCH 064/120] fix tests, revert tracker changes and remove pythoninstancetask Signed-off-by: Samhita Alla --- .github/workflows/pythonbuild.yml | 2 - docs/source/plugins/awssagemaker.rst | 13 ++--- flytekit/core/base_task.py | 43 ++++++++++++--- flytekit/core/task.py | 54 ++++++++++++++++--- flytekit/core/tracker.py | 17 +----- plugins/README.md | 5 +- .../awssagemaker_inference/boto3_agent.py | 2 +- .../awssagemaker_inference/boto3_mixin.py | 17 +++--- .../awssagemaker_inference/boto3_task.py | 15 ++++-- .../awssagemaker_inference/task.py | 24 ++++----- .../awssagemaker_inference/workflow.py | 8 +-- .../flytekit-awssagemaker-inference/setup.py | 2 +- .../tests/test_agent.py | 4 +- .../tests/test_boto3_agent.py | 5 +- .../tests/test_boto3_mixin.py | 6 +-- .../tests/test_task.py | 18 +++---- .../tests/test_workflow.py | 4 +- plugins/setup.py | 2 +- .../unit/core/tracker/test_tracking.py | 42 --------------- 19 files changed, 146 insertions(+), 137 deletions(-) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index e5e3a14f21..92233ef6fb 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -187,8 +187,6 @@ jobs: - flytekit-aws-athena - flytekit-aws-batch - flytekit-awssagemaker-inference - # TODO: uncomment this when the sagemaker agent is implemented: https://github.com/flyteorg/flyte/issues/4079 - # - flytekit-aws-sagemaker - flytekit-bigquery - flytekit-dask - flytekit-data-fsspec diff --git a/docs/source/plugins/awssagemaker.rst b/docs/source/plugins/awssagemaker.rst index b8ded38bee..c0862c7dc4 100644 --- a/docs/source/plugins/awssagemaker.rst +++ b/docs/source/plugins/awssagemaker.rst @@ -1,17 +1,12 @@ -:orphan: +.. _awssagemaker_inference: -.. TODO: Will need to add this document back to the plugins/index.rst file -.. when sagemaker agent work is done: https://github.com/flyteorg/flyte/issues/4079 - -.. _awssagemaker: - -################################################### +########################### AWS Sagemaker API reference -################################################### +########################### .. tags:: Integration, MachineLearning, AWS -.. automodule:: flytekitplugins.awssagemaker +.. automodule:: flytekitplugins.awssagemaker_inference :no-members: :no-inherited-members: :no-special-members: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 54f546ef1d..fd71f86211 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -23,7 +23,20 @@ import warnings from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast +from typing import ( + Any, + Coroutine, + Dict, + Generic, + List, + Optional, + OrderedDict, + Tuple, + Type, + TypeVar, + Union, + cast, +) from flyteidl.core import tasks_pb2 @@ -147,9 +160,6 @@ class IgnoreOutputs(Exception): """ This exception should be used to indicate that the outputs generated by this can be safely ignored. This is useful in case of distributed training or peer-to-peer parallel algorithms. - - For example look at Sagemaker training, e.g. - :py:class:`plugins.awssagemaker.flytekitplugins.awssagemaker.training.SagemakerBuiltinAlgorithmsTask`. """ pass @@ -278,7 +288,12 @@ def local_execute( logger.info("Cache miss, task will be executed now") outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) # TODO: need `native_inputs` - LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) + LocalTaskCache.set( + self.name, + self.metadata.cache_version, + input_literal_map, + outputs_literal_map, + ) logger.info( f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} " f"and inputs: {input_literal_map}" @@ -448,7 +463,10 @@ def __init__( self._task_config = task_config if disable_deck is not None: - warnings.warn("disable_deck was deprecated in 1.10.0, please use enable_deck instead", FutureWarning) + warnings.warn( + "disable_deck was deprecated in 1.10.0, please use enable_deck instead", + FutureWarning, + ) # Confirm that disable_deck and enable_deck do not contradict each other if disable_deck is not None and enable_deck is not None: @@ -647,7 +665,13 @@ def dispatch_execute( # If executed inside of a workflow being executed locally, then run the coroutine to get the # actual results. return asyncio.run( - self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params) + self._async_execute( + native_inputs, + native_outputs, + ctx, + exec_ctx, + new_user_params, + ) ) return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params) @@ -660,7 +684,10 @@ def dispatch_execute( # Short circuit the translation to literal map because what's returned may be a dj spec (or an # already-constructed LiteralMap if the dynamic task was a no-op), not python native values # dynamic_execute returns a literal map in local execute so this also gets triggered. - if isinstance(native_outputs, (_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec)): + if isinstance( + native_outputs, + (_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec), + ): return native_outputs literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index a99fbf599e..aeb0e90273 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -2,7 +2,18 @@ import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union, overload +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Type, + TypeVar, + Union, + overload, +) from flytekit.core import launch_plan as _annotated_launchplan from flytekit.core import workflow as _annotated_workflow @@ -31,8 +42,7 @@ class TaskPlugins(object): # Plugin_object_type is a derivative of ``PythonFunctionTask`` Examples of available task plugins include different query-based plugins such as - :py:class:`flytekitplugins.athena.task.AthenaTask` and :py:class:`flytekitplugins.hive.task.HiveTask`, ML tools like - :py:class:`plugins.awssagemaker.flytekitplugins.awssagemaker.training.SagemakerBuiltinAlgorithmsTask`, kubeflow + :py:class:`flytekitplugins.athena.task.AthenaTask` and :py:class:`flytekitplugins.hive.task.HiveTask`, kubeflow operators like :py:class:`plugins.kfpytorch.flytekitplugins.kfpytorch.task.PyTorchFunctionTask` and :py:class:`plugins.kftensorflow.flytekitplugins.kftensorflow.task.TensorflowFunctionTask`, and generic plugins like :py:class:`flytekitplugins.pod.task.PodFunctionTask` which doesn't integrate with third party tools or services. @@ -102,7 +112,13 @@ def task( secret_requests: Optional[List[Secret]] = ..., execution_mode: PythonFunctionTask.ExecutionBehavior = ..., node_dependency_hints: Optional[ - Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]] + Iterable[ + Union[ + PythonFunctionTask, + _annotated_launchplan.LaunchPlan, + _annotated_workflow.WorkflowBase, + ] + ] ] = ..., task_resolver: Optional[TaskResolverMixin] = ..., docs: Optional[Documentation] = ..., @@ -133,7 +149,13 @@ def task( secret_requests: Optional[List[Secret]] = ..., execution_mode: PythonFunctionTask.ExecutionBehavior = ..., node_dependency_hints: Optional[ - Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]] + Iterable[ + Union[ + PythonFunctionTask, + _annotated_launchplan.LaunchPlan, + _annotated_workflow.WorkflowBase, + ] + ] ] = ..., task_resolver: Optional[TaskResolverMixin] = ..., docs: Optional[Documentation] = ..., @@ -163,7 +185,13 @@ def task( secret_requests: Optional[List[Secret]] = None, execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, node_dependency_hints: Optional[ - Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]] + Iterable[ + Union[ + PythonFunctionTask, + _annotated_launchplan.LaunchPlan, + _annotated_workflow.WorkflowBase, + ] + ] ] = None, task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, @@ -172,7 +200,11 @@ def task( pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, -) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]: +) -> Union[ + Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], + PythonFunctionTask[T], + Callable[..., FuncOut], +]: """ This is the core decorator to use for any task type in flytekit. @@ -337,7 +369,13 @@ class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, type], outputs: Dict[str, Type] + self, + project: str, + domain: str, + name: str, + version: str, + inputs: Dict[str, type], + outputs: Dict[str, Type], ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 490b458bbf..8afcef512a 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -89,16 +89,6 @@ def __call__(cls, *args, **kwargs): return o -def is_python_instance_task(obj): - for cls in inspect.getmro(type(obj)): - try: - if cls.__name__ == "PythonInstanceTask": - return True - except Exception: - pass - return False - - class TrackedInstance(metaclass=InstanceTrackingMeta): """ Please see the notes for the metaclass above first. @@ -181,11 +171,8 @@ def _candidate_name_matches(candidate) -> bool: except ValueError as err: logger.warning(f"Caught ValueError {err} while attempting to auto-assign name") - if not is_python_instance_task(self): - logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") - raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") - else: - return "" + logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") + raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") def isnested(func: Callable) -> bool: diff --git a/plugins/README.md b/plugins/README.md index d738c5b5a4..1dc2714fa3 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -6,8 +6,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | |------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| -| AWS Sagemaker Training | ```bash pip install flytekitplugins-awssagemaker ``` | Installs SDK to author Sagemaker built-in and custom training jobs in python | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Backend | -| dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | +| dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dask.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | | K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | @@ -145,6 +144,6 @@ Try using `gsed` instead of `sed` if you are on a Mac. Also this only works of c - Example of TaskTemplate plugin which also allows plugin writers to supply a prebuilt container for runtime: [flytekit-sqlalchemy](./flytekit-sqlalchemy/) - Example of a SQL backend plugin where the actual query invocation is done by a backend plugin: [flytekit-snowflake](./flytekit-snowflake/) - Example of a Meta plugin that can wrap other tasks: [flytekit-papermill](./flytekit-papermill/) -- Example of a plugin that modifies the execution command: [flytekit-spark](./flytekit-spark/) OR [flytekit-aws-sagemaker](./flytekit-aws-sagemaker/) +- Example of a plugin that modifies the execution command: [flytekit-spark](./flytekit-spark/) - Example that allows executing the user container with some other context modifications: [flytekit-kf-tensorflow](./flytekit-kf-tensorflow/) - Example of a Persistence Plugin that allows data to be stored to different persistence layers: [flytekit-data-fsspec](./flytekit-data-fsspec/) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index 32673f1145..c4e05698e2 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -51,7 +51,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N result = await boto3_object._call( method=method, config=config, - container=task_template.container, + images=custom["images"], inputs=inputs, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 9025c25864..7cbc9f12e8 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -3,7 +3,6 @@ import aioboto3 from flytekit.interaction.string_literals import literal_map_string_repr -from flytekit.models import task as _task_model from flytekit.models.literals import LiteralMap @@ -89,7 +88,7 @@ async def _call( self, method: str, config: Dict[str, Any], - container: Optional[_task_model.Container] = None, + images: Optional[dict[str, str]] = None, inputs: Optional[LiteralMap] = None, region: Optional[str] = None, aws_access_key_id: Optional[str] = None, @@ -101,14 +100,14 @@ async def _call( :param method: The boto3 method to invoke, e.g., create_endpoint_config. :param config: The configuration for the method, e.g., {"EndpointConfigName": "my-endpoint-config"}. The config - may contain placeholders replaced by values from inputs and container. + may contain placeholders replaced by values from inputs. For example, if the config is - {"EndpointConfigName": "{inputs.endpoint_config_name}", "EndpointName": "{endpoint_name}", - "Image": "{container.image}"} - the inputs contain a string literal for endpoint_config_name, and the container has the image, + {"EndpointConfigName": "{inputs.endpoint_config_name}", "EndpointName": "{inputs.endpoint_name}", + "Image": "{images.primary_container_image}"}, + the inputs contain a string literal for endpoint_config_name and endpoint_name and images contain primary_container_image, then the config will be updated to {"EndpointConfigName": "my-endpoint-config", "EndpointName": "my-endpoint", "Image": "my-image"} before invoking the boto3 method. - :param container: Container retrieved from the task template. + :param images: A dict of Docker images to use, for example, when deploying a model on SageMaker. :param inputs: The inputs for the task being created. :param region: The region for the boto3 client. If not provided, the region specified in the constructor will be used. :param aws_access_key_id: The access key ID to use to access the AWS resources. @@ -118,8 +117,8 @@ async def _call( args = {} if inputs: args["inputs"] = literal_map_string_repr(inputs) - if container: - args["container"] = {"image": container.image} + if images: + args["images"] = images updated_config = update_dict_fn(config, args) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py index 5a8873b798..c4057f4c4c 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -3,9 +3,10 @@ from flytekit import ImageSpec, kwtypes from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.core.python_function_task import PythonInstanceTask from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin +from flytekit.image_spec.image_spec import ImageBuildEngine @dataclass @@ -14,9 +15,10 @@ class BotoConfig(object): method: str config: dict[str, Any] region: str + images: Optional[dict[str, Union[str, ImageSpec]]] = None -class BotoTask(SyncAgentExecutorMixin, PythonInstanceTask[BotoConfig]): +class BotoTask(SyncAgentExecutorMixin, PythonTask[BotoConfig]): _TASK_TYPE = "boto" def __init__( @@ -24,7 +26,6 @@ def __init__( name: str, task_config: BotoConfig, inputs: Optional[dict[str, Type]] = None, - container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): super().__init__( @@ -32,14 +33,20 @@ def __init__( task_config=task_config, task_type=self._TASK_TYPE, interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)), - container_image=container_image, **kwargs, ) def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: + images = self.task_config.images + if images is not None: + [ImageBuildEngine.build(image) for image in images.values() if isinstance(image, ImageSpec)] + + print(images) + return { "service": self.task_config.service, "config": self.task_config.config, "region": self.task_config.region, "method": self.task_config.method, + "images": images, } diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py index 866e5bf4d7..ede71775f9 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Type, Union from flytekit import ImageSpec, kwtypes -from flytekit.configuration import DefaultImages, SerializationSettings +from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin @@ -16,8 +16,8 @@ def __init__( name: str, config: dict[str, Any], region: Optional[str], + images: dict[str, Union[str, ImageSpec]], inputs: Optional[dict[str, Type]] = None, - container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): """ @@ -27,15 +27,20 @@ def __init__( :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. :param inputs: The input literal map to be used for updating the configuration. - :param container_image: The path where inference code is stored. - This can be either in Amazon EC2 Container Registry or in a Docker registry - that is accessible from the same VPC that you configure for your endpoint. + :param image: The path where the inference code is stored can either be in the Amazon EC2 Container Registry + or in a Docker registry that is accessible from the same VPC that you configure for your endpoint. """ + super(SageMakerModelTask, self).__init__( name=name, - task_config=BotoConfig(service="sagemaker", method="create_model", config=config, region=region), + task_config=BotoConfig( + service="sagemaker", + method="create_model", + config=config, + region=region, + images=images, + ), inputs=inputs, - container_image=container_image, **kwargs, ) @@ -66,7 +71,6 @@ def __init__( region=region, ), inputs=inputs, - container_image=DefaultImages.default_image(), **kwargs, ) @@ -137,7 +141,6 @@ def __init__( region=region, ), inputs=inputs, - container_image=DefaultImages.default_image(), **kwargs, ) @@ -168,7 +171,6 @@ def __init__( region=region, ), inputs=inputs, - container_image=DefaultImages.default_image(), **kwargs, ) @@ -199,7 +201,6 @@ def __init__( region=region, ), inputs=inputs, - container_image=DefaultImages.default_image(), **kwargs, ) @@ -230,6 +231,5 @@ def __init__( region=region, ), inputs=inputs, - container_image=DefaultImages.default_image(), **kwargs, ) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py index 608b7a9473..5888889256 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py @@ -1,6 +1,6 @@ -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Type -from flytekit import ImageSpec, Workflow, kwtypes +from flytekit import Workflow, kwtypes from .task import ( SageMakerDeleteEndpointConfigTask, @@ -17,10 +17,10 @@ def create_sagemaker_deployment( model_config: dict[str, Any], endpoint_config_config: dict[str, Any], endpoint_config: dict[str, Any], + images: dict[str, Any], model_input_types: Optional[dict[str, Type]] = None, endpoint_config_input_types: Optional[dict[str, Type]] = None, endpoint_input_types: Optional[dict[str, Type]] = None, - container_image: Optional[Union[str, ImageSpec]] = None, region: Optional[str] = None, ): """ @@ -31,7 +31,7 @@ def create_sagemaker_deployment( config=model_config, region=region, inputs=model_input_types, - container_image=container_image, + images=images, ) endpoint_config_task = SageMakerEndpointConfigTask( diff --git a/plugins/flytekit-awssagemaker-inference/setup.py b/plugins/flytekit-awssagemaker-inference/setup.py index e91f13afdb..047756bc05 100644 --- a/plugins/flytekit-awssagemaker-inference/setup.py +++ b/plugins/flytekit-awssagemaker-inference/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" # s3fs 2023.9.2 requires aiobotocore~=2.5.4 -plugin_requires = ["flytekit>=1.10.0", "flyteidl", "aioboto3==11.1.1"] +plugin_requires = ["flytekit>1.10.7", "flyteidl>=1.11.0b0", "aioboto3==11.1.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_agent.py b/plugins/flytekit-awssagemaker-inference/tests/test_agent.py index 20e3390323..5f5dbc9d3f 100644 --- a/plugins/flytekit-awssagemaker-inference/tests/test_agent.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_agent.py @@ -15,11 +15,11 @@ @pytest.mark.asyncio @mock.patch( - "flytekitplugins.awssagemaker.agent.get_agent_secret", + "flytekitplugins.awssagemaker_inference.agent.get_agent_secret", return_value="mocked_secret", ) @mock.patch( - "flytekitplugins.awssagemaker.agent.Boto3AgentMixin._call", + "flytekitplugins.awssagemaker_inference.agent.Boto3AgentMixin._call", return_value={ "EndpointName": "sagemaker-xgboost-endpoint", "EndpointArn": "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint", diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py index 5e2183d6a7..1049787677 100644 --- a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py @@ -13,11 +13,11 @@ @pytest.mark.asyncio @mock.patch( - "flytekitplugins.awssagemaker.boto3_agent.get_agent_secret", + "flytekitplugins.awssagemaker_inference.boto3_agent.get_agent_secret", return_value="mocked_secret", ) @mock.patch( - "flytekitplugins.awssagemaker.boto3_agent.Boto3AgentMixin._call", + "flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call", return_value={ "ResponseMetadata": { "RequestId": "66f80391-348a-4ee0-9158-508914d16db2", @@ -58,6 +58,7 @@ async def test_agent(mock_boto_call, mock_secret): }, "region": "us-east-2", "method": "create_endpoint_config", + "images": None, } task_metadata = TaskMetadata( discoverable=True, diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py index 7a801c3a16..4c37891431 100644 --- a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py @@ -65,9 +65,9 @@ def test_inputs(): def test_container(): - original_dict = {"a": "{container.image}"} - container = {"image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} + original_dict = {"a": "{images.primary_container_image}"} + images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} - result = update_dict_fn(original_dict=original_dict, update_dict={"container": container}) + result = update_dict_fn(original_dict=original_dict, update_dict={"images": images}) assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_task.py b/plugins/flytekit-awssagemaker-inference/tests/test_task.py index fdeb6cab93..0213c91bba 100644 --- a/plugins/flytekit-awssagemaker-inference/tests/test_task.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_task.py @@ -14,14 +14,14 @@ @pytest.mark.parametrize( - "name,config,service,method,inputs,container_image,no_of_inputs,no_of_outputs,region,task", + "name,config,service,method,inputs,images,no_of_inputs,no_of_outputs,region,task", [ ( "sagemaker_model", { "ModelName": "{inputs.model_name}", "PrimaryContainer": { - "Image": "{container.image}", + "Image": "{images.primary_container_image}", "ModelDataUrl": "{inputs.model_data_url}", }, "ExecutionRoleArn": "{inputs.execution_role_arn}", @@ -29,7 +29,7 @@ "sagemaker", "create_model", kwtypes(model_name=str, model_data_url=str, execution_role_arn=str), - "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost", + {"primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost"}, 3, 1, "us-east-2", @@ -81,7 +81,7 @@ kwtypes(endpoint_name=str), None, 1, - 0, + 1, "us-east-2", SageMakerDeleteEndpointTask, ), @@ -93,7 +93,7 @@ kwtypes(endpoint_config_name=str), None, 1, - 0, + 1, "us-east-2", SageMakerDeleteEndpointConfigTask, ), @@ -105,7 +105,7 @@ kwtypes(model_name=str), None, 1, - 0, + 1, "us-east-2", SageMakerDeleteModelTask, ), @@ -132,19 +132,19 @@ def test_sagemaker_task( service, method, inputs, - container_image, + images, no_of_inputs, no_of_outputs, region, task, ): - if container_image: + if images: sagemaker_task = task( name=name, config=config, region=region, inputs=inputs, - container_image=container_image, + images=images, ) else: sagemaker_task = task( diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py b/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py index 739405ee6e..6740855b25 100644 --- a/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py +++ b/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py @@ -13,7 +13,7 @@ def test_sagemaker_deployment_workflow(): model_config={ "ModelName": "sagemaker-xgboost", "PrimaryContainer": { - "Image": "{container.image}", + "Image": "{images.primary_container_image}", "ModelDataUrl": "{inputs.model_path}", }, "ExecutionRoleArn": "{inputs.execution_role_arn}", @@ -37,7 +37,7 @@ def test_sagemaker_deployment_workflow(): "EndpointName": "sagemaker-xgboost-endpoint", "EndpointConfigName": "sagemaker-xgboost-endpoint-config", }, - container_image="1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost", + images={"primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost"}, region="us-east-2", ) diff --git a/plugins/setup.py b/plugins/setup.py index 002514f400..cafbfa4912 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -14,7 +14,7 @@ "flytekitplugins-async-fsspec": "flytekit-async-fsspec", "flytekitplugins-athena": "flytekit-aws-athena", "flytekitplugins-awsbatch": "flytekit-aws-batch", - "flytekitplugins-awssagemaker": "flytekit-aws-sagemaker", + "flytekitplugins-awssagemaker-inference": "flytekit-awssagemaker-inference", "flytekitplugins-bigquery": "flytekit-bigquery", "flytekitplugins-dask": "flytekit-dask", "flytekitplugins-dbt": "flytekit-dbt", diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index e028be477b..25c9c52fc3 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -4,10 +4,7 @@ from flytekit import task from flytekit.configuration.feature_flags import FeatureFlags -from flytekit.core.base_task import PythonTask -from flytekit.core.python_function_task import PythonInstanceTask from flytekit.core.tracker import extract_task_module -from flytekit.exceptions.system import FlyteSystemException from tests.flytekit.unit.core.tracker import d from tests.flytekit.unit.core.tracker.b import b_local_a, local_b from tests.flytekit.unit.core.tracker.c import b_in_c, c_local_a @@ -99,45 +96,6 @@ def test_extract_task_module(test_input, expected): raise -class FakePythonInstanceTaskWithExceptionLHS(PythonInstanceTask): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - @property - def instantiated_in(self) -> str: - # random module - return "tests.flytekit.unit.exceptions.test_base" - - -python_instance_task_instantiated = FakePythonInstanceTaskWithExceptionLHS(name="python_instance_task", task_config={}) - - -class FakePythonTaskWithExceptionLHS(PythonTask): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._raises_exception = True - - @property - def instantiated_in(self) -> str: - # random module - return "tests.flytekit.unit.exceptions.test_base" - - -python_task_instantiated = FakePythonTaskWithExceptionLHS( - name="python_task", - task_config={}, - task_type="python-task", -) - - -def test_raise_exception_when_accessing_nonexistent_lhs(): - _, _, name, _ = extract_task_module(python_instance_task_instantiated) - assert name == "" - - with pytest.raises(FlyteSystemException): - extract_task_module(python_task_instantiated) - - local_task = task(d.inner_function) From beb758bb5785a4ef599e5469fa33b2b12a13ac30 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 4 Mar 2024 15:36:54 +0530 Subject: [PATCH 065/120] tracker changes revert Signed-off-by: Samhita Alla --- flytekit/core/tracker.py | 2 +- .../unit/core/tracker/test_tracking.py | 23 ++++--------------- 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 8afcef512a..24ac0ffd06 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -130,7 +130,7 @@ def find_lhs(self) -> str: raise _system_exceptions.FlyteSystemException(f"Object {self} does not have an _instantiated in") logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}") - m = importlib.import_module(self.instantiated_in) + m = importlib.import_module(self._instantiated_in) for k in dir(m): try: if getattr(m, k) is self: diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index 25c9c52fc3..b33725436d 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -50,31 +50,16 @@ def convert_to_test(d: dict) -> typing.Tuple[typing.List[str], typing.List]: "core.task": (task, ("flytekit.core.task.task", "flytekit.core.task", "task")), "current-mod-tasks": ( d.tasks, - ( - "tests.flytekit.unit.core.tracker.d.tasks", - "tests.flytekit.unit.core.tracker.d", - "tasks", - ), - ), - "tasks-core-task": ( - d.task, - ("flytekit.core.task.task", "flytekit.core.task", "task"), + ("tests.flytekit.unit.core.tracker.d.tasks", "tests.flytekit.unit.core.tracker.d", "tasks"), ), + "tasks-core-task": (d.task, ("flytekit.core.task.task", "flytekit.core.task", "task")), "tracked-local": ( local_b, - ( - "tests.flytekit.unit.core.tracker.b.local_b", - "tests.flytekit.unit.core.tracker.b", - "local_b", - ), + ("tests.flytekit.unit.core.tracker.b.local_b", "tests.flytekit.unit.core.tracker.b", "local_b"), ), "tracked-b-in-c": ( b_in_c, - ( - "tests.flytekit.unit.core.tracker.c.b_in_c", - "tests.flytekit.unit.core.tracker.c", - "b_in_c", - ), + ("tests.flytekit.unit.core.tracker.c.b_in_c", "tests.flytekit.unit.core.tracker.c", "b_in_c"), ), } ) From 783a57d9075898d02de1dc82625269d77a655133 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 4 Mar 2024 16:19:20 +0530 Subject: [PATCH 066/120] dict to Dict Signed-off-by: Samhita Alla --- plugins/README.md | 1 + plugins/flytekit-awssagemaker-inference/README.md | 4 ++-- .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/README.md b/plugins/README.md index 1dc2714fa3..e83ad4b012 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -6,6 +6,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | |------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| +| AWS SageMaker inference | ```bash pip install flytekitplugins-awssagemaker-inference``` | Deploy SageMaker models, create and trigger inference endpoints. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker-inference.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker-inference/) | Python | | dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dask.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | diff --git a/plugins/flytekit-awssagemaker-inference/README.md b/plugins/flytekit-awssagemaker-inference/README.md index 020c757fc8..f7f9b3fa6f 100644 --- a/plugins/flytekit-awssagemaker-inference/README.md +++ b/plugins/flytekit-awssagemaker-inference/README.md @@ -23,7 +23,7 @@ sagemaker_deployment_wf = create_sagemaker_deployment( model_config={ "ModelName": MODEL_NAME, "PrimaryContainer": { - "Image": "{container.image}", + "Image": "{images.primary_container_image}", "ModelDataUrl": "{inputs.model_path}", }, "ExecutionRoleArn": "{inputs.execution_role_arn}", @@ -47,7 +47,7 @@ sagemaker_deployment_wf = create_sagemaker_deployment( "EndpointName": ENDPOINT_NAME, "EndpointConfigName": ENDPOINT_CONFIG_NAME, }, - container_image=custom_image, + images={"primary_container_image": custom_image}, region=REGION, ) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 7cbc9f12e8..3fe6b8e874 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -88,7 +88,7 @@ async def _call( self, method: str, config: Dict[str, Any], - images: Optional[dict[str, str]] = None, + images: Optional[Dict[str, str]] = None, inputs: Optional[LiteralMap] = None, region: Optional[str] = None, aws_access_key_id: Optional[str] = None, From fd2a0227781aacaed209b6ea74bff24925c2f73f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 4 Mar 2024 16:30:14 +0530 Subject: [PATCH 067/120] make images optional Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_task.py | 2 -- .../flytekitplugins/awssagemaker_inference/task.py | 2 +- .../flytekitplugins/awssagemaker_inference/workflow.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py index c4057f4c4c..56827889fc 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -41,8 +41,6 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: if images is not None: [ImageBuildEngine.build(image) for image in images.values() if isinstance(image, ImageSpec)] - print(images) - return { "service": self.task_config.service, "config": self.task_config.config, diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py index ede71775f9..27e0744dc3 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py @@ -16,7 +16,7 @@ def __init__( name: str, config: dict[str, Any], region: Optional[str], - images: dict[str, Union[str, ImageSpec]], + images: Optional[dict[str, Union[str, ImageSpec]]] = None, inputs: Optional[dict[str, Type]] = None, **kwargs, ): diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py index 5888889256..273722483a 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py @@ -17,7 +17,7 @@ def create_sagemaker_deployment( model_config: dict[str, Any], endpoint_config_config: dict[str, Any], endpoint_config: dict[str, Any], - images: dict[str, Any], + images: Optional[dict[str, Any]] = None, model_input_types: Optional[dict[str, Type]] = None, endpoint_config_input_types: Optional[dict[str, Type]] = None, endpoint_input_types: Optional[dict[str, Type]] = None, From bb0fcaf4421a46fa8f8b748472c103977684a3f0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 4 Mar 2024 17:27:12 +0530 Subject: [PATCH 068/120] add image_name Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_task.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py index 56827889fc..531b47758f 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -39,7 +39,13 @@ def __init__( def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: images = self.task_config.images if images is not None: - [ImageBuildEngine.build(image) for image in images.values() if isinstance(image, ImageSpec)] + for key, image in images.items(): + if isinstance(image, ImageSpec): + # Build the image + ImageBuildEngine.build(image) + + # Replace the value in the dictionary with image.image_name() + images[key] = image.image_name() return { "service": self.task_config.service, From 3e0a1b01372c36dcfdb35ae3a95a3d9f07dbabbc Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 16:05:46 +0530 Subject: [PATCH 069/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/agent.py | 2 ++ .../awssagemaker_inference/boto3_agent.py | 1 + .../awssagemaker_inference/boto3_task.py | 10 ++++++++++ 3 files changed, 13 insertions(+) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py index 2db22013e4..9e97cdacc0 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py @@ -41,6 +41,8 @@ class SageMakerEndpointMetadata(ResourceMeta): class SageMakerEndpointAgent(Boto3AgentMixin, AsyncAgentBase): """This agent creates an endpoint.""" + name = "SageMaker Endpoint Agent" + def __init__(self): super().__init__( service="sagemaker", diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index c4e05698e2..eee55fc3c4 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -38,6 +38,7 @@ def __init__(self): super().__init__(task_type_name="boto") async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: + print(custom["config"]) custom = task_template.custom service = custom["service"] raw_config = custom["config"] diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py index 531b47758f..2b663543e5 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -47,6 +47,16 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: # Replace the value in the dictionary with image.image_name() images[key] = image.image_name() + print( + { + "service": self.task_config.service, + "config": self.task_config.config, + "region": self.task_config.region, + "method": self.task_config.method, + "images": images, + } + ) + return { "service": self.task_config.service, "config": self.task_config.config, From 0741e5a637bc8fbfef52fd9f12e96e2b0ccd33ce Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 16:14:29 +0530 Subject: [PATCH 070/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index eee55fc3c4..1a753b97d4 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -38,8 +38,9 @@ def __init__(self): super().__init__(task_type_name="boto") async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: - print(custom["config"]) custom = task_template.custom + print(custom["config"]) + service = custom["service"] raw_config = custom["config"] convert_floats_with_no_fraction_to_ints(raw_config) From 1bf71e8a0900be295496dfd9e08d1785711e10ed Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 16:30:24 +0530 Subject: [PATCH 071/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 1 - .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index 1a753b97d4..66ab317474 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -39,7 +39,6 @@ def __init__(self): async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: custom = task_template.custom - print(custom["config"]) service = custom["service"] raw_config = custom["config"] diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 3fe6b8e874..ff6cc07dbf 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -117,6 +117,8 @@ async def _call( args = {} if inputs: args["inputs"] = literal_map_string_repr(inputs) + + print(args["images"]) if images: args["images"] = images From 492948efdf766ed50e96ff4f291c1138a7d7af6f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 17:25:18 +0530 Subject: [PATCH 072/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index ff6cc07dbf..db513efc44 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -115,10 +115,10 @@ async def _call( :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} + print(args["inputs"]) if inputs: args["inputs"] = literal_map_string_repr(inputs) - print(args["images"]) if images: args["images"] = images From 2bcb9aef92d2a7af5e6af225ad7efd4478289a76 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 18:20:15 +0530 Subject: [PATCH 073/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 2 ++ .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index 66ab317474..f39da5778d 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -47,6 +47,8 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N region = custom["region"] method = custom["method"] + print(boto3_object) + boto3_object = Boto3AgentMixin(service=service, region=region) result = await boto3_object._call( diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index db513efc44..2e0758d437 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -115,7 +115,6 @@ async def _call( :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} - print(args["inputs"]) if inputs: args["inputs"] = literal_map_string_repr(inputs) From 34e6d427b761728ce480c54c4157f1bd2338727d Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 18:37:48 +0530 Subject: [PATCH 074/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index f39da5778d..dd3dfd4715 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -46,15 +46,16 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N config = raw_config region = custom["region"] method = custom["method"] - - print(boto3_object) + images = custom["images"] boto3_object = Boto3AgentMixin(service=service, region=region) + print(result) + result = await boto3_object._call( method=method, config=config, - images=custom["images"], + images=images, inputs=inputs, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), From 6472fd4a9ed578a2caab4a1c2af827a7aed0977d Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 18:46:43 +0530 Subject: [PATCH 075/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 8 +++++++- .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index dd3dfd4715..acb2b767f5 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,7 +50,13 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) - print(result) + print(method) + print(config) + print(images) + print(inputs) + print(get_agent_secret(secret_key="aws-access-key")) + print(get_agent_secret(secret_key="aws-secret-access-key")) + print(get_agent_secret(secret_key="aws-session-token")) result = await boto3_object._call( method=method, diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 2e0758d437..bc91e19b8b 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -115,6 +115,8 @@ async def _call( :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} + print(args["inputs"]) + if inputs: args["inputs"] = literal_map_string_repr(inputs) From d31a0f6d7a9cad3e2c89e69861d68bedf67e66c9 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 18:47:51 +0530 Subject: [PATCH 076/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index acb2b767f5..b8ce7bbe53 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -57,6 +57,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N print(get_agent_secret(secret_key="aws-access-key")) print(get_agent_secret(secret_key="aws-secret-access-key")) print(get_agent_secret(secret_key="aws-session-token")) + print(result) result = await boto3_object._call( method=method, From 0292552466504b54292b84fb2f22c3983a9b360f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 18:58:01 +0530 Subject: [PATCH 077/120] debug Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_agent.py | 11 +---------- .../awssagemaker_inference/boto3_mixin.py | 4 +--- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index b8ce7bbe53..e18bb6ca8e 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,20 +50,11 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) - print(method) - print(config) - print(images) - print(inputs) - print(get_agent_secret(secret_key="aws-access-key")) - print(get_agent_secret(secret_key="aws-secret-access-key")) - print(get_agent_secret(secret_key="aws-session-token")) - print(result) - result = await boto3_object._call( method=method, config=config, images=images, - inputs=inputs, + inputs="", aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), aws_session_token=get_agent_secret(secret_key="aws-session-token"), diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index bc91e19b8b..9f298b00e0 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -3,7 +3,6 @@ import aioboto3 from flytekit.interaction.string_literals import literal_map_string_repr -from flytekit.models.literals import LiteralMap def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: @@ -89,7 +88,7 @@ async def _call( method: str, config: Dict[str, Any], images: Optional[Dict[str, str]] = None, - inputs: Optional[LiteralMap] = None, + inputs: Optional[str] = None, region: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, @@ -116,7 +115,6 @@ async def _call( """ args = {} print(args["inputs"]) - if inputs: args["inputs"] = literal_map_string_repr(inputs) From 176eadfe6247e02d0b8f0ef89656bf27e2108874 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 19:08:18 +0530 Subject: [PATCH 078/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 5 ++++- .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index e18bb6ca8e..8c5747ebe7 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,11 +50,14 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) + print(config) + print(result) + result = await boto3_object._call( method=method, config=config, images=images, - inputs="", + inputs=inputs, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), aws_session_token=get_agent_secret(secret_key="aws-session-token"), diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 9f298b00e0..db513efc44 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -3,6 +3,7 @@ import aioboto3 from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.models.literals import LiteralMap def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: @@ -88,7 +89,7 @@ async def _call( method: str, config: Dict[str, Any], images: Optional[Dict[str, str]] = None, - inputs: Optional[str] = None, + inputs: Optional[LiteralMap] = None, region: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, From 39a2e87f456a6990def069ccc307b06205c30c90 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 19:25:17 +0530 Subject: [PATCH 079/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index 8c5747ebe7..da8e9b3592 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,7 +50,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) - print(config) + print(images) print(result) result = await boto3_object._call( From 4a370263c0d830b2052f883ccf087882bd74c94a Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 19:39:20 +0530 Subject: [PATCH 080/120] debug Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_agent.py | 2 +- .../awssagemaker_inference/boto3_task.py | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index da8e9b3592..b3dd131a88 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,7 +50,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) - print(images) + print(get_agent_secret(secret_key="aws-secret-access-key")) print(result) result = await boto3_object._call( diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py index 2b663543e5..531b47758f 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -47,16 +47,6 @@ def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: # Replace the value in the dictionary with image.image_name() images[key] = image.image_name() - print( - { - "service": self.task_config.service, - "config": self.task_config.config, - "region": self.task_config.region, - "method": self.task_config.method, - "images": images, - } - ) - return { "service": self.task_config.service, "config": self.task_config.config, From dd5af68ad49ff6c21c760bbe247439046c7e12c2 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 20:05:49 +0530 Subject: [PATCH 081/120] debug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index b3dd131a88..743d6fb0a3 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,7 +50,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) - print(get_agent_secret(secret_key="aws-secret-access-key")) + print(get_agent_secret(secret_key="aws-access-key")) print(result) result = await boto3_object._call( From feb3f893c5b1eeab185268fdb557b573842ae0eb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 21:35:55 +0530 Subject: [PATCH 082/120] dict to Dict Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_agent.py | 3 -- .../awssagemaker_inference/boto3_mixin.py | 2 +- .../awssagemaker_inference/boto3_task.py | 10 +++--- .../awssagemaker_inference/task.py | 36 +++++++++---------- .../awssagemaker_inference/workflow.py | 16 ++++----- 5 files changed, 32 insertions(+), 35 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index 743d6fb0a3..37484862b8 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -50,9 +50,6 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N boto3_object = Boto3AgentMixin(service=service, region=region) - print(get_agent_secret(secret_key="aws-access-key")) - print(result) - result = await boto3_object._call( method=method, config=config, diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py index db513efc44..8d8fd36f11 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -115,7 +115,7 @@ async def _call( :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} - print(args["inputs"]) + if inputs: args["inputs"] = literal_map_string_repr(inputs) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py index 531b47758f..4bfbdffa84 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union from flytekit import ImageSpec, kwtypes from flytekit.configuration import SerializationSettings @@ -13,9 +13,9 @@ class BotoConfig(object): service: str method: str - config: dict[str, Any] + config: Dict[str, Any] region: str - images: Optional[dict[str, Union[str, ImageSpec]]] = None + images: Optional[Dict[str, Union[str, ImageSpec]]] = None class BotoTask(SyncAgentExecutorMixin, PythonTask[BotoConfig]): @@ -25,7 +25,7 @@ def __init__( self, name: str, task_config: BotoConfig, - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): super().__init__( @@ -36,7 +36,7 @@ def __init__( **kwargs, ) - def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: images = self.task_config.images if images is not None: for key, image in images.items(): diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py index 27e0744dc3..9acda6319c 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union from flytekit import ImageSpec, kwtypes from flytekit.configuration import SerializationSettings @@ -14,10 +14,10 @@ class SageMakerModelTask(BotoTask): def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - images: Optional[dict[str, Union[str, ImageSpec]]] = None, - inputs: Optional[dict[str, Type]] = None, + images: Optional[Dict[str, Union[str, ImageSpec]]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -49,9 +49,9 @@ class SageMakerEndpointConfigTask(BotoTask): def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -77,7 +77,7 @@ def __init__( @dataclass class SageMakerEndpointMetadata(object): - config: dict[str, Any] + config: Dict[str, Any] region: str @@ -87,9 +87,9 @@ class SageMakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SageMakerEndpoin def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -111,7 +111,7 @@ def __init__( **kwargs, ) - def get_custom(self, settings: SerializationSettings) -> dict[str, Any]: + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return {"config": self.task_config.config, "region": self.task_config.region} @@ -119,9 +119,9 @@ class SageMakerDeleteEndpointTask(BotoTask): def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -149,9 +149,9 @@ class SageMakerDeleteEndpointConfigTask(BotoTask): def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -179,9 +179,9 @@ class SageMakerDeleteModelTask(BotoTask): def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ @@ -209,9 +209,9 @@ class SageMakerInvokeEndpointTask(BotoTask): def __init__( self, name: str, - config: dict[str, Any], + config: Dict[str, Any], region: Optional[str], - inputs: Optional[dict[str, Type]] = None, + inputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py index 273722483a..511cc39444 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type +from typing import Any, Dict, Optional, Type from flytekit import Workflow, kwtypes @@ -14,13 +14,13 @@ def create_sagemaker_deployment( name: str, - model_config: dict[str, Any], - endpoint_config_config: dict[str, Any], - endpoint_config: dict[str, Any], - images: Optional[dict[str, Any]] = None, - model_input_types: Optional[dict[str, Type]] = None, - endpoint_config_input_types: Optional[dict[str, Type]] = None, - endpoint_input_types: Optional[dict[str, Type]] = None, + model_config: Dict[str, Any], + endpoint_config_config: Dict[str, Any], + endpoint_config: Dict[str, Any], + images: Optional[Dict[str, Any]] = None, + model_input_types: Optional[Dict[str, Type]] = None, + endpoint_config_input_types: Optional[Dict[str, Type]] = None, + endpoint_input_types: Optional[Dict[str, Type]] = None, region: Optional[str] = None, ): """ From bead8a9558a6bc7d93880d19f1916b532a8a1587 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 21:37:18 +0530 Subject: [PATCH 083/120] state to phase Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py index 9e97cdacc0..18535b9ba6 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py @@ -80,7 +80,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou ) current_state = endpoint_status.get("EndpointStatus") - flyte_state = convert_to_flyte_phase(states[current_state]) + flyte_phase = convert_to_flyte_phase(states[current_state]) message = None if current_state == "Failed": @@ -100,7 +100,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou } ) - return Resource(phase=flyte_state, outputs=res, message=message) + return Resource(phase=flyte_phase, outputs=res, message=message) async def delete(self, resource_meta: SageMakerEndpointMetadata, **kwargs): await self._call( From 1b6a329a428ce63869b60270bd54d1b3a6cb8c67 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 21:58:44 +0530 Subject: [PATCH 084/120] add encode mode to secrets.get Signed-off-by: Samhita Alla --- flytekit/extend/backend/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index b20c9fdf66..986a7ca2f9 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -38,7 +38,7 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool: def get_agent_secret(secret_key: str) -> str: - return flytekit.current_context().secrets.get(secret_key) + return flytekit.current_context().secrets.get(secret_key, encode_mode="rb") def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: From 49e68e207757a246277d22ff987b4cb8495ede20 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 22:20:48 +0530 Subject: [PATCH 085/120] revert: add encode mode to secrets.get Signed-off-by: Samhita Alla --- flytekit/extend/backend/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 986a7ca2f9..b20c9fdf66 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -38,7 +38,7 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool: def get_agent_secret(secret_key: str) -> str: - return flytekit.current_context().secrets.get(secret_key, encode_mode="rb") + return flytekit.current_context().secrets.get(secret_key) def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate: From 8c81fe0d4f89ca892268a820543549fbadd34096 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 5 Mar 2024 23:33:39 +0530 Subject: [PATCH 086/120] add decode Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index 37484862b8..b3a4b7478a 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -55,9 +55,9 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N config=config, images=images, inputs=inputs, - aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), - aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), - aws_session_token=get_agent_secret(secret_key="aws-session-token"), + aws_access_key_id=get_agent_secret(secret_key="aws-access-key").decode("utf-8"), + aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key").decode("utf-8"), + aws_session_token=get_agent_secret(secret_key="aws-session-token").decode("utf-8"), ) outputs = None From 284436347d21af3a1eee86434c18ccd0b8072835 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 6 Mar 2024 10:11:03 +0530 Subject: [PATCH 087/120] revert decode Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py index b3a4b7478a..37484862b8 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -55,9 +55,9 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N config=config, images=images, inputs=inputs, - aws_access_key_id=get_agent_secret(secret_key="aws-access-key").decode("utf-8"), - aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key").decode("utf-8"), - aws_session_token=get_agent_secret(secret_key="aws-session-token").decode("utf-8"), + aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), + aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), + aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) outputs = None From 2218ed1bea852bf94a0c2abb27b0e3c536f01d6f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 15 Mar 2024 15:02:17 +0530 Subject: [PATCH 088/120] ergonomic improvements; change plugin name Signed-off-by: Samhita Alla --- .github/workflows/pythonbuild.yml | 2 +- Dockerfile.agent | 2 +- docs/source/plugins/awssagemaker.rst | 2 +- plugins/README.md | 71 +++++++++++-------- .../README.md | 13 +++- .../dev-requirements.txt | 0 .../awssagemaker_inference/__init__.py | 1 + .../awssagemaker_inference/agent.py | 6 +- .../awssagemaker_inference/boto3_agent.py | 10 +-- .../awssagemaker_inference/boto3_mixin.py | 0 .../awssagemaker_inference/boto3_task.py | 0 .../awssagemaker_inference/task.py | 32 +++++++++ .../awssagemaker_inference/workflow.py | 0 .../setup.py | 8 +-- .../tests/__init__.py | 0 .../tests/test_boto3_agent.py | 0 .../tests/test_boto3_mixin.py | 0 .../tests/test_boto3_task.py | 0 .../tests/test_inference_agent.py} | 0 .../tests/test_inference_task.py} | 0 .../tests/test_inference_workflow.py} | 0 .../pydantic/deserialization.py | 2 +- plugins/setup.py | 2 +- 23 files changed, 100 insertions(+), 51 deletions(-) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/README.md (85%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/dev-requirements.txt (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/__init__.py (97%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/agent.py (95%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/boto3_agent.py (92%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/boto3_mixin.py (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/boto3_task.py (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/task.py (86%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/flytekitplugins/awssagemaker_inference/workflow.py (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/setup.py (82%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/tests/__init__.py (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/tests/test_boto3_agent.py (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/tests/test_boto3_mixin.py (100%) rename plugins/{flytekit-awssagemaker-inference => flytekit-aws-sagemaker}/tests/test_boto3_task.py (100%) rename plugins/{flytekit-awssagemaker-inference/tests/test_agent.py => flytekit-aws-sagemaker/tests/test_inference_agent.py} (100%) rename plugins/{flytekit-awssagemaker-inference/tests/test_task.py => flytekit-aws-sagemaker/tests/test_inference_task.py} (100%) rename plugins/{flytekit-awssagemaker-inference/tests/test_workflow.py => flytekit-aws-sagemaker/tests/test_inference_workflow.py} (100%) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index c71ed21c86..efa0770e05 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -215,7 +215,7 @@ jobs: - flytekit-async-fsspec - flytekit-aws-athena - flytekit-aws-batch - - flytekit-awssagemaker-inference + - flytekit-aws-sagemaker - flytekit-bigquery - flytekit-dask - flytekit-data-fsspec diff --git a/Dockerfile.agent b/Dockerfile.agent index 4c3afee13e..69390c631a 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -13,7 +13,7 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ flytekitplugins-bigquery==$VERSION \ flytekitplugins-chatgpt==$VERSION \ flytekitplugins-snowflake==$VERSION \ - flytekitplugins-awssagemaker-inference==$VERSION \ + flytekitplugins-aws-sagemaker==$VERSION \ && apt-get clean autoclean \ && apt-get autoremove --yes \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ diff --git a/docs/source/plugins/awssagemaker.rst b/docs/source/plugins/awssagemaker.rst index c0862c7dc4..bc29334aa4 100644 --- a/docs/source/plugins/awssagemaker.rst +++ b/docs/source/plugins/awssagemaker.rst @@ -1,4 +1,4 @@ -.. _awssagemaker_inference: +.. _aws_sagemaker: ########################### AWS Sagemaker API reference diff --git a/plugins/README.md b/plugins/README.md index e83ad4b012..b5d23957bc 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -4,31 +4,33 @@ All the Flytekit plugins maintained by the core team are added here. It is not n ## Currently Available Plugins 🔌 -| Plugin | Installation | Description | Version | Type | -|------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| -| AWS SageMaker inference | ```bash pip install flytekitplugins-awssagemaker-inference``` | Deploy SageMaker models, create and trigger inference endpoints. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker-inference.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker-inference/) | Python | -| dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dask.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | -| Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | -| K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | -| K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | -| K8s native MPI Jobs | ```bash pip install flytekitplugins-kfmpi ``` | Installs SDK to author Distributed MPI Jobs in python using Kubeflow MPI Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfmpi.svg)](https://pypi.python.org/pypi/flytekitplugins-kfmpi/) | Backend | -| Papermill based Tasks | ```bash pip install flytekitplugins-papermill ``` | Execute entire notebooks as Flyte Tasks and pass inputs and outputs between them and python tasks | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-papermill.svg)](https://pypi.python.org/pypi/flytekitplugins-papermill/) | Flytekit-only | -| Pod Tasks | ```bash pip install flytekitplugins-pod ``` | Installs SDK to author Pods in python. These pods can have multiple containers, use volumes and have non exiting side-cars | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-pod.svg)](https://pypi.python.org/pypi/flytekitplugins-pod/) | Flytekit-only | -| spark | ```bash pip install flytekitplugins-spark ``` | Installs SDK to author Spark jobs that can be executed natively on Kubernetes with a supported backend Flyte plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-spark/) | Backend | -| AWS Athena Queries | ```bash pip install flytekitplugins-athena ``` | Installs SDK to author queries executed on AWS Athena | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-athena.svg)](https://pypi.python.org/pypi/flytekitplugins-athena/) | Backend | -| DOLT | ```bash pip install flytekitplugins-dolt ``` | Read & write dolt data sets and use dolt tables as native types | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dolt.svg)](https://pypi.python.org/pypi/flytekitplugins-dolt/) | Flytekit-only | -| Pandera | ```bash pip install flytekitplugins-pandera ``` | Use Pandera schemas as native Flyte types, which enable data quality checks. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-pandera.svg)](https://pypi.python.org/pypi/flytekitplugins-pandera/) | Flytekit-only | -| SQLAlchemy | ```bash pip install flytekitplugins-sqlalchemy ``` | Write queries for any database that supports SQLAlchemy | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-sqlalchemy.svg)](https://pypi.python.org/pypi/flytekitplugins-sqlalchemy/) | Flytekit-only | -| Great Expectations | ```bash pip install flytekitplugins-great-expectations``` | Enforce data quality for various data types within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-great-expectations.svg)](https://pypi.python.org/pypi/flytekitplugins-great-expectations/) | Flytekit-only | -| Snowflake | ```bash pip install flytekitplugins-snowflake``` | Use Snowflake as a 'data warehouse-as-a-service' within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-snowflake.svg)](https://pypi.python.org/pypi/flytekitplugins-snowflake/) | Backend | -| dbt | ```bash pip install flytekitplugins-dbt``` | Run dbt within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dbt.svg)](https://pypi.python.org/pypi/flytekitplugins-dbt/) | Flytekit-only | -| Huggingface | ```bash pip install flytekitplugins-huggingface``` | Read & write Hugginface Datasets as Flyte StructuredDatasets | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-huggingface.svg)](https://pypi.python.org/pypi/flytekitplugins-huggingface/) | Flytekit-only | -| DuckDB | ```bash pip install flytekitplugins-duckdb``` | Run analytical workloads with ease using DuckDB | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-duckdb.svg)](https://pypi.python.org/pypi/flytekitplugins-duckdb/) | Flytekit-only | +| Plugin | Installation | Description | Version | Type | +| ---------------------------- | ----------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------- | +| AWS SageMaker | `bash pip install flytekitplugins-aws-sagemaker` | Deploy SageMaker models and manage inference endpoints with ease. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-aws-sagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-aws-sagemaker/) | Python | +| dask | `bash pip install flytekitplugins-dask ` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dask.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | +| Hive Queries | `bash pip install flytekitplugins-hive ` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | +| K8s distributed PyTorch Jobs | `bash pip install flytekitplugins-kfpytorch ` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | +| K8s native tensorflow Jobs | `bash pip install flytekitplugins-kftensorflow ` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | +| K8s native MPI Jobs | `bash pip install flytekitplugins-kfmpi ` | Installs SDK to author Distributed MPI Jobs in python using Kubeflow MPI Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfmpi.svg)](https://pypi.python.org/pypi/flytekitplugins-kfmpi/) | Backend | +| Papermill based Tasks | `bash pip install flytekitplugins-papermill ` | Execute entire notebooks as Flyte Tasks and pass inputs and outputs between them and python tasks | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-papermill.svg)](https://pypi.python.org/pypi/flytekitplugins-papermill/) | Flytekit-only | +| Pod Tasks | `bash pip install flytekitplugins-pod ` | Installs SDK to author Pods in python. These pods can have multiple containers, use volumes and have non exiting side-cars | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-pod.svg)](https://pypi.python.org/pypi/flytekitplugins-pod/) | Flytekit-only | +| spark | `bash pip install flytekitplugins-spark ` | Installs SDK to author Spark jobs that can be executed natively on Kubernetes with a supported backend Flyte plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-spark.svg)](https://pypi.python.org/pypi/flytekitplugins-spark/) | Backend | +| AWS Athena Queries | `bash pip install flytekitplugins-athena ` | Installs SDK to author queries executed on AWS Athena | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-athena.svg)](https://pypi.python.org/pypi/flytekitplugins-athena/) | Backend | +| DOLT | `bash pip install flytekitplugins-dolt ` | Read & write dolt data sets and use dolt tables as native types | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dolt.svg)](https://pypi.python.org/pypi/flytekitplugins-dolt/) | Flytekit-only | +| Pandera | `bash pip install flytekitplugins-pandera ` | Use Pandera schemas as native Flyte types, which enable data quality checks. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-pandera.svg)](https://pypi.python.org/pypi/flytekitplugins-pandera/) | Flytekit-only | +| SQLAlchemy | `bash pip install flytekitplugins-sqlalchemy ` | Write queries for any database that supports SQLAlchemy | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-sqlalchemy.svg)](https://pypi.python.org/pypi/flytekitplugins-sqlalchemy/) | Flytekit-only | +| Great Expectations | `bash pip install flytekitplugins-great-expectations` | Enforce data quality for various data types within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-great-expectations.svg)](https://pypi.python.org/pypi/flytekitplugins-great-expectations/) | Flytekit-only | +| Snowflake | `bash pip install flytekitplugins-snowflake` | Use Snowflake as a 'data warehouse-as-a-service' within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-snowflake.svg)](https://pypi.python.org/pypi/flytekitplugins-snowflake/) | Backend | +| dbt | `bash pip install flytekitplugins-dbt` | Run dbt within Flyte | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dbt.svg)](https://pypi.python.org/pypi/flytekitplugins-dbt/) | Flytekit-only | +| Huggingface | `bash pip install flytekitplugins-huggingface` | Read & write Hugginface Datasets as Flyte StructuredDatasets | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-huggingface.svg)](https://pypi.python.org/pypi/flytekitplugins-huggingface/) | Flytekit-only | +| DuckDB | `bash pip install flytekitplugins-duckdb` | Run analytical workloads with ease using DuckDB | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-duckdb.svg)](https://pypi.python.org/pypi/flytekitplugins-duckdb/) | Flytekit-only | ## Have a Plugin Idea? 💡 + Please [file an issue](https://github.com/flyteorg/flyte/issues/new?assignees=&labels=untriaged%2Cplugins&template=backend-plugin-request.md&title=%5BPlugin%5D). ## Development 💻 + Flytekit plugins are structured as micro-libs and can be authored in an independent repository. > Refer to the [Python microlibs](https://medium.com/@jherreras/python-microlibs-5be9461ad979) blog to understand the idea of microlibs. @@ -36,15 +38,18 @@ Flytekit plugins are structured as micro-libs and can be authored in an independ The plugins maintained by the core team can be found in this repository and provide a simple way of discovery. ## Unit tests 🧪 + Plugins should have their own unit tests. ## Guidelines 📜 + Some guidelines to help you write the Flytekit plugins better. 1. The folder name has to be `flytekit-*`, e.g., `flytekit-hive`. In case you want to group for a specific service, then use `flytekit-aws-athena`. 2. Flytekit plugins use a concept called [Namespace packages](https://packaging.python.org/guides/creating-and-discovering-plugins/#using-namespace-packages), and thus, the package structure is essential. Please use the following Python package structure: + ``` flytekit-myplugin/ - README.md @@ -55,7 +60,8 @@ Some guidelines to help you write the Flytekit plugins better. - tests - __init__.py ``` - *NOTE:* the inner package `flytekitplugins` DOES NOT have an `__init__.py` file. + + _NOTE:_ the inner package `flytekitplugins` DOES NOT have an `__init__.py` file. 3. The published packages have to be named `flytekitplugins-{package-name}`, where `{package-name}` is a unique identifier for the plugin. @@ -113,23 +119,25 @@ setup( # entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) ``` -5. Each plugin should have a README.md, which describes how to install it with a simple example. For example, refer to flytekit-greatexpectations' [README](./flytekit-greatexpectations/README.md). -6. Each plugin should have its own tests' package. *NOTE:* `tests` folder should have an `__init__.py` file. +5. Each plugin should have a README.md, which describes how to install it with a simple example. For example, refer to flytekit-greatexpectations' [README](./flytekit-greatexpectations/README.md). + +6. Each plugin should have its own tests' package. _NOTE:_ `tests` folder should have an `__init__.py` file. -7. There may be some cases where you might want to auto-load some of your modules when the plugin is installed. This is especially true for `data-plugins` and `type-plugins`. -In such a case, you can add a special directive in the `setup.py` which will instruct Flytekit to automatically load the prescribed modules. +7. There may be some cases where you might want to auto-load some of your modules when the plugin is installed. This is especially true for `data-plugins` and `type-plugins`. + In such a case, you can add a special directive in the `setup.py` which will instruct Flytekit to automatically load the prescribed modules. - Following shows an excerpt from the `flytekit-data-fsspec` plugin's setup.py file. + Following shows an excerpt from the `flytekit-data-fsspec` plugin's setup.py file. - ```python - setup( - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, - ) + ```python + setup( + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, + ) - ``` + ``` ### Flytekit Version Pinning + Currently we advocate pinning to minor releases of flytekit. To bump the pins across the board, `cd plugins/` and then update the command below with the appropriate range and run @@ -140,6 +148,7 @@ for f in $(ls **/setup.py); do sed -i "s/flytekit>.*,<1.1/flytekit>=1.1.0b0,<1.2 Try using `gsed` instead of `sed` if you are on a Mac. Also this only works of course for setup files that start with the version in your sed command. There may be plugins that have different pins to start out with. ## References 📚 + - Example of a simple Python task that allows adding only Python side functionality: [flytekit-greatexpectations](./flytekit-greatexpectations/) - Example of a TypeTransformer or a Type Plugin: [flytekit-pandera](./flytekit-pandera/). These plugins add new types to Flyte and tell Flyte how to transform them and add additional features through types. Flyte is a multi-lang system, and type transformers allow marshaling between Flytekit and backend and other languages. - Example of TaskTemplate plugin which also allows plugin writers to supply a prebuilt container for runtime: [flytekit-sqlalchemy](./flytekit-sqlalchemy/) diff --git a/plugins/flytekit-awssagemaker-inference/README.md b/plugins/flytekit-aws-sagemaker/README.md similarity index 85% rename from plugins/flytekit-awssagemaker-inference/README.md rename to plugins/flytekit-aws-sagemaker/README.md index f7f9b3fa6f..4ac862e6e6 100644 --- a/plugins/flytekit-awssagemaker-inference/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -1,17 +1,24 @@ -# AWS SageMaker Inference Plugin +# AWS SageMaker Plugin -The plugin features a deployment agent enabling you to deploy SageMaker models, create and trigger inference endpoints. +The plugin currently features a SageMaker deployment agent. + +## Inference + +The deployment agent enables you to deploy models, create and trigger inference endpoints. Additionally, you can entirely remove the SageMaker deployment using the `delete_sagemaker_deployment` workflow. To install the plugin, run the following command: ```bash -pip install flytekitplugins-awssagemaker-inference +pip install flytekitplugins-aws-sagemaker ``` Here is a sample SageMaker deployment workflow: ```python +from flytekitplugins.awssagemaker_inference import create_sagemaker_deployment + + REGION = os.getenv("REGION") MODEL_NAME = "xgboost" ENDPOINT_CONFIG_NAME = "xgboost-endpoint-config" diff --git a/plugins/flytekit-awssagemaker-inference/dev-requirements.txt b/plugins/flytekit-aws-sagemaker/dev-requirements.txt similarity index 100% rename from plugins/flytekit-awssagemaker-inference/dev-requirements.txt rename to plugins/flytekit-aws-sagemaker/dev-requirements.txt diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py similarity index 97% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/__init__.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py index 169abc2e76..cbb05d5178 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py @@ -30,5 +30,6 @@ SageMakerEndpointTask, SageMakerInvokeEndpointTask, SageMakerModelTask, + triton_image_uri, ) from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py similarity index 95% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 18535b9ba6..798dc656d7 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -54,8 +54,8 @@ async def create( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs ) -> SageMakerEndpointMetadata: custom = task_template.custom - config = custom["config"] - region = custom["region"] + config = custom.get("config") + region = custom.get("region") await self._call( method="create_endpoint", @@ -67,7 +67,7 @@ async def create( aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - return SageMakerEndpointMetadata(endpoint_name=config["EndpointName"], region=region) + return SageMakerEndpointMetadata(endpoint_name=config.get("EndpointName"), region=region) async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: endpoint_status = await self._call( diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py similarity index 92% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index 37484862b8..5fe7914286 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -40,13 +40,13 @@ def __init__(self): async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, **kwargs) -> Resource: custom = task_template.custom - service = custom["service"] - raw_config = custom["config"] + service = custom.get("service") + raw_config = custom.get("config") convert_floats_with_no_fraction_to_ints(raw_config) config = raw_config - region = custom["region"] - method = custom["method"] - images = custom["images"] + region = custom.get("region") + method = custom.get("method") + images = custom.get("images") boto3_object = Boto3AgentMixin(service=service, region=region) diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_mixin.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/boto3_task.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py similarity index 86% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index 9acda6319c..faefb73005 100644 --- a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -9,6 +9,33 @@ from .boto3_task import BotoConfig, BotoTask +account_id_map = { + "us-east-1": "785573368785", + "us-east-2": "007439368137", + "us-west-1": "710691900526", + "us-west-2": "301217895009", + "eu-west-1": "802834080501", + "eu-west-2": "205493899709", + "eu-west-3": "254080097072", + "eu-north-1": "601324751636", + "eu-south-1": "966458181534", + "eu-central-1": "746233611703", + "ap-east-1": "110948597952", + "ap-south-1": "763008648453", + "ap-northeast-1": "941853720454", + "ap-northeast-2": "151534178276", + "ap-southeast-1": "324986816169", + "ap-southeast-2": "355873309152", + "cn-northwest-1": "474822919863", + "cn-north-1": "472730292857", + "sa-east-1": "756306329178", + "ca-central-1": "464438896020", + "me-south-1": "836785723513", + "af-south-1": "774647643957", +} + +triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:21.08-py3" + class SageMakerModelTask(BotoTask): def __init__( @@ -31,6 +58,11 @@ def __init__( or in a Docker registry that is accessible from the same VPC that you configure for your endpoint. """ + for image_name, image in images.items(): + if "{region}" in image: + base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com" + images[image_name] = image.format(account_id=account_id_map[region], region=region, base=base) + super(SageMakerModelTask, self).__init__( name=name, task_config=BotoConfig( diff --git a/plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/flytekitplugins/awssagemaker_inference/workflow.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py diff --git a/plugins/flytekit-awssagemaker-inference/setup.py b/plugins/flytekit-aws-sagemaker/setup.py similarity index 82% rename from plugins/flytekit-awssagemaker-inference/setup.py rename to plugins/flytekit-aws-sagemaker/setup.py index 047756bc05..059c05e661 100644 --- a/plugins/flytekit-awssagemaker-inference/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -PLUGIN_NAME = "awssagemaker_inference" +PLUGIN_NAME = "awssagemaker" microlib_name = f"flytekitplugins-{PLUGIN_NAME}" @@ -14,9 +14,9 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="Flytekit AWS SageMaker Inference Plugin", + description="Flytekit AWS SageMaker Plugin", namespace_packages=["flytekitplugins"], - packages=[f"flytekitplugins.{PLUGIN_NAME}"], + packages=["flytekitplugins.awssagemaker_inference"], install_requires=plugin_requires, license="apache2", python_requires=">=3.8", @@ -34,5 +34,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, + entry_points={"flytekit.plugins": ["awssagemaker_inference=flytekitplugins.awssagemaker_inference"]}, ) diff --git a/plugins/flytekit-awssagemaker-inference/tests/__init__.py b/plugins/flytekit-aws-sagemaker/tests/__init__.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/__init__.py rename to plugins/flytekit-aws-sagemaker/tests/__init__.py diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/test_boto3_agent.py rename to plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/test_boto3_mixin.py rename to plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_boto3_task.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/test_boto3_task.py rename to plugins/flytekit-aws-sagemaker/tests/test_boto3_task.py diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/test_agent.py rename to plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_task.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/test_task.py rename to plugins/flytekit-aws-sagemaker/tests/test_inference_task.py diff --git a/plugins/flytekit-awssagemaker-inference/tests/test_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py similarity index 100% rename from plugins/flytekit-awssagemaker-inference/tests/test_workflow.py rename to plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py index a7d602b402..7f0d6c5832 100644 --- a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -11,7 +11,7 @@ pydantic = lazy_module("pydantic") # this field is used by pydantic to get the validator method -PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ +PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_pydantic_core_schema__.__name__ PythonType = TypeVar("PythonType") # target type of the deserialization diff --git a/plugins/setup.py b/plugins/setup.py index cafbfa4912..f1d5b5bf2e 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -14,7 +14,7 @@ "flytekitplugins-async-fsspec": "flytekit-async-fsspec", "flytekitplugins-athena": "flytekit-aws-athena", "flytekitplugins-awsbatch": "flytekit-aws-batch", - "flytekitplugins-awssagemaker-inference": "flytekit-awssagemaker-inference", + "flytekitplugins-aws-sagemaker": "flytekit-aws-sagemaker", "flytekitplugins-bigquery": "flytekit-bigquery", "flytekitplugins-dask": "flytekit-dask", "flytekitplugins-dbt": "flytekit-dbt", From c8b8e380e8f23b9d3d25b094c3b11092e8f94be2 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 15 Mar 2024 15:13:40 +0530 Subject: [PATCH 089/120] change plugin name Signed-off-by: Samhita Alla --- Dockerfile.agent | 2 +- plugins/README.md | 2 +- plugins/flytekit-aws-sagemaker/README.md | 2 +- plugins/setup.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Dockerfile.agent b/Dockerfile.agent index 69390c631a..886e4af613 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -13,7 +13,7 @@ RUN pip install --no-cache-dir -U flytekit==$VERSION \ flytekitplugins-bigquery==$VERSION \ flytekitplugins-chatgpt==$VERSION \ flytekitplugins-snowflake==$VERSION \ - flytekitplugins-aws-sagemaker==$VERSION \ + flytekitplugins-awssagemaker==$VERSION \ && apt-get clean autoclean \ && apt-get autoremove --yes \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ \ diff --git a/plugins/README.md b/plugins/README.md index b5d23957bc..f12b17cb3f 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -6,7 +6,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | | ---------------------------- | ----------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------- | -| AWS SageMaker | `bash pip install flytekitplugins-aws-sagemaker` | Deploy SageMaker models and manage inference endpoints with ease. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-aws-sagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-aws-sagemaker/) | Python | +| AWS SageMaker | `bash pip install flytekitplugins-awssagemaker` | Deploy SageMaker models and manage inference endpoints with ease. | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Python | | dask | `bash pip install flytekitplugins-dask ` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-dask.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | `bash pip install flytekitplugins-hive ` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | `bash pip install flytekitplugins-kfpytorch ` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index 4ac862e6e6..b8eacf0914 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -10,7 +10,7 @@ Additionally, you can entirely remove the SageMaker deployment using the `delete To install the plugin, run the following command: ```bash -pip install flytekitplugins-aws-sagemaker +pip install flytekitplugins-awssagemaker ``` Here is a sample SageMaker deployment workflow: diff --git a/plugins/setup.py b/plugins/setup.py index f1d5b5bf2e..002514f400 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -14,7 +14,7 @@ "flytekitplugins-async-fsspec": "flytekit-async-fsspec", "flytekitplugins-athena": "flytekit-aws-athena", "flytekitplugins-awsbatch": "flytekit-aws-batch", - "flytekitplugins-aws-sagemaker": "flytekit-aws-sagemaker", + "flytekitplugins-awssagemaker": "flytekit-aws-sagemaker", "flytekitplugins-bigquery": "flytekit-bigquery", "flytekitplugins-dask": "flytekit-dask", "flytekitplugins-dbt": "flytekit-dbt", From 1d7821fe2b36bdb7fac163ac1fce1618754b6007 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 15 Mar 2024 15:47:49 +0530 Subject: [PATCH 090/120] nit Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 059c05e661..ba1d767c32 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -1,6 +1,7 @@ from setuptools import setup PLUGIN_NAME = "awssagemaker" +INFERENCE_PACKAGE = "awssagemaker_inference" microlib_name = f"flytekitplugins-{PLUGIN_NAME}" @@ -16,7 +17,7 @@ author_email="admin@flyte.org", description="Flytekit AWS SageMaker Plugin", namespace_packages=["flytekitplugins"], - packages=["flytekitplugins.awssagemaker_inference"], + packages=[f"flytekitplugins.{INFERENCE_PACKAGE}"], install_requires=plugin_requires, license="apache2", python_requires=">=3.8", @@ -34,5 +35,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={"flytekit.plugins": ["awssagemaker_inference=flytekitplugins.awssagemaker_inference"]}, + entry_points={"flytekit.plugins": [f"{INFERENCE_PACKAGE}=flytekitplugins.{INFERENCE_PACKAGE}"]}, ) From fb8f04a59dd8c709c35d9adc6889b7ed227c2aaa Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 15 Mar 2024 16:32:59 +0530 Subject: [PATCH 091/120] image check Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index faefb73005..92bb98026b 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -59,7 +59,7 @@ def __init__( """ for image_name, image in images.items(): - if "{region}" in image: + if isinstance(image, str) and "{region}" in image: base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com" images[image_name] = image.format(account_id=account_id_map[region], region=region, base=base) From 71c1a4bde9b8daf0938aa875b3d1ac7aab2a22ec Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 15 Mar 2024 22:18:32 +0530 Subject: [PATCH 092/120] add api docs Signed-off-by: Samhita Alla --- docs/source/plugins/awssagemaker.rst | 8 ++++---- docs/source/plugins/index.rst | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/plugins/awssagemaker.rst b/docs/source/plugins/awssagemaker.rst index bc29334aa4..dc16dc1dbb 100644 --- a/docs/source/plugins/awssagemaker.rst +++ b/docs/source/plugins/awssagemaker.rst @@ -1,8 +1,8 @@ -.. _aws_sagemaker: +.. _awssagemaker_inference: -########################### -AWS Sagemaker API reference -########################### +##################################### +AWS Sagemaker Inference API reference +##################################### .. tags:: Integration, MachineLearning, AWS diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index 06e4d7cd58..c2f6599e03 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -30,6 +30,7 @@ Plugin API reference * :ref:`Vaex ` - Vaex API reference * :ref:`MLflow ` - MLflow API reference * :ref:`DuckDB ` - DuckDB API reference +* :ref:`SageMaker Inference ` - SageMaker Inference API reference .. toctree:: :maxdepth: 2 @@ -61,3 +62,4 @@ Plugin API reference Vaex MLflow DuckDB + SageMaker Inference From aea16ded2197c115612df694a16c8c8e385ef613 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 15 Mar 2024 22:21:04 +0530 Subject: [PATCH 093/120] add api docs Signed-off-by: Samhita Alla --- .../plugins/{awssagemaker.rst => awssagemaker_inference.rst} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/source/plugins/{awssagemaker.rst => awssagemaker_inference.rst} (100%) diff --git a/docs/source/plugins/awssagemaker.rst b/docs/source/plugins/awssagemaker_inference.rst similarity index 100% rename from docs/source/plugins/awssagemaker.rst rename to docs/source/plugins/awssagemaker_inference.rst From 41f117ad6edf7344b271c0cd9d594f90b31c5c14 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 18 Mar 2024 12:04:21 +0530 Subject: [PATCH 094/120] incorporate Kevin's suggestions Signed-off-by: Samhita Alla --- .../awssagemaker_inference/agent.py | 14 +------------- .../awssagemaker_inference/boto3_agent.py | 14 +------------- plugins/flytekit-aws-sagemaker/setup.py | 8 +++----- .../tests/test_boto3_agent.py | 2 +- .../tests/test_inference_agent.py | 2 +- 5 files changed, 7 insertions(+), 33 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 798dc656d7..23a4ced379 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -3,8 +3,6 @@ from datetime import datetime from typing import Optional -from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentRegistry, AsyncAgentBase, @@ -88,17 +86,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou res = None if current_state == "InService": - ctx = FlyteContextManager.current_context() - res = LiteralMap( - { - "result": TypeEngine.to_literal( - ctx, - json.dumps(endpoint_status, cls=DateTimeEncoder), - str, - TypeEngine.to_literal_type(str), - ) - } - ) + res = {"result": json.dumps(endpoint_status, cls=DateTimeEncoder)} return Resource(phase=flyte_phase, outputs=res, message=message) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py index 5fe7914286..adb1772248 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_agent.py @@ -2,8 +2,6 @@ from flyteidl.core.execution_pb2 import TaskExecution -from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import ( AgentRegistry, Resource, @@ -62,17 +60,7 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N outputs = None if result: - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - { - "result": TypeEngine.to_literal( - ctx, - result, - dict, - TypeEngine.to_literal_type(dict), - ) - } - ) + outputs = {"result": result} return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs) diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index ba1d767c32..cdc4b816b6 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -5,8 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -# s3fs 2023.9.2 requires aiobotocore~=2.5.4 -plugin_requires = ["flytekit>1.10.7", "flyteidl>=1.11.0b0", "aioboto3==11.1.1"] +plugin_requires = ["flytekit>=1.11.0", "aioboto3>=12.3.0"] __version__ = "0.0.0+develop" @@ -20,15 +19,14 @@ packages=[f"flytekitplugins.{INFERENCE_PACKAGE}"], install_requires=plugin_requires, license="apache2", - python_requires=">=3.8", + python_requires=">=3.10", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development", diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index 1049787677..d7bc04656b 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -94,6 +94,6 @@ async def test_agent(mock_boto_call, mock_secret): assert resource.phase == TaskExecution.SUCCEEDED assert ( - resource.outputs.literals["result"].scalar.generic.fields["EndpointConfigArn"].string_value + resource.outputs["result"]["EndpointConfigArn"] == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" ) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py index 5f5dbc9d3f..b05f23d062 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py @@ -106,7 +106,7 @@ async def test_agent(mock_boto_call, mock_secret): resource = await agent.get(metadata) assert resource.phase == TaskExecution.SUCCEEDED - from_json = json.loads(resource.outputs.literals["result"].scalar.primitive.string_value) + from_json = json.loads(resource.outputs["result"]) assert from_json["EndpointName"] == "sagemaker-xgboost-endpoint" assert from_json["EndpointArn"] == "arn:aws:sagemaker:us-east-2:1234567890:endpoint/sagemaker-xgboost-endpoint" From 67fe7f6366e8b787b0b399ad4a2210543205d3f1 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 20 Mar 2024 15:23:30 +0530 Subject: [PATCH 095/120] handle scenario when the same input is present in the wf already Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/workflow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 511cc39444..fd9bde80de 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -61,7 +61,9 @@ def create_sagemaker_deployment( input_dict = {} if isinstance(value, dict): for param, t in value.items(): - wf.add_workflow_input(param, t) + # Handles the scenario when the same input is present during different API calls. + if param not in wf.inputs.keys(): + wf.add_workflow_input(param, t) input_dict[param] = wf.inputs[param] node = wf.add_entity(key, **input_dict) if len(nodes) > 0: From e2dd67269b320a02456d251f2408143f115bba27 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 20 Mar 2024 16:02:31 +0530 Subject: [PATCH 096/120] add support for region to be a user-provided input at run-time Signed-off-by: Samhita Alla --- .../awssagemaker_inference/__init__.py | 3 +- .../awssagemaker_inference/agent.py | 9 +--- .../awssagemaker_inference/boto3_mixin.py | 41 +++++++++++++++- .../awssagemaker_inference/boto3_task.py | 2 +- .../awssagemaker_inference/task.py | 48 ++++--------------- 5 files changed, 52 insertions(+), 51 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py index cbb05d5178..e907455182 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/__init__.py @@ -30,6 +30,7 @@ SageMakerEndpointTask, SageMakerInvokeEndpointTask, SageMakerModelTask, - triton_image_uri, ) from .workflow import create_sagemaker_deployment, delete_sagemaker_deployment + +triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:21.08-py3" diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 23a4ced379..0a7d4578d5 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -1,5 +1,4 @@ import json -from dataclasses import dataclass from datetime import datetime from typing import Optional @@ -7,13 +6,13 @@ AgentRegistry, AsyncAgentBase, Resource, - ResourceMeta, ) from flytekit.extend.backend.utils import convert_to_flyte_phase, get_agent_secret from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from .boto3_mixin import Boto3AgentMixin +from .task import SageMakerEndpointMetadata states = { "Creating": "Running", @@ -30,12 +29,6 @@ def default(self, o): return json.JSONEncoder.default(self, o) -@dataclass -class SageMakerEndpointMetadata(ResourceMeta): - endpoint_name: str - region: str - - class SageMakerEndpointAgent(Boto3AgentMixin, AsyncAgentBase): """This agent creates an endpoint.""" diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 8d8fd36f11..bdbfd3b653 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -5,6 +5,31 @@ from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.models.literals import LiteralMap +account_id_map = { + "us-east-1": "785573368785", + "us-east-2": "007439368137", + "us-west-1": "710691900526", + "us-west-2": "301217895009", + "eu-west-1": "802834080501", + "eu-west-2": "205493899709", + "eu-west-3": "254080097072", + "eu-north-1": "601324751636", + "eu-south-1": "966458181534", + "eu-central-1": "746233611703", + "ap-east-1": "110948597952", + "ap-south-1": "763008648453", + "ap-northeast-1": "941853720454", + "ap-northeast-2": "151534178276", + "ap-southeast-1": "324986816169", + "ap-southeast-2": "355873309152", + "cn-northwest-1": "474822919863", + "cn-north-1": "472730292857", + "sa-east-1": "756306329178", + "ca-central-1": "464438896020", + "me-south-1": "836785723513", + "af-south-1": "774647643957", +} + def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: """ @@ -118,6 +143,20 @@ async def _call( if inputs: args["inputs"] = literal_map_string_repr(inputs) + input_region = args["inputs"].get("region") + + final_region = input_region or region or self._region + if not final_region: + raise ValueError("Region parameter is required.") + + for image_name, image in images.items(): + if isinstance(image, str) and "{region}" in image: + base = "amazonaws.com.cn" if final_region.startswith("cn-") else "amazonaws.com" + images[image_name] = image.format( + account_id=account_id_map[final_region], + region=final_region, + base=base, + ) if images: args["images"] = images @@ -128,7 +167,7 @@ async def _call( session = aioboto3.Session() async with session.client( service_name=self._service, - region_name=self._region or region, + region_name=final_region, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py index 4bfbdffa84..2e7c8f5b7b 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_task.py @@ -14,7 +14,7 @@ class BotoConfig(object): service: str method: str config: Dict[str, Any] - region: str + region: Optional[str] = None images: Optional[Dict[str, Union[str, ImageSpec]]] = None diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index 92bb98026b..6372efc03d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -9,40 +9,13 @@ from .boto3_task import BotoConfig, BotoTask -account_id_map = { - "us-east-1": "785573368785", - "us-east-2": "007439368137", - "us-west-1": "710691900526", - "us-west-2": "301217895009", - "eu-west-1": "802834080501", - "eu-west-2": "205493899709", - "eu-west-3": "254080097072", - "eu-north-1": "601324751636", - "eu-south-1": "966458181534", - "eu-central-1": "746233611703", - "ap-east-1": "110948597952", - "ap-south-1": "763008648453", - "ap-northeast-1": "941853720454", - "ap-northeast-2": "151534178276", - "ap-southeast-1": "324986816169", - "ap-southeast-2": "355873309152", - "cn-northwest-1": "474822919863", - "cn-north-1": "472730292857", - "sa-east-1": "756306329178", - "ca-central-1": "464438896020", - "me-south-1": "836785723513", - "af-south-1": "774647643957", -} - -triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:21.08-py3" - class SageMakerModelTask(BotoTask): def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, images: Optional[Dict[str, Union[str, ImageSpec]]] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, @@ -58,11 +31,6 @@ def __init__( or in a Docker registry that is accessible from the same VPC that you configure for your endpoint. """ - for image_name, image in images.items(): - if isinstance(image, str) and "{region}" in image: - base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com" - images[image_name] = image.format(account_id=account_id_map[region], region=region, base=base) - super(SageMakerModelTask, self).__init__( name=name, task_config=BotoConfig( @@ -82,7 +50,7 @@ def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, ): @@ -110,7 +78,7 @@ def __init__( @dataclass class SageMakerEndpointMetadata(object): config: Dict[str, Any] - region: str + region: Optional[str] = None class SageMakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SageMakerEndpointMetadata]): @@ -120,7 +88,7 @@ def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, ): @@ -152,7 +120,7 @@ def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, ): @@ -182,7 +150,7 @@ def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, ): @@ -212,7 +180,7 @@ def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, ): @@ -242,7 +210,7 @@ def __init__( self, name: str, config: Dict[str, Any], - region: Optional[str], + region: Optional[str] = None, inputs: Optional[Dict[str, Type]] = None, **kwargs, ): From eeeab7169994d5a90823796de1693e4e8fed4c4c Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 20 Mar 2024 18:26:57 +0530 Subject: [PATCH 097/120] modify workflow code to accommodate providing regions at runtime Signed-off-by: Samhita Alla --- .../awssagemaker_inference/workflow.py | 87 ++++++++++++++----- 1 file changed, 63 insertions(+), 24 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index fd9bde80de..ecb6cb4d3d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -22,15 +22,29 @@ def create_sagemaker_deployment( endpoint_config_input_types: Optional[Dict[str, Type]] = None, endpoint_input_types: Optional[Dict[str, Type]] = None, region: Optional[str] = None, + region_at_runtime: bool = False, ): """ Creates SageMaker model, endpoint config and endpoint. + + :param model_config: Configuration for the SageMaker model creation API call. + :param endpoint_config_config: Configuration for the SageMaker endpoint configuration creation API call. + :param endpoint_config: Configuration for the SageMaker endpoint creation API call. + :param images: A dictionary of images for SageMaker model creation. + :param model_input_types: Mapping of SageMaker model configuration inputs to their types. + :param endpoint_config_input_types: Mapping of SageMaker endpoint configuration inputs to their types. + :param endpoint_input_types: Mapping of SageMaker endpoint inputs to their types. + :param region: The region for SageMaker API calls. + :param region_at_runtime: Set this to True if you want to provide the region at runtime. """ + if not any((region, region_at_runtime)): + raise ValueError("Region parameter is required.") + sagemaker_model_task = SageMakerModelTask( name=f"sagemaker-model-{name}", config=model_config, region=region, - inputs=model_input_types, + inputs=(model_input_types.update({"region": str}) if region_at_runtime else model_input_types), images=images, ) @@ -38,14 +52,16 @@ def create_sagemaker_deployment( name=f"sagemaker-endpoint-config-{name}", config=endpoint_config_config, region=region, - inputs=endpoint_config_input_types, + inputs=( + endpoint_config_input_types.update({"region": str}) if region_at_runtime else endpoint_config_input_types + ), ) endpoint_task = SageMakerEndpointTask( name=f"sagemaker-endpoint-{name}", config=endpoint_config, region=region, - inputs=endpoint_input_types, + inputs=(endpoint_input_types.update({"region": str}) if region_at_runtime else endpoint_input_types), ) wf = Workflow(name=f"sagemaker-deploy-{name}") @@ -56,6 +72,9 @@ def create_sagemaker_deployment( endpoint_task: endpoint_input_types, } + if region_at_runtime: + wf.add_workflow_input("region", str) + nodes = [] for key, value in inputs.items(): input_dict = {} @@ -65,6 +84,8 @@ def create_sagemaker_deployment( if param not in wf.inputs.keys(): wf.add_workflow_input(param, t) input_dict[param] = wf.inputs[param] + if region_at_runtime: + input_dict["region"] = wf.inputs["region"] node = wf.add_entity(key, **input_dict) if len(nodes) > 0: nodes[-1] >> node @@ -74,49 +95,67 @@ def create_sagemaker_deployment( return wf -def delete_sagemaker_deployment(name: str, region: Optional[str] = None): +def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False): """ Deletes SageMaker model, endpoint config and endpoint. + + :param name: The prefix to be added to the task names. + :param region: The region to use for SageMaker API calls. + :param region_at_runtime: Set this to True if you want to provide the region at runtime. """ + if not any((region, region_at_runtime)): + raise ValueError("Region parameter is required.") + sagemaker_delete_endpoint = SageMakerDeleteEndpointTask( name=f"sagemaker-delete-endpoint-{name}", config={"EndpointName": "{inputs.endpoint_name}"}, region=region, - inputs=kwtypes(endpoint_name=str), + inputs=(kwtypes(endpoint_name=str, region=str) if region_at_runtime else kwtypes(endpoint_name=str)), ) sagemaker_delete_endpoint_config = SageMakerDeleteEndpointConfigTask( name=f"sagemaker-delete-endpoint-config-{name}", config={"EndpointConfigName": "{inputs.endpoint_config_name}"}, region=region, - inputs=kwtypes(endpoint_config_name=str), + inputs=( + kwtypes(endpoint_config_name=str, region=str) if region_at_runtime else kwtypes(endpoint_config_name=str) + ), ) sagemaker_delete_model = SageMakerDeleteModelTask( name=f"sagemaker-delete-model-{name}", config={"ModelName": "{inputs.model_name}"}, region=region, - inputs=kwtypes(model_name=str), + inputs=(kwtypes(model_name=str, region=str) if region_at_runtime else kwtypes(model_name=str)), ) wf = Workflow(name=f"sagemaker-delete-endpoint-wf-{name}") - wf.add_workflow_input("endpoint_name", str) - wf.add_workflow_input("endpoint_config_name", str) - wf.add_workflow_input("model_name", str) - node_t1 = wf.add_entity( - sagemaker_delete_endpoint, - endpoint_name=wf.inputs["endpoint_name"], - ) - node_t2 = wf.add_entity( - sagemaker_delete_endpoint_config, - endpoint_config_name=wf.inputs["endpoint_config_name"], - ) - node_t3 = wf.add_entity( - sagemaker_delete_model, - model_name=wf.inputs["model_name"], - ) - node_t1 >> node_t2 - node_t2 >> node_t3 + if region_at_runtime: + wf.add_workflow_input("region", str) + + inputs = { + sagemaker_delete_endpoint: "endpoint_name", + sagemaker_delete_endpoint_config: "endpoint_config_name", + sagemaker_delete_model: "model_name", + } + + nodes = [] + for key, value in inputs.items(): + wf.add_workflow_input(value, str) + node = wf.add_entity( + key, + **( + { + value: wf.inputs[value], + "region": wf.inputs["region"], + } + if region_at_runtime + else {value: wf.inputs[value]} + ), + ) + if len(nodes) > 0: + nodes[-1] >> node + nodes.append(node) return wf From 6b954177bff75e1551fc7ba4735841a0d9a680f9 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 20 Mar 2024 21:45:43 +0530 Subject: [PATCH 098/120] code optimization and add region support to workflows Signed-off-by: Samhita Alla --- .../awssagemaker_inference/workflow.py | 133 ++++++++++-------- 1 file changed, 73 insertions(+), 60 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index ecb6cb4d3d..1a893e60f3 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -12,6 +12,17 @@ ) +def create_deployment_task( + name: str, + task_type: Any, + config: Dict[str, Any], + region: str, + inputs: Optional[Dict[str, Type]], + images: Optional[Dict[str, Any]], +) -> Any: + return task_type(name=name, config=config, region=region, inputs=inputs, images=images) + + def create_sagemaker_deployment( name: str, model_config: Dict[str, Any], @@ -40,43 +51,45 @@ def create_sagemaker_deployment( if not any((region, region_at_runtime)): raise ValueError("Region parameter is required.") - sagemaker_model_task = SageMakerModelTask( - name=f"sagemaker-model-{name}", - config=model_config, - region=region, - inputs=(model_input_types.update({"region": str}) if region_at_runtime else model_input_types), - images=images, - ) - - endpoint_config_task = SageMakerEndpointConfigTask( - name=f"sagemaker-endpoint-config-{name}", - config=endpoint_config_config, - region=region, - inputs=( - endpoint_config_input_types.update({"region": str}) if region_at_runtime else endpoint_config_input_types - ), - ) - - endpoint_task = SageMakerEndpointTask( - name=f"sagemaker-endpoint-{name}", - config=endpoint_config, - region=region, - inputs=(endpoint_input_types.update({"region": str}) if region_at_runtime else endpoint_input_types), - ) - wf = Workflow(name=f"sagemaker-deploy-{name}") - inputs = { - sagemaker_model_task: model_input_types, - endpoint_config_task: endpoint_config_input_types, - endpoint_task: endpoint_input_types, - } - if region_at_runtime: + model_input_types.update({"region": str}) + endpoint_config_input_types.update({"region": str}) + endpoint_input_types.update({"region": str}) wf.add_workflow_input("region", str) + inputs = { + SageMakerModelTask: { + "input_types": model_input_types, + "name": "sagemaker-model", + "images": True, + "config": model_config, + }, + SageMakerEndpointConfigTask: { + "input_types": endpoint_config_input_types, + "name": "sagemaker-endpoint-config", + "images": False, + "config": endpoint_config_config, + }, + SageMakerEndpointTask: { + "input_types": endpoint_input_types, + "name": "sagemaker-endpoint", + "images": False, + "config": endpoint_config, + }, + } + nodes = [] for key, value in inputs.items(): + obj = create_sagemaker_deployment( + name=f"{value['name']}-{name}", + task_type=key, + config=value["config"], + region=region, + inputs=value["input_types"], + images=images if value["images"] else None, + ) input_dict = {} if isinstance(value, dict): for param, t in value.items(): @@ -84,9 +97,7 @@ def create_sagemaker_deployment( if param not in wf.inputs.keys(): wf.add_workflow_input(param, t) input_dict[param] = wf.inputs[param] - if region_at_runtime: - input_dict["region"] = wf.inputs["region"] - node = wf.add_entity(key, **input_dict) + node = wf.add_entity(obj, **input_dict) if len(nodes) > 0: nodes[-1] >> node nodes.append(node) @@ -95,6 +106,22 @@ def create_sagemaker_deployment( return wf +def create_delete_task( + name: str, + task_type: Any, + config: Dict[str, Any], + region: str, + value: str, + region_at_runtime: bool, +) -> Any: + return task_type( + name=name, + config=config, + region=region, + inputs=(kwtypes(**{value: str, "region": str}) if region_at_runtime else kwtypes(value=str)), + ) + + def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False): """ Deletes SageMaker model, endpoint config and endpoint. @@ -106,45 +133,31 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_ if not any((region, region_at_runtime)): raise ValueError("Region parameter is required.") - sagemaker_delete_endpoint = SageMakerDeleteEndpointTask( - name=f"sagemaker-delete-endpoint-{name}", - config={"EndpointName": "{inputs.endpoint_name}"}, - region=region, - inputs=(kwtypes(endpoint_name=str, region=str) if region_at_runtime else kwtypes(endpoint_name=str)), - ) - - sagemaker_delete_endpoint_config = SageMakerDeleteEndpointConfigTask( - name=f"sagemaker-delete-endpoint-config-{name}", - config={"EndpointConfigName": "{inputs.endpoint_config_name}"}, - region=region, - inputs=( - kwtypes(endpoint_config_name=str, region=str) if region_at_runtime else kwtypes(endpoint_config_name=str) - ), - ) - - sagemaker_delete_model = SageMakerDeleteModelTask( - name=f"sagemaker-delete-model-{name}", - config={"ModelName": "{inputs.model_name}"}, - region=region, - inputs=(kwtypes(model_name=str, region=str) if region_at_runtime else kwtypes(model_name=str)), - ) - wf = Workflow(name=f"sagemaker-delete-endpoint-wf-{name}") if region_at_runtime: wf.add_workflow_input("region", str) inputs = { - sagemaker_delete_endpoint: "endpoint_name", - sagemaker_delete_endpoint_config: "endpoint_config_name", - sagemaker_delete_model: "model_name", + SageMakerDeleteEndpointTask: "endpoint_name", + SageMakerDeleteEndpointConfigTask: "endpoint_config_name", + SageMakerDeleteModelTask: "model_name", } nodes = [] for key, value in inputs.items(): + obj = create_delete_task( + name=f"sagemaker-delete-{value.replace('_name').replace('_', '-')}-{name}", + task_type=key, + config={value.title().replace("_", ""): f"{{inputs.{value}}}"}, + region=region, + value=value, + region_at_runtime=region_at_runtime, + ) + wf.add_workflow_input(value, str) node = wf.add_entity( - key, + obj, **( { value: wf.inputs[value], From 9ef751373b7e1d8304673ae02fbb4ede2a7fd790 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 20 Mar 2024 21:52:12 +0530 Subject: [PATCH 099/120] nit Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 1a893e60f3..19992da219 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -82,7 +82,7 @@ def create_sagemaker_deployment( nodes = [] for key, value in inputs.items(): - obj = create_sagemaker_deployment( + obj = create_deployment_task( name=f"{value['name']}-{name}", task_type=key, config=value["config"], From 8f1a2e39a5831a4e8955500aab34b20d5186a5ea Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 20 Mar 2024 22:00:25 +0530 Subject: [PATCH 100/120] fixed an input_types bug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/workflow.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 19992da219..ec56eecb9e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -82,17 +82,18 @@ def create_sagemaker_deployment( nodes = [] for key, value in inputs.items(): + input_types = value["input_types"] obj = create_deployment_task( name=f"{value['name']}-{name}", task_type=key, config=value["config"], region=region, - inputs=value["input_types"], + inputs=input_types, images=images if value["images"] else None, ) input_dict = {} - if isinstance(value, dict): - for param, t in value.items(): + if isinstance(input_types, dict): + for param, t in input_types.items(): # Handles the scenario when the same input is present during different API calls. if param not in wf.inputs.keys(): wf.add_workflow_input(param, t) From 346c8944686e17fbc2858125a9f50f0a70d072fb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 10:17:02 +0530 Subject: [PATCH 101/120] fix images bug Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_mixin.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index bdbfd3b653..c3bc1d5c8a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -149,14 +149,20 @@ async def _call( if not final_region: raise ValueError("Region parameter is required.") - for image_name, image in images.items(): - if isinstance(image, str) and "{region}" in image: - base = "amazonaws.com.cn" if final_region.startswith("cn-") else "amazonaws.com" - images[image_name] = image.format( - account_id=account_id_map[final_region], - region=final_region, - base=base, + if images: + base = "amazonaws.com.cn" if final_region.startswith("cn-") else "amazonaws.com" + images = { + image_name: ( + image.format( + account_id=account_id_map[final_region], + region=final_region, + base=base, + ) + if isinstance(image, str) and "{region}" in image + else image ) + for image_name, image in images.items() + } if images: args["images"] = images From ef053caeae8e623c72ce03bb8c51770558940ea0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 11:38:19 +0530 Subject: [PATCH 102/120] replace endpoint_name with config Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 0a7d4578d5..389559d97e 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -58,12 +58,12 @@ async def create( aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - return SageMakerEndpointMetadata(endpoint_name=config.get("EndpointName"), region=region) + return SageMakerEndpointMetadata(config=config, region=region) async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: endpoint_status = await self._call( method="describe_endpoint", - config={"EndpointName": resource_meta.endpoint_name}, + config={"EndpointName": resource_meta.config.get("EndpointName")}, region=resource_meta.region, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), @@ -86,7 +86,7 @@ async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resou async def delete(self, resource_meta: SageMakerEndpointMetadata, **kwargs): await self._call( "delete_endpoint", - config={"EndpointName": resource_meta.endpoint_name}, + config={"EndpointName": resource_meta.config.get("EndpointName")}, region=resource_meta.region, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), From 04fca760fd545617dec55480f945dfce073b2ae7 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 11:41:23 +0530 Subject: [PATCH 103/120] input_region default Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/boto3_mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index c3bc1d5c8a..8b9e2fd60a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -140,6 +140,7 @@ async def _call( :param aws_session_token: An AWS session token used as part of the credentials to authenticate the user. """ args = {} + input_region = None if inputs: args["inputs"] = literal_map_string_repr(inputs) From 03f4507372f0402a049f766b29ef8d8a887468b6 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 11:56:28 +0530 Subject: [PATCH 104/120] add inputs Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/agent.py | 4 +++- .../flytekitplugins/awssagemaker_inference/task.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py index 389559d97e..ecd283fed8 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/agent.py @@ -58,12 +58,13 @@ async def create( aws_session_token=get_agent_secret(secret_key="aws-session-token"), ) - return SageMakerEndpointMetadata(config=config, region=region) + return SageMakerEndpointMetadata(config=config, region=region, inputs=inputs) async def get(self, resource_meta: SageMakerEndpointMetadata, **kwargs) -> Resource: endpoint_status = await self._call( method="describe_endpoint", config={"EndpointName": resource_meta.config.get("EndpointName")}, + inputs=resource_meta.inputs, region=resource_meta.region, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), @@ -88,6 +89,7 @@ async def delete(self, resource_meta: SageMakerEndpointMetadata, **kwargs): "delete_endpoint", config={"EndpointName": resource_meta.config.get("EndpointName")}, region=resource_meta.region, + inputs=resource_meta.inputs, aws_access_key_id=get_agent_secret(secret_key="aws-access-key"), aws_secret_access_key=get_agent_secret(secret_key="aws-secret-access-key"), aws_session_token=get_agent_secret(secret_key="aws-session-token"), diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index 6372efc03d..ce920f4b00 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -6,6 +6,7 @@ from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.models.literals import LiteralMap from .boto3_task import BotoConfig, BotoTask @@ -79,6 +80,7 @@ def __init__( class SageMakerEndpointMetadata(object): config: Dict[str, Any] region: Optional[str] = None + inputs: Optional[LiteralMap] = None class SageMakerEndpointTask(AsyncAgentExecutorMixin, PythonTask[SageMakerEndpointMetadata]): From 9df6cf55e04a45993cf0dbb434cc58cd35bcb115 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 12:06:03 +0530 Subject: [PATCH 105/120] replace bug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index ec56eecb9e..221f0f9b1d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -148,7 +148,7 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_ nodes = [] for key, value in inputs.items(): obj = create_delete_task( - name=f"sagemaker-delete-{value.replace('_name').replace('_', '-')}-{name}", + name=f"sagemaker-delete-{value.replace('_name', '').replace('_', '-')}-{name}", task_type=key, config={value.title().replace("_", ""): f"{{inputs.{value}}}"}, region=region, From 79a1d685b15ef4b29dd2997c19e9b1d4aa9a3595 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 12:21:27 +0530 Subject: [PATCH 106/120] add return type Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 221f0f9b1d..5c20c81699 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -34,7 +34,7 @@ def create_sagemaker_deployment( endpoint_input_types: Optional[Dict[str, Type]] = None, region: Optional[str] = None, region_at_runtime: bool = False, -): +) -> Workflow: """ Creates SageMaker model, endpoint config and endpoint. @@ -123,7 +123,7 @@ def create_delete_task( ) -def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False): +def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False) -> Workflow: """ Deletes SageMaker model, endpoint config and endpoint. From 5864d96dbcace6a2c3d4fc1e5c486a329f73a145 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 12:32:05 +0530 Subject: [PATCH 107/120] kwtypes fix bug Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/workflow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 5c20c81699..585e1b985c 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -119,7 +119,7 @@ def create_delete_task( name=name, config=config, region=region, - inputs=(kwtypes(**{value: str, "region": str}) if region_at_runtime else kwtypes(value=str)), + inputs=(kwtypes(**{value: str, "region": str}) if region_at_runtime else kwtypes(**{value: str})), ) @@ -134,7 +134,7 @@ def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_ if not any((region, region_at_runtime)): raise ValueError("Region parameter is required.") - wf = Workflow(name=f"sagemaker-delete-endpoint-wf-{name}") + wf = Workflow(name=f"sagemaker-delete-deployment-{name}") if region_at_runtime: wf.add_workflow_input("region", str) From 6ed332cb64138e232968ff883017074b078985b4 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 13:04:42 +0530 Subject: [PATCH 108/120] add cache_ignore_input_vars to task template in tests Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py | 1 + plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py | 1 + 2 files changed, 2 insertions(+) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index d7bc04656b..7be62e216c 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -70,6 +70,7 @@ async def test_agent(mock_boto_call, mock_secret): deprecated_error_message="This is deprecated!", cache_serializable=True, pod_template_name="A", + cache_ignore_input_vars=(), ) task_template = TaskTemplate( diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py index b05f23d062..2a2499cba7 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py @@ -87,6 +87,7 @@ async def test_agent(mock_boto_call, mock_secret): deprecated_error_message="This is deprecated!", cache_serializable=True, pod_template_name="A", + cache_ignore_input_vars=(), ) task_template = TaskTemplate( From b22e8ffa725af2617b3899de47f8a0adaff939df Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 14:31:48 +0530 Subject: [PATCH 109/120] fix test Signed-off-by: Samhita Alla --- .../flytekit-aws-sagemaker/tests/test_inference_agent.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py index 2a2499cba7..7b1d342479 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py @@ -99,7 +99,13 @@ async def test_agent(mock_boto_call, mock_secret): ) # CREATE - metadata = SageMakerEndpointMetadata(endpoint_name="sagemaker-endpoint", region="us-east-2") + metadata = SageMakerEndpointMetadata( + config={ + "EndpointName": "sagemaker-endpoint", + "EndpointConfigName": "sagemaker-endpoint-config", + }, + region="us-east-2", + ) response = await agent.create(task_template) assert response == metadata From 85a9c083bb2de9ee0df56ee0d39841516d71d6fb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 14:34:02 +0530 Subject: [PATCH 110/120] fix test Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py index 7b1d342479..e4003c0735 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_agent.py @@ -72,7 +72,7 @@ async def test_agent(mock_boto_call, mock_secret): "service": "sagemaker", "config": { "EndpointName": "sagemaker-endpoint", - "EndpointConfigName": "endpoint-config-name", + "EndpointConfigName": "sagemaker-endpoint-config", }, "region": "us-east-2", "method": "create_endpoint", From fda1dd90e681ce38c6e0dcbe32f787a232de0ba0 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:08:48 +0530 Subject: [PATCH 111/120] add test case for boto3 call method Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_mixin.py | 12 +++-- .../awssagemaker_inference/workflow.py | 34 +++++++++----- .../tests/test_boto3_mixin.py | 44 +++++++++++++++++++ .../tests/test_inference_task.py | 15 +++++++ .../tests/test_inference_workflow.py | 43 +++++++++++++++++- 5 files changed, 132 insertions(+), 16 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 8b9e2fd60a..9a64c6485d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -61,7 +61,9 @@ def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: try: update_dict_copy = update_dict_copy[key] except Exception: - raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") + raise ValueError( + f"Could not find the key {key} in {update_dict_copy}." + ) return update_dict_copy @@ -151,7 +153,11 @@ async def _call( raise ValueError("Region parameter is required.") if images: - base = "amazonaws.com.cn" if final_region.startswith("cn-") else "amazonaws.com" + base = ( + "amazonaws.com.cn" + if final_region.startswith("cn-") + else "amazonaws.com" + ) images = { image_name: ( image.format( @@ -170,7 +176,7 @@ async def _call( updated_config = update_dict_fn(config, args) - # Asynchronouse Boto3 session + # Asynchronous Boto3 session session = aioboto3.Session() async with session.client( service_name=self._service, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 585e1b985c..55300ac64a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, Tuple from flytekit import Workflow, kwtypes @@ -19,8 +19,16 @@ def create_deployment_task( region: str, inputs: Optional[Dict[str, Type]], images: Optional[Dict[str, Any]], -) -> Any: - return task_type(name=name, config=config, region=region, inputs=inputs, images=images) + region_at_runtime: bool, +) -> Tuple[Any, Optional[Dict[str, Type]]]: + if region_at_runtime: + if inputs: + inputs.update({"region": str}) + else: + inputs = kwtypes(region=str) + return task_type( + name=name, config=config, region=region, inputs=inputs, images=images + ), inputs def create_sagemaker_deployment( @@ -54,9 +62,6 @@ def create_sagemaker_deployment( wf = Workflow(name=f"sagemaker-deploy-{name}") if region_at_runtime: - model_input_types.update({"region": str}) - endpoint_config_input_types.update({"region": str}) - endpoint_input_types.update({"region": str}) wf.add_workflow_input("region", str) inputs = { @@ -83,17 +88,18 @@ def create_sagemaker_deployment( nodes = [] for key, value in inputs.items(): input_types = value["input_types"] - obj = create_deployment_task( + obj, new_input_types = create_deployment_task( name=f"{value['name']}-{name}", task_type=key, config=value["config"], region=region, inputs=input_types, images=images if value["images"] else None, + region_at_runtime=region_at_runtime, ) input_dict = {} - if isinstance(input_types, dict): - for param, t in input_types.items(): + if isinstance(new_input_types, dict): + for param, t in new_input_types.items(): # Handles the scenario when the same input is present during different API calls. if param not in wf.inputs.keys(): wf.add_workflow_input(param, t) @@ -119,11 +125,17 @@ def create_delete_task( name=name, config=config, region=region, - inputs=(kwtypes(**{value: str, "region": str}) if region_at_runtime else kwtypes(**{value: str})), + inputs=( + kwtypes(**{value: str, "region": str}) + if region_at_runtime + else kwtypes(**{value: str}) + ), ) -def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False) -> Workflow: +def delete_sagemaker_deployment( + name: str, region: Optional[str] = None, region_at_runtime: bool = False +) -> Workflow: """ Deletes SageMaker model, endpoint config and endpoint. diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 4c37891431..21c34c8f88 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -6,6 +6,10 @@ from flytekit.core.type_engine import TypeEngine from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.types.file import FlyteFile +from flytekitplugins.awssagemaker_inference.boto3_mixin import Boto3AgentMixin +from flytekitplugins.awssagemaker_inference import triton_image_uri +import pytest +from unittest.mock import patch, AsyncMock def test_inputs(): @@ -71,3 +75,43 @@ def test_container(): result = update_dict_fn(original_dict=original_dict, update_dict={"images": images}) assert result == {"a": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} + + +@pytest.mark.asyncio +@patch("flytekitplugins.awssagemaker_inference.boto3_mixin.aioboto3.Session") +async def test_call(mock_session): + mixin = Boto3AgentMixin(service="sagemaker") + + mock_client = AsyncMock() + mock_session.return_value.client.return_value.__aenter__.return_value = mock_client + mock_method = mock_client.create_model + + config = { + "ModelName": "{inputs.model_name}", + "PrimaryContainer": { + "Image": "{images.primary_container_image}", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + } + inputs = TypeEngine.dict_to_literal_map( + FlyteContext.current_context(), + {"model_name": "xgboost", "region": "us-west-2"}, + {"model_name": str, "region": str}, + ) + + result = await mixin._call( + method="create_model", + config=config, + inputs=inputs, + images={"primary_container_image": triton_image_uri}, + ) + + mock_method.assert_called_with( + ModelName="xgboost", + PrimaryContainer={ + "Image": "301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:21.08-py3", + "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", + }, + ) + + assert result == mock_method.return_value diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py index 0213c91bba..93e61d909d 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_task.py @@ -124,6 +124,21 @@ "us-east-2", SageMakerInvokeEndpointTask, ), + ( + "sagemaker_invoke_endpoint_with_region_at_runtime", + { + "EndpointName": "{inputs.endpoint_name}", + "InputLocation": "s3://sagemaker-agent-xgboost/inference_input", + }, + "sagemaker-runtime", + "invoke_endpoint_async", + kwtypes(endpoint_name=str, region=str), + None, + 2, + 1, + None, + SageMakerInvokeEndpointTask, + ), ], ) def test_sagemaker_task( diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py index 6740855b25..85578b1a57 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py @@ -45,12 +45,51 @@ def test_sagemaker_deployment_workflow(): assert len(sagemaker_deployment_wf.interface.outputs) == 1 assert len(sagemaker_deployment_wf.nodes) == 3 +def test_sagemaker_deployment_workflow_with_region_at_runtime(): + sagemaker_deployment_wf = create_sagemaker_deployment( + name="sagemaker-deployment-region-runtime", + model_input_types=kwtypes(model_path=str, execution_role_arn=str), + model_config={ + "ModelName": "sagemaker-xgboost", + "PrimaryContainer": { + "Image": "{images.primary_container_image}", + "ModelDataUrl": "{inputs.model_path}", + }, + "ExecutionRoleArn": "{inputs.execution_role_arn}", + }, + endpoint_config_input_types=kwtypes(instance_type=str), + endpoint_config_config={ + "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + "ProductionVariants": [ + { + "VariantName": "variant-name-1", + "ModelName": "sagemaker-xgboost", + "InitialInstanceCount": 1, + "InstanceType": "{inputs.instance_type}", + }, + ], + "AsyncInferenceConfig": { + "OutputConfig": {"S3OutputPath": "s3://sagemaker-agent-xgboost/inference-output/output"} + }, + }, + endpoint_config={ + "EndpointName": "sagemaker-xgboost-endpoint", + "EndpointConfigName": "sagemaker-xgboost-endpoint-config", + }, + images={"primary_container_image": "1234567890.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost"}, + region_at_runtime=True, + ) + + assert len(sagemaker_deployment_wf.interface.inputs) == 4 + assert len(sagemaker_deployment_wf.interface.outputs) == 1 + assert len(sagemaker_deployment_wf.nodes) == 3 + def test_sagemaker_deployment_deletion_workflow(): sagemaker_deployment_deletion_wf = delete_sagemaker_deployment( - name="sagemaker-deployment-deletion", region="us-east-2" + name="sagemaker-deployment-deletion", region_at_runtime=True ) - assert len(sagemaker_deployment_deletion_wf.interface.inputs) == 3 + assert len(sagemaker_deployment_deletion_wf.interface.inputs) == 4 assert len(sagemaker_deployment_deletion_wf.interface.outputs) == 0 assert len(sagemaker_deployment_deletion_wf.nodes) == 3 From 191fca369a04d04ebcaba9cedf75e414f2899a94 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:10:41 +0530 Subject: [PATCH 112/120] lint Signed-off-by: Samhita Alla --- .../awssagemaker_inference/workflow.py | 19 +++++++------------ .../tests/test_inference_workflow.py | 1 + 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py index 55300ac64a..87a27c7497 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/workflow.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type, Tuple +from typing import Any, Dict, Optional, Tuple, Type from flytekit import Workflow, kwtypes @@ -26,9 +26,10 @@ def create_deployment_task( inputs.update({"region": str}) else: inputs = kwtypes(region=str) - return task_type( - name=name, config=config, region=region, inputs=inputs, images=images - ), inputs + return ( + task_type(name=name, config=config, region=region, inputs=inputs, images=images), + inputs, + ) def create_sagemaker_deployment( @@ -125,17 +126,11 @@ def create_delete_task( name=name, config=config, region=region, - inputs=( - kwtypes(**{value: str, "region": str}) - if region_at_runtime - else kwtypes(**{value: str}) - ), + inputs=(kwtypes(**{value: str, "region": str}) if region_at_runtime else kwtypes(**{value: str})), ) -def delete_sagemaker_deployment( - name: str, region: Optional[str] = None, region_at_runtime: bool = False -) -> Workflow: +def delete_sagemaker_deployment(name: str, region: Optional[str] = None, region_at_runtime: bool = False) -> Workflow: """ Deletes SageMaker model, endpoint config and endpoint. diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py index 85578b1a57..1b5f1bebbd 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py @@ -45,6 +45,7 @@ def test_sagemaker_deployment_workflow(): assert len(sagemaker_deployment_wf.interface.outputs) == 1 assert len(sagemaker_deployment_wf.nodes) == 3 + def test_sagemaker_deployment_workflow_with_region_at_runtime(): sagemaker_deployment_wf = create_sagemaker_deployment( name="sagemaker-deployment-region-runtime", From 0975917433351915bfa84f810ec2484a150eadd4 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:13:03 +0530 Subject: [PATCH 113/120] lint Signed-off-by: Samhita Alla --- .../flytekit-aws-sagemaker/tests/test_boto3_mixin.py | 12 +++++++----- .../tests/test_inference_workflow.py | 5 +---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 21c34c8f88..94dd7f9a6f 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -1,15 +1,17 @@ import typing +from unittest.mock import AsyncMock, patch -from flytekitplugins.awssagemaker_inference.boto3_mixin import update_dict_fn +import pytest +from flytekitplugins.awssagemaker_inference import triton_image_uri +from flytekitplugins.awssagemaker_inference.boto3_mixin import ( + Boto3AgentMixin, + update_dict_fn, +) from flytekit import FlyteContext, StructuredDataset from flytekit.core.type_engine import TypeEngine from flytekit.interaction.string_literals import literal_map_string_repr from flytekit.types.file import FlyteFile -from flytekitplugins.awssagemaker_inference.boto3_mixin import Boto3AgentMixin -from flytekitplugins.awssagemaker_inference import triton_image_uri -import pytest -from unittest.mock import patch, AsyncMock def test_inputs(): diff --git a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py index 1b5f1bebbd..f98bb557fa 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_inference_workflow.py @@ -1,7 +1,4 @@ -from flytekitplugins.awssagemaker_inference import ( - create_sagemaker_deployment, - delete_sagemaker_deployment, -) +from flytekitplugins.awssagemaker_inference import create_sagemaker_deployment, delete_sagemaker_deployment from flytekit import kwtypes From 36f431577861b2c2ad6f80b362d239af6d910a5f Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:20:09 +0530 Subject: [PATCH 114/120] lint Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/README.md b/plugins/flytekit-aws-sagemaker/README.md index b8eacf0914..2d57333353 100644 --- a/plugins/flytekit-aws-sagemaker/README.md +++ b/plugins/flytekit-aws-sagemaker/README.md @@ -30,7 +30,7 @@ sagemaker_deployment_wf = create_sagemaker_deployment( model_config={ "ModelName": MODEL_NAME, "PrimaryContainer": { - "Image": "{images.primary_container_image}", + "Image": "{images.deployment_image}", "ModelDataUrl": "{inputs.model_path}", }, "ExecutionRoleArn": "{inputs.execution_role_arn}", @@ -54,7 +54,7 @@ sagemaker_deployment_wf = create_sagemaker_deployment( "EndpointName": ENDPOINT_NAME, "EndpointConfigName": ENDPOINT_CONFIG_NAME, }, - images={"primary_container_image": custom_image}, + images={"deployment_image": custom_image}, region=REGION, ) From bef6a4cfadcf75547bf39be8491ff9eca77e5ac1 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:22:29 +0530 Subject: [PATCH 115/120] lint Signed-off-by: Samhita Alla --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7d51221146..e42d33348f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.2.2 + rev: v0.3.3 hooks: # Run the linter. - id: ruff From 2fecd3710675667595ac44d4faf2f706b169cbcb Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:23:20 +0530 Subject: [PATCH 116/120] lint Signed-off-by: Samhita Alla --- plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 94dd7f9a6f..98c5686e2d 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -91,7 +91,7 @@ async def test_call(mock_session): config = { "ModelName": "{inputs.model_name}", "PrimaryContainer": { - "Image": "{images.primary_container_image}", + "Image": "{images.image}", "ModelDataUrl": "s3://sagemaker-agent-xgboost/model.tar.gz", }, } @@ -105,7 +105,7 @@ async def test_call(mock_session): method="create_model", config=config, inputs=inputs, - images={"primary_container_image": triton_image_uri}, + images={"image": triton_image_uri}, ) mock_method.assert_called_with( From a833ff165a0847d549940b84b86627c4ef6f47e4 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:28:17 +0530 Subject: [PATCH 117/120] ruff version Signed-off-by: Samhita Alla --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e42d33348f..7d51221146 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.2.2 hooks: # Run the linter. - id: ruff From c7bb5ace7950cf4f73d5a288ea54294922481160 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 21 Mar 2024 23:30:31 +0530 Subject: [PATCH 118/120] ruff version Signed-off-by: Samhita Alla --- .../awssagemaker_inference/boto3_mixin.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 9a64c6485d..40f2ad393d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -61,9 +61,7 @@ def update_dict_fn(original_dict: Any, update_dict: Dict[str, Any]) -> Any: try: update_dict_copy = update_dict_copy[key] except Exception: - raise ValueError( - f"Could not find the key {key} in {update_dict_copy}." - ) + raise ValueError(f"Could not find the key {key} in {update_dict_copy}.") return update_dict_copy @@ -153,11 +151,7 @@ async def _call( raise ValueError("Region parameter is required.") if images: - base = ( - "amazonaws.com.cn" - if final_region.startswith("cn-") - else "amazonaws.com" - ) + base = "amazonaws.com.cn" if final_region.startswith("cn-") else "amazonaws.com" images = { image_name: ( image.format( From b86b4c22f80a324e8502392984f69c3253e683a7 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 22 Mar 2024 12:00:30 +0530 Subject: [PATCH 119/120] nit Signed-off-by: Samhita Alla --- plugins/README.md | 10 +++++----- .../awssagemaker_inference/boto3_mixin.py | 2 -- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/plugins/README.md b/plugins/README.md index f12b17cb3f..81d3ad9530 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -129,12 +129,12 @@ setup( Following shows an excerpt from the `flytekit-data-fsspec` plugin's setup.py file. - ```python - setup( - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, - ) +```python +setup( + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) - ``` +``` ### Flytekit Version Pinning diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py index 40f2ad393d..045124afd0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/boto3_mixin.py @@ -164,8 +164,6 @@ async def _call( ) for image_name, image in images.items() } - - if images: args["images"] = images updated_config = update_dict_fn(config, args) From 5a906a34c07eae025de5a35dcbe6721a300c4adf Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Fri, 22 Mar 2024 12:08:45 +0530 Subject: [PATCH 120/120] docstring update Signed-off-by: Samhita Alla --- .../flytekitplugins/awssagemaker_inference/task.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py index ce920f4b00..4ed538a410 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker_inference/task.py @@ -27,9 +27,8 @@ def __init__( :param name: The name of the task. :param config: The configuration to be provided to the boto3 API call. :param region: The region for the boto3 client. + :param images: Images for SageMaker model creation. :param inputs: The input literal map to be used for updating the configuration. - :param image: The path where the inference code is stored can either be in the Amazon EC2 Container Registry - or in a Docker registry that is accessible from the same VPC that you configure for your endpoint. """ super(SageMakerModelTask, self).__init__(