From ba29604fdf8751610a8f52c1cf99521e2631eeca Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 3 Nov 2021 23:01:05 -0700 Subject: [PATCH] Move plugin models to be alongside plugins, remove legacy Sagemaker, Pytorch, notebook, and Tensorflow plugins (#736) Signed-off-by: Yee Hing Tong --- .github/workflows/pythonbuild.yml | 2 +- Makefile | 2 - flytekit/common/tasks/pytorch_task.py | 65 --- .../sagemaker/built_in_training_job_task.py | 109 ---- .../sagemaker/custom_training_job_task.py | 173 ------ .../tasks/sagemaker/distributed_training.py | 102 ---- .../common/tasks/sagemaker/hpo_job_task.py | 114 ---- flytekit/common/tasks/sagemaker/types.py | 7 - flytekit/common/tasks/tensorflow_task.py | 69 --- flytekit/contrib/notebook/__init__.py | 0 flytekit/contrib/notebook/helper.py | 43 -- flytekit/contrib/notebook/supported_types.py | 12 - flytekit/contrib/notebook/tasks.py | 502 ---------------- flytekit/models/presto.py | 5 + flytekit/models/qubole.py | 5 + flytekit/models/sagemaker/__init__.py | 0 flytekit/models/task.py | 59 +- flytekit/plugins/__init__.py | 8 - flytekit/sdk/sagemaker/__init__.py | 0 flytekit/sdk/sagemaker/task.py | 153 ----- flytekit/sdk/tasks.py | 322 ----------- flytekit_scripts/flytekit_sagemaker_runner.py | 94 --- .../flytekitplugins/awssagemaker/__init__.py | 6 +- .../flytekitplugins/awssagemaker/hpo.py | 7 +- .../awssagemaker/models}/__init__.py | 0 .../awssagemaker/models}/hpo_job.py | 3 +- .../awssagemaker/models}/parameter_ranges.py | 0 .../awssagemaker/models}/training_job.py | 0 .../flytekitplugins/awssagemaker/training.py | 3 +- .../flytekit-aws-sagemaker/tests/test_hpo.py | 18 +- .../tests}/test_hpo_job.py | 2 +- .../tests}/test_parameter_ranges.py | 3 +- .../tests/test_training.py | 12 +- .../tests}/test_training_job.py | 2 +- .../flytekitplugins/kfpytorch/models.py | 23 + .../flytekitplugins/kfpytorch/task.py | 5 +- .../flytekitplugins/kftensorflow/models.py | 35 ++ .../flytekitplugins/kftensorflow/task.py | 5 +- .../flytekitplugins/spark/models.py | 147 +++++ .../flytekitplugins/spark/task.py | 5 +- setup.py | 1 - tests/flytekit/common/workflows/sagemaker.py | 121 ---- .../unit/sdk/tasks/test_pytorch_task.py | 48 -- .../unit/sdk/tasks/test_sagemaker_tasks.py | 540 ------------------ .../unit/sdk/tasks/test_tensorflow_task.py | 55 -- .../scripts/test_flytekit_sagemaker_runner.py | 37 -- 46 files changed, 261 insertions(+), 2663 deletions(-) delete mode 100644 flytekit/common/tasks/pytorch_task.py delete mode 100644 flytekit/common/tasks/sagemaker/built_in_training_job_task.py delete mode 100644 flytekit/common/tasks/sagemaker/custom_training_job_task.py delete mode 100644 flytekit/common/tasks/sagemaker/distributed_training.py delete mode 100644 flytekit/common/tasks/sagemaker/hpo_job_task.py delete mode 100644 flytekit/common/tasks/sagemaker/types.py delete mode 100644 flytekit/common/tasks/tensorflow_task.py delete mode 100644 flytekit/contrib/notebook/__init__.py delete mode 100644 flytekit/contrib/notebook/helper.py delete mode 100644 flytekit/contrib/notebook/supported_types.py delete mode 100644 flytekit/contrib/notebook/tasks.py delete mode 100644 flytekit/models/sagemaker/__init__.py delete mode 100644 flytekit/sdk/sagemaker/__init__.py delete mode 100644 flytekit/sdk/sagemaker/task.py delete mode 100644 flytekit_scripts/flytekit_sagemaker_runner.py rename {flytekit/common/tasks/sagemaker => plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models}/__init__.py (100%) rename {flytekit/models/sagemaker => plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models}/hpo_job.py (99%) rename {flytekit/models/sagemaker => plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models}/parameter_ranges.py (100%) rename {flytekit/models/sagemaker => plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models}/training_job.py (100%) rename {tests/flytekit/unit/models/sagemaker => plugins/flytekit-aws-sagemaker/tests}/test_hpo_job.py (97%) rename {tests/flytekit/unit/models/sagemaker => plugins/flytekit-aws-sagemaker/tests}/test_parameter_ranges.py (98%) rename {tests/flytekit/unit/models/sagemaker => plugins/flytekit-aws-sagemaker/tests}/test_training_job.py (98%) create mode 100644 plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py create mode 100644 plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py create mode 100644 plugins/flytekit-spark/flytekitplugins/spark/models.py delete mode 100644 tests/flytekit/common/workflows/sagemaker.py delete mode 100644 tests/flytekit/unit/sdk/tasks/test_pytorch_task.py delete mode 100644 tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py delete mode 100644 tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py delete mode 100644 tests/scripts/test_flytekit_sagemaker_runner.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 1332980df5..924e33fa63 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -39,7 +39,7 @@ jobs: pip freeze - name: Test with coverage run: | - coverage run -m pytest tests/flytekit/unit tests/scripts + coverage run -m pytest tests/flytekit/unit - name: Integration Tests with coverage # https://github.com/actions/runner/issues/241#issuecomment-577360161 shell: 'script -q -e -c "bash {0}"' diff --git a/Makefile b/Makefile index 935a28024f..9e10e0df45 100644 --- a/Makefile +++ b/Makefile @@ -48,12 +48,10 @@ spellcheck: ## Runs a spellchecker over all code and documentation .PHONY: test test: lint ## Run tests pytest tests/flytekit/unit - pytest tests/scripts .PHONY: unit_test unit_test: pytest tests/flytekit/unit - pytest tests/scripts requirements-spark2.txt: export CUSTOM_COMPILE_COMMAND := make requirements-spark2.txt requirements-spark2.txt: requirements-spark2.in install-piptools diff --git a/flytekit/common/tasks/pytorch_task.py b/flytekit/common/tasks/pytorch_task.py deleted file mode 100644 index 17c2cd7ea1..0000000000 --- a/flytekit/common/tasks/pytorch_task.py +++ /dev/null @@ -1,65 +0,0 @@ -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models - - -class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer): - @property - def args(self): - """ - Override args to remove the injection of command prefixes - :rtype: list[Text] - """ - return self._args - - -class SdkPyTorchTask(_sdk_runnable.SdkRunnableTask): - def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - discoverable, - timeout, - workers_count, - per_replica_storage_request, - per_replica_cpu_request, - per_replica_gpu_request, - per_replica_memory_request, - per_replica_storage_limit, - per_replica_cpu_limit, - per_replica_gpu_limit, - per_replica_memory_limit, - environment, - ): - pytorch_job = _task_models.PyTorchJob(workers_count=workers_count).to_flyte_idl() - super(SdkPyTorchTask, self).__init__( - task_function=task_function, - task_type=task_type, - discovery_version=discovery_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=per_replica_storage_request, - cpu_request=per_replica_cpu_request, - gpu_request=per_replica_gpu_request, - memory_request=per_replica_memory_request, - storage_limit=per_replica_storage_limit, - cpu_limit=per_replica_cpu_limit, - gpu_limit=per_replica_gpu_limit, - memory_limit=per_replica_memory_limit, - discoverable=discoverable, - timeout=timeout, - environment=environment, - custom=_MessageToDict(pytorch_job), - ) - - def _get_container_definition(self, **kwargs): - """ - :rtype: SdkRunnablePytorchContainer - """ - return super(SdkPyTorchTask, self)._get_container_definition(cls=SdkRunnablePytorchContainer, **kwargs) diff --git a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py deleted file mode 100644 index 356f933c3e..0000000000 --- a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py +++ /dev/null @@ -1,109 +0,0 @@ -import datetime as _datetime - -from google.protobuf.json_format import MessageToDict - -from flytekit import __version__ -from flytekit.common import interface as _interface -from flytekit.common.constants import SdkTaskType -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import task as _sdk_task -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models -from flytekit.models import types as _idl_types -from flytekit.models.core import types as _core_types -from flytekit.models.sagemaker import training_job as _training_job_models - - -def _content_type_to_blob_format(content_type: _training_job_models) -> str: - if content_type == _training_job_models.InputContentType.TEXT_CSV: - return "csv" - else: - raise _user_exceptions.FlyteValueException("Unsupported InputContentType: {}".format(content_type)) - - -class SdkBuiltinAlgorithmTrainingJobTask(_sdk_task.SdkTask): - def __init__( - self, - training_job_resource_config: _training_job_models.TrainingJobResourceConfig, - algorithm_specification: _training_job_models.AlgorithmSpecification, - retries: int = 0, - cacheable: bool = False, - cache_version: str = "", - ): - """ - - :param training_job_resource_config: The options to configure the training job - :param algorithm_specification: The options to configure the target algorithm of the training - :param retries: Number of retries to attempt - :param cacheable: The flag to set if the user wants the output of the task execution to be cached - :param cache_version: String describing the caching version for task discovery purposes - """ - # Use the training job model as a measure of type checking - self._training_job_model = _training_job_models.TrainingJob( - algorithm_specification=algorithm_specification, - training_job_resource_config=training_job_resource_config, - ) - - # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training - # job gracefully - timeout = _datetime.timedelta(seconds=0) - - super(SdkBuiltinAlgorithmTrainingJobTask, self).__init__( - type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, - metadata=_task_models.TaskMetadata( - runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - version=__version__, - flavor="sagemaker", - ), - discoverable=cacheable, - timeout=timeout, - retries=_literal_models.RetryStrategy(retries=retries), - interruptible=False, - discovery_version=cache_version, - deprecated_error_message="", - ), - interface=_interface.TypedInterface( - inputs={ - "static_hyperparameters": _interface_model.Variable( - type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), - description="", - ), - "train": _interface_model.Variable( - type=_idl_types.LiteralType( - blob=_core_types.BlobType( - format=_content_type_to_blob_format(algorithm_specification.input_content_type), - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ), - ), - description="", - ), - "validation": _interface_model.Variable( - type=_idl_types.LiteralType( - blob=_core_types.BlobType( - format=_content_type_to_blob_format(algorithm_specification.input_content_type), - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ), - ), - description="", - ), - }, - outputs={ - "model": _interface_model.Variable( - type=_idl_types.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) - ), - description="", - ) - }, - ), - custom=MessageToDict(self._training_job_model.to_flyte_idl()), - ) - - @property - def training_job_model(self) -> _training_job_models.TrainingJob: - return self._training_job_model diff --git a/flytekit/common/tasks/sagemaker/custom_training_job_task.py b/flytekit/common/tasks/sagemaker/custom_training_job_task.py deleted file mode 100644 index fb7d8cac98..0000000000 --- a/flytekit/common/tasks/sagemaker/custom_training_job_task.py +++ /dev/null @@ -1,173 +0,0 @@ -import logging as _logging -import typing as _typing - -import six as _six -from google.protobuf.json_format import MessageToDict - -from flytekit.common.constants import SdkTaskType -from flytekit.common.core.identifier import WorkflowExecutionIdentifier -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks.sagemaker import distributed_training as _sm_distribution -from flytekit.common.tasks.sagemaker.distributed_training import DefaultOutputPersistPredicate -from flytekit.models.sagemaker import training_job as _training_job_models - - -class CustomTrainingJobTask(_sdk_runnable.SdkRunnableTask): - """ - CustomTrainJobTask defines a python task that can run on SageMaker bring your own container. - - """ - - def __init__( - self, - task_function, - cache_version, - retries, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - cache, - timeout, - environment, - algorithm_specification: _training_job_models.AlgorithmSpecification, - training_job_resource_config: _training_job_models.TrainingJobResourceConfig, - output_persist_predicate: _typing.Callable = DefaultOutputPersistPredicate(), - ): - """ - :param task_function: Function container user code. This will be executed via the SDK's engine. - :param Text cache_version: string describing the version for task discovery purposes - :param int retries: Number of retries to attempt - :param Text deprecated: - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param bool cache: - :param datetime.timedelta timeout: - :param dict[Text, Text] environment: - :param _training_job_models.AlgorithmSpecification algorithm_specification: - :param _training_job_models.TrainingJobResourceConfig training_job_resource_config: - :param _typing.Callable output_persist_predicate: - """ - - self._output_persist_predicate = output_persist_predicate - - # Use the training job model as a measure of type checking - self._training_job_model = _training_job_models.TrainingJob( - algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config - ) - - super().__init__( - task_function=task_function, - task_type=SdkTaskType.SAGEMAKER_CUSTOM_TRAINING_JOB_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=False, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout, - environment=environment, - custom=MessageToDict(self._training_job_model.to_flyte_idl()), - ) - - @property - def output_persist_predicate(self): - return self._output_persist_predicate - - @property - def training_job_model(self) -> _training_job_models.TrainingJob: - return self._training_job_model - - def _is_distributed(self): - return ( - self.training_job_model.training_job_resource_config - and self.training_job_model.training_job_resource_config.instance_count > 1 - ) - - @_exception_scopes.system_entry_point - def execute(self, context, inputs): - """ - :param flytekit.engines.common.EngineContext context: - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - engine_context = _sm_distribution.DistributedTrainingEngineContext( - execution_date=context.execution_date, - tmp_dir=context.working_directory, - stats=context.stats, - execution_id=context.execution_id, - logging=context.logging, - raw_output_data_prefix=context.raw_output_data_prefix, - distributed_training_context=_sm_distribution.get_sagemaker_distributed_training_context_from_env(), - ) - - ret = super().execute(engine_context, inputs) - - # In a single-node training case, we always wants to persist the output. - # In a distributed-training case, whether or not flytekit wants to persist the output depends on the return - # value of the predicate. - if self._is_distributed() is False or ( - self._is_distributed() - and self._output_persist_predicate - and self.output_persist_predicate(engine_context.distributed_training_context) is True - ): - return ret - else: - _logging.info( - "Output_persist_predicate() returns False for this instance. " - "The output of this task will not be persisted" - ) - return {} - - def _execute_user_code(self, context, inputs): - """ - :param flytekit.engines.common.tasks.sagemaker.distribution.DistributedTrainingEngineContext context: - :param dict[Text, T] inputs: This variable is a bit of a misnomer, since it's both inputs and outputs. The - dictionary passed here will be passed to the user-defined function, and will have values that are a - variety of types. The T's here are Python std values for inputs. If there isn't a native Python type for - something (like Schema or Blob), they are the Flyte classes. For outputs they are OutputReferences. - (Note that these are not the same OutputReferences as in BindingData's) - :rtype: Any: the returned object from user code. - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - - return _exception_scopes.user_entry_point(self.task_function)( - _sm_distribution.DistributedTrainingExecutionParam( - execution_date=context.execution_date, - # TODO: it might be better to consider passing the full struct - execution_id=_six.text_type(WorkflowExecutionIdentifier.promote_from_model(context.execution_id)), - stats=context.stats, - logging=context.logging, - tmp_dir=context.working_directory, - distributed_training_context=context.distributed_training_context, - ), - **inputs - ) diff --git a/flytekit/common/tasks/sagemaker/distributed_training.py b/flytekit/common/tasks/sagemaker/distributed_training.py deleted file mode 100644 index 3cd3bf4d37..0000000000 --- a/flytekit/common/tasks/sagemaker/distributed_training.py +++ /dev/null @@ -1,102 +0,0 @@ -import json as _json -import os as _os - -import retry as _retry - -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.engines import common as _common_engine - -SM_RESOURCE_CONFIG_FILE = "/opt/ml/input/config/resourceconfig.json" -SM_ENV_VAR_CURRENT_HOST = "SM_CURRENT_HOST" -SM_ENV_VAR_HOSTS = "SM_HOSTS" -SM_ENV_VAR_NETWORK_INTERFACE_NAME = "SM_NETWORK_INTERFACE_NAME" - - -# SageMaker suggests "Hostname information might not be immediately available to the processing container. -# We recommend adding a retry policy on hostname resolution operations as nodes become available in the cluster." -# https://docs.aws.amazon.com/sagemaker/latest/dg/build-your-own-processing-container.html#byoc-config -@_retry.retry(exceptions=KeyError, delay=1, tries=10, backoff=1) -def get_sagemaker_distributed_training_context_from_env() -> dict: - distributed_training_context = {} - if ( - not _os.environ.get(SM_ENV_VAR_CURRENT_HOST) - or not _os.environ.get(SM_ENV_VAR_HOSTS) - or not _os.environ.get(SM_ENV_VAR_NETWORK_INTERFACE_NAME) - ): - raise KeyError - - distributed_training_context[DistributedTrainingContextKey.CURRENT_HOST] = _os.environ.get(SM_ENV_VAR_CURRENT_HOST) - distributed_training_context[DistributedTrainingContextKey.HOSTS] = _json.loads(_os.environ.get(SM_ENV_VAR_HOSTS)) - distributed_training_context[DistributedTrainingContextKey.NETWORK_INTERFACE_NAME] = _os.environ.get( - SM_ENV_VAR_NETWORK_INTERFACE_NAME - ) - - return distributed_training_context - - -@_retry.retry(exceptions=FileNotFoundError, delay=1, tries=10, backoff=1) -def get_sagemaker_distributed_training_context_from_file() -> dict: - with open(SM_RESOURCE_CONFIG_FILE, "r") as rc_file: - return _json.load(rc_file) - - -# The default output-persisting predicate. -# With this predicate, only the copy running on the first host in the list of hosts would persist its output -class DefaultOutputPersistPredicate(object): - def __call__(self, distributed_training_context): - return ( - distributed_training_context[DistributedTrainingContextKey.CURRENT_HOST] - == distributed_training_context[DistributedTrainingContextKey.HOSTS][0] - ) - - -class DistributedTrainingContextKey(object): - CURRENT_HOST = "current_host" - HOSTS = "hosts" - NETWORK_INTERFACE_NAME = "network_interface_name" - - -class DistributedTrainingEngineContext(_common_engine.EngineContext): - def __init__( - self, - execution_date, - tmp_dir, - stats, - execution_id, - logging, - raw_output_data_prefix=None, - distributed_training_context=None, - ): - super().__init__( - execution_date=execution_date, - tmp_dir=tmp_dir, - stats=stats, - execution_id=execution_id, - logging=logging, - raw_output_data_prefix=raw_output_data_prefix, - ) - self._distributed_training_context = distributed_training_context - - @property - def distributed_training_context(self) -> dict: - return self._distributed_training_context - - -class DistributedTrainingExecutionParam(_sdk_runnable.ExecutionParameters): - def __init__(self, execution_date, tmp_dir, stats, execution_id, logging, distributed_training_context): - - super().__init__( - execution_date=execution_date, tmp_dir=tmp_dir, stats=stats, execution_id=execution_id, logging=logging - ) - - self._distributed_training_context = distributed_training_context - - @property - def distributed_training_context(self): - """ - This contains the resource information for distributed training. Currently this information is only available - for SageMaker training jobs. - - :rtype: dict - """ - return self._distributed_training_context diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py deleted file mode 100644 index fd4d0a3ce1..0000000000 --- a/flytekit/common/tasks/sagemaker/hpo_job_task.py +++ /dev/null @@ -1,114 +0,0 @@ -import datetime as _datetime -import typing - -from google.protobuf.json_format import MessageToDict - -from flytekit import __version__ -from flytekit.common import interface as _interface -from flytekit.common.constants import SdkTaskType -from flytekit.common.tasks import task as _sdk_task -from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask -from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask -from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig, ParameterRange -from flytekit.models import interface as _interface_model -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models -from flytekit.models import types as _types_models -from flytekit.models.core import types as _core_types -from flytekit.models.sagemaker import hpo_job as _hpo_job_model - - -class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask): - def __init__( - self, - max_number_of_training_jobs: int, - max_parallel_training_jobs: int, - training_job: typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask], - retries: int = 0, - cacheable: bool = False, - cache_version: str = "", - tunable_parameters: typing.List[str] = None, - ): - """ - :param int max_number_of_training_jobs: The maximum number of training jobs that can be launched by this - hyperparameter tuning job - :param int max_parallel_training_jobs: The maximum number of training jobs that can launched by this hyperparameter - tuning job in parallel - :param typing.Union[SdkBuiltinAlgorithmTrainingJobTask, CustomTrainingJobTask] training_job: The reference to the training job definition - :param int retries: Number of retries to attempt - :param bool cacheable: The flag to set if the user wants the output of the task execution to be cached - :param str cache_version: String describing the caching version for task discovery purposes - :param typing.List[str] tunable_parameters: A list of parameters that to tune. If you are tuning a built-int - algorithm, refer to the algorithm's documentation to understand the possible values for the tunable - parameters. E.g. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html for the - list of hyperparameters for Image Classification built-in algorithm. If you are passing a custom - training job, the list of tunable parameters must be a strict subset of the list of inputs defined on - that job. Refer to https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html - for the list of supported hyperparameter types. - """ - # Use the training job model as a measure of type checking - hpo_job = _hpo_job_model.HyperparameterTuningJob( - max_number_of_training_jobs=max_number_of_training_jobs, - max_parallel_training_jobs=max_parallel_training_jobs, - training_job=training_job.training_job_model, - ).to_flyte_idl() - - # Setting flyte-level timeout to 0, and let SageMaker respect the StoppingCondition of - # the underlying training job - # TODO: Discuss whether this is a viable interface or contract - timeout = _datetime.timedelta(seconds=0) - - inputs = {} - inputs.update(training_job.interface.inputs) - inputs.update( - { - "hyperparameter_tuning_job_config": _interface_model.Variable( - HyperparameterTuningJobConfig.to_flyte_literal_type(), - "", - ), - } - ) - - if tunable_parameters: - inputs.update( - { - param: _interface_model.Variable(ParameterRange.to_flyte_literal_type(), "") - for param in tunable_parameters - } - ) - - super().__init__( - type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, - metadata=_task_models.TaskMetadata( - runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - version=__version__, - flavor="sagemaker", - ), - discoverable=cacheable, - timeout=timeout, - retries=_literal_models.RetryStrategy(retries=retries), - interruptible=False, - discovery_version=cache_version, - deprecated_error_message="", - ), - interface=_interface.TypedInterface( - inputs=inputs, - outputs={ - "model": _interface_model.Variable( - type=_types_models.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) - ), - description="", - ) - }, - ), - custom=MessageToDict(hpo_job), - ) - - def __call__(self, *args, **kwargs): - # Overriding the call function just so we clear up the docs and avoid IDEs complaining about the signature. - return super().__call__(*args, **kwargs) diff --git a/flytekit/common/tasks/sagemaker/types.py b/flytekit/common/tasks/sagemaker/types.py deleted file mode 100644 index 5efa8fd6bf..0000000000 --- a/flytekit/common/tasks/sagemaker/types.py +++ /dev/null @@ -1,7 +0,0 @@ -from flytekit.models.sagemaker import hpo_job as _hpo_models -from flytekit.models.sagemaker import parameter_ranges as _parameter_range_models -from flytekit.sdk import types as _sdk_types - -HyperparameterTuningJobConfig = _sdk_types.Types.GenericProto(_hpo_models.HyperparameterTuningJobConfig) - -ParameterRange = _sdk_types.Types.GenericProto(_parameter_range_models.ParameterRangeOneOf) diff --git a/flytekit/common/tasks/tensorflow_task.py b/flytekit/common/tasks/tensorflow_task.py deleted file mode 100644 index 17b9b12ae5..0000000000 --- a/flytekit/common/tasks/tensorflow_task.py +++ /dev/null @@ -1,69 +0,0 @@ -from google.protobuf.json_format import MessageToDict as _MessageToDict - -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import task as _task_models - - -class SdkRunnableTensorflowContainer(_sdk_runnable.SdkRunnableContainer): - @property - def args(self): - """ - Override args to remove the injection of command prefixes - :rtype: list[Text] - """ - return self._args - - -class SdkTensorFlowTask(_sdk_runnable.SdkRunnableTask): - def __init__( - self, - task_function, - task_type, - cache_version, - retries, - interruptible, - deprecated, - cache, - timeout, - workers_count, - ps_replicas_count, - chief_replicas_count, - per_replica_storage_request, - per_replica_cpu_request, - per_replica_gpu_request, - per_replica_memory_request, - per_replica_storage_limit, - per_replica_cpu_limit, - per_replica_gpu_limit, - per_replica_memory_limit, - environment, - ): - tensorflow_job = _task_models.TensorFlowJob( - workers_count=workers_count, ps_replicas_count=ps_replicas_count, chief_replicas_count=chief_replicas_count - ).to_flyte_idl() - super(SdkTensorFlowTask, self).__init__( - task_function=task_function, - task_type=task_type, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - storage_request=per_replica_storage_request, - cpu_request=per_replica_cpu_request, - gpu_request=per_replica_gpu_request, - memory_request=per_replica_memory_request, - storage_limit=per_replica_storage_limit, - cpu_limit=per_replica_cpu_limit, - gpu_limit=per_replica_gpu_limit, - memory_limit=per_replica_memory_limit, - discoverable=cache, - timeout=timeout, - environment=environment, - custom=_MessageToDict(tensorflow_job), - ) - - def _get_container_definition(self, **kwargs): - """ - :rtype: SdkRunnableTensorflowContainer - """ - return super(SdkTensorFlowTask, self)._get_container_definition(cls=SdkRunnableTensorflowContainer, **kwargs) diff --git a/flytekit/contrib/notebook/__init__.py b/flytekit/contrib/notebook/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/contrib/notebook/helper.py b/flytekit/contrib/notebook/helper.py deleted file mode 100644 index 203731f0b6..0000000000 --- a/flytekit/contrib/notebook/helper.py +++ /dev/null @@ -1,43 +0,0 @@ -import os as _os - -import six as _six - -from flytekit.common.types.helpers import pack_python_std_map_to_literal_map as _packer -from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map -from flytekit.plugins import pyspark as _pyspark - - -def record_outputs(outputs=None): - """ - Converts/Records outputs generated by users in their Notebook as Flyte Types. - """ - if outputs is None: - return _packer({}, {}) - tm = {} - for k, v in _six.iteritems(outputs): - t = type(v) - if t not in _notebook_types_map: - raise ValueError( - "Currently only primitive types {} are supported for recording from notebook".format( - _notebook_types_map - ) - ) - tm[k] = _notebook_types_map[t] - return _packer(outputs, tm).to_flyte_idl() - - -# TODO: Support Client Mode -def get_spark_context(spark_conf): - """ - outputs: SparkContext - Returns appropriate SparkContext based on whether invoked via a Notebook or a Flyte workflow. - """ - # We run in cluster-mode in Flyte. - # Ref https://github.com/lyft/flyteplugins/blob/master/go/tasks/v1/flytek8s/k8s_resource_adds.go#L46 - if "FLYTE_INTERNAL_EXECUTION_ID" in _os.environ: - return _pyspark.SparkContext() - - # Add system spark-conf for local/notebook based execution. - spark_conf.add(("spark.master", "local")) - conf = _pyspark.SparkConf().setAll(spark_conf) - return _pyspark.SparkContext(conf=conf) diff --git a/flytekit/contrib/notebook/supported_types.py b/flytekit/contrib/notebook/supported_types.py deleted file mode 100644 index ac0977197a..0000000000 --- a/flytekit/contrib/notebook/supported_types.py +++ /dev/null @@ -1,12 +0,0 @@ -import datetime as _datetime - -from flytekit.common.types import primitives as _primitives - -notebook_types_map = { - int: _primitives.Integer, - bool: _primitives.Boolean, - float: _primitives.Float, - str: _primitives.String, - _datetime.datetime: _primitives.Datetime, - _datetime.timedelta: _primitives.Timedelta, -} diff --git a/flytekit/contrib/notebook/tasks.py b/flytekit/contrib/notebook/tasks.py deleted file mode 100644 index 5836717574..0000000000 --- a/flytekit/contrib/notebook/tasks.py +++ /dev/null @@ -1,502 +0,0 @@ -import datetime as _datetime -import importlib as _importlib -import inspect as _inspect -import json as _json -import os as _os -import sys as _sys - -import six as _six -from google.protobuf import json_format as _json_format -from google.protobuf import text_format as _text_format - -from flytekit import __version__ -from flytekit.bin import entrypoint as _entrypoint -from flytekit.common import constants as _constants -from flytekit.common import interface as _interface2 -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks import spark_task as _spark_task -from flytekit.common.tasks import task as _base_tasks -from flytekit.common.types import helpers as _type_helpers -from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map -from flytekit.engines import loader as _engine_loader -from flytekit.models import interface as _interface -from flytekit.models import literals as _literal_models -from flytekit.models import task as _task_models -from flytekit.plugins import papermill as _pm -from flytekit.sdk.spark_types import SparkType as _spark_type -from flytekit.sdk.types import Types as _Types - -OUTPUT_NOTEBOOK = "output_notebook" - - -def python_notebook( - notebook_path="", - inputs={}, - outputs={}, - cache_version="", - retries=0, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cls=None, -): - """ - Decorator to create a Python Notebook Task definition. - - :rtype: SdkNotebookTask - """ - return SdkNotebookTask( - notebook_path=notebook_path, - inputs=inputs, - outputs=outputs, - task_type=_constants.SdkTaskType.PYTHON_TASK, - discovery_version=cache_version, - retries=retries, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment, - custom={}, - ) - - -class SdkNotebookTask(_base_tasks.SdkTask): - - """ - This class includes the additional logic for building a task that executes Notebooks. - - """ - - def __init__( - self, - notebook_path, - inputs, - outputs, - task_type, - discovery_version, - retries, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - custom, - ): - - if _os.path.isabs(notebook_path) is False: - # Find absolute path for the notebook. - task_module = _importlib.import_module(_find_instance_module()) - module_path = _os.path.dirname(task_module.__file__) - notebook_path = _os.path.normpath(_os.path.join(module_path, notebook_path)) - - self._notebook_path = notebook_path - - super(SdkNotebookTask, self).__init__( - task_type, - _task_models.TaskMetadata( - discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "notebook", - ), - timeout, - _literal_models.RetryStrategy(retries), - False, - discovery_version, - deprecated, - ), - _interface2.TypedInterface({}, {}), - custom, - container=self._get_container_definition( - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - environment=environment, - ), - ) - # Add Inputs - if inputs is not None: - inputs(self) - - # Add outputs - if outputs is not None: - outputs(self) - - # Add a Notebook output as a Blob. - self.interface.outputs.update( - output_notebook=_interface.Variable(_Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK) - ) - - def _validate_inputs(self, inputs): - """ - :param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - for k, v in _six.iteritems(inputs): - sdk_type = _type_helpers.get_sdk_type_from_literal_type(v.type) - if sdk_type not in _notebook_types_map.values(): - raise _user_exceptions.FlyteValidationException( - "Input Type '{}' not supported. Only Primitives are supported for notebook.".format(sdk_type) - ) - super(SdkNotebookTask, self)._validate_inputs(inputs) - - def _validate_outputs(self, outputs): - """ - :param dict[Text, flytekit.models.interface.Variable] inputs: Input variables to validate - :raises: flytekit.common.exceptions.user.FlyteValidationException - """ - - # Add output_notebook as an implicit output to the task. - - for k, v in _six.iteritems(outputs): - - if k == OUTPUT_NOTEBOOK: - raise ValueError( - "{} is a reserved output keyword. Please use a different output name.".format(OUTPUT_NOTEBOOK) - ) - - sdk_type = _type_helpers.get_sdk_type_from_literal_type(v.type) - if sdk_type not in _notebook_types_map.values(): - raise _user_exceptions.FlyteValidationException( - "Output Type '{}' not supported. Only Primitives are supported for notebook.".format(sdk_type) - ) - super(SdkNotebookTask, self)._validate_outputs(outputs) - - @_exception_scopes.system_entry_point - def add_inputs(self, inputs): - """ - Adds the inputs to this task. This can be called multiple times, but it will fail if an input with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] inputs: names and variables - """ - self._validate_inputs(inputs) - self.interface.inputs.update(inputs) - - @_exception_scopes.system_entry_point - def add_outputs(self, outputs): - """ - Adds the outputs to this task. This can be called multiple times, but it will fail if an output with a given - name is added more than once, a name collides with an output, or if the name doesn't exist as an arg name in - the wrapped function. - :param dict[Text, flytekit.models.interface.Variable] outputs: names and variables - """ - self._validate_outputs(outputs) - self.interface.outputs.update(outputs) - - @_exception_scopes.system_entry_point - def unit_test(self, **input_map): - """ - :param dict[Text, T] input_map: Python Std input from users. We will cast these to the appropriate Flyte - literals. - :returns: Depends on the behavior of the specific task in the unit engine. - """ - - return ( - _engine_loader.get_engine("unit") - .get_task(self) - .execute( - _type_helpers.pack_python_std_map_to_literal_map( - input_map, - { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }, - ) - ) - ) - - @_exception_scopes.system_entry_point - def local_execute(self, **input_map): - """ - :param dict[Text, T] input_map: Python Std input from users. We will cast these to the appropriate Flyte - literals. - :rtype: dict[Text, T] - :returns: The output produced by this task in Python standard format. - """ - return ( - _engine_loader.get_engine("local") - .get_task(self) - .execute( - _type_helpers.pack_python_std_map_to_literal_map( - input_map, - { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }, - ) - ) - ) - - @_exception_scopes.system_entry_point - def execute(self, context, inputs): - """ - :param flytekit.engines.common.EngineContext context: - :param flytekit.models.literals.LiteralMap inputs: - :rtype: dict[Text, flytekit.models.common.FlyteIdlEntity] - :returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These - entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each - engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote - working directory (with the names provided), which will in turn allow Flyte Propeller to push along the - workflow. Where as local engine will merely feed the outputs directly into the next node. - """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( - inputs, - {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, - ) - - input_notebook_path = self._notebook_path - # Execute Notebook via Papermill. - output_notebook_path = input_notebook_path.split(".ipynb")[0] + "-out.ipynb" - _pm.execute_notebook(input_notebook_path, output_notebook_path, parameters=inputs_dict) - - # Parse Outputs from Notebook. - outputs = None - with open(output_notebook_path) as json_file: - data = _json.load(json_file) - for p in data["cells"]: - meta = p["metadata"] - if "outputs" in meta["tags"]: - outputs = " ".join(p["outputs"][0]["data"]["text/plain"]) - - if outputs is not None: - dict = _literal_models._literals_pb2.LiteralMap() - _text_format.Parse(outputs, dict) - - # Add output_notebook as an output to the task. - output_notebook = _task_output.OutputReference( - _type_helpers.get_sdk_type_from_literal_type(_Types.Blob.to_flyte_literal_type()) - ) - output_notebook.set(output_notebook_path) - - output_literal_map = _literal_models.LiteralMap.from_flyte_idl(dict) - output_literal_map.literals[OUTPUT_NOTEBOOK] = output_notebook.sdk_value - - return {_constants.OUTPUT_FILE_NAME: output_literal_map} - - @property - def container(self): - """ - If not None, the target of execution should be a container. - :rtype: Container - """ - - # Find task_name - task_module = _importlib.import_module(self.instantiated_in) - for k in dir(task_module): - if getattr(task_module, k) is self: - task_name = k - break - - self._container._args = [ - "pyflyte-execute", - "--task-module", - self.instantiated_in, - "--task-name", - task_name, - "--inputs", - "{{.input}}", - "--output-prefix", - "{{.outputPrefix}}", - "--raw-output-data-prefix", - "{{.rawOutputDataPrefix}}", - ] - return self._container - - def _get_container_definition( - self, - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - environment=None, - **kwargs - ): - """ - :param Text storage_request: - :param Text cpu_request: - :param Text gpu_request: - :param Text memory_request: - :param Text storage_limit: - :param Text cpu_limit: - :param Text gpu_limit: - :param Text memory_limit: - :param dict[Text, Text] environment: - :rtype: flytekit.models.task.Container - """ - - storage_limit = storage_limit or storage_request - cpu_limit = cpu_limit or cpu_request - gpu_limit = gpu_limit or gpu_request - memory_limit = memory_limit or memory_request - - resources = _sdk_runnable.SdkRunnableContainer.get_resources( - storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit - ) - - return _sdk_runnable.SdkRunnableContainer( - command=[], - args=[], - resources=resources, - env=environment, - config={}, - ) - - -def spark_notebook( - notebook_path, - inputs={}, - outputs={}, - spark_conf=None, - cache_version="", - retries=0, - deprecated="", - cache=False, - timeout=None, - environment=None, -): - """ - Decorator to create a Notebook spark task. This task will connect to a Spark cluster, configure the environment, - and then execute the code within the notebook_path as the Spark driver program. - """ - return SdkNotebookSparkTask( - notebook_path=notebook_path, - inputs=inputs, - outputs=outputs, - spark_conf=spark_conf, - discovery_version=cache_version, - retries=retries, - deprecated=deprecated, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment or {}, - ) - - -def _find_instance_module(): - frame = _inspect.currentframe() - while frame: - if frame.f_code.co_name == "": - return frame.f_globals["__name__"] - frame = frame.f_back - return None - - -class SdkNotebookSparkTask(SdkNotebookTask): - - """ - This class includes the additional logic for building a task that executes Spark Notebooks. - - """ - - def __init__( - self, - notebook_path, - inputs, - outputs, - spark_conf, - discovery_version, - retries, - deprecated, - discoverable, - timeout, - environment=None, - ): - - spark_exec_path = _os.path.abspath(_entrypoint.__file__) - if spark_exec_path.endswith(".pyc"): - spark_exec_path = spark_exec_path[:-1] - - if spark_conf is None: - # Parse spark_conf from notebook if not set at task_level. - with open(notebook_path) as json_file: - data = _json.load(json_file) - for p in data["cells"]: - meta = p["metadata"] - if "tags" in meta: - if "conf" in meta["tags"]: - sc_str = " ".join(p["source"]) - ldict = {} - exec(sc_str, globals(), ldict) - spark_conf = ldict["spark_conf"] - - spark_job = _task_models.SparkJob( - spark_conf=spark_conf, - main_class="", - spark_type=_spark_type.PYTHON, - hadoop_conf={}, - application_file="local://" + spark_exec_path, - executor_path=_sys.executable, - ).to_flyte_idl() - - super(SdkNotebookSparkTask, self).__init__( - notebook_path, - inputs, - outputs, - _constants.SdkTaskType.SPARK_TASK, - discovery_version, - retries, - deprecated, - "", - "", - "", - "", - "", - "", - "", - "", - discoverable, - timeout, - environment, - _json_format.MessageToDict(spark_job), - ) - - def _get_container_definition(self, environment=None, **kwargs): - """ - :rtype: flytekit.models.task.Container - """ - - return _spark_task.SdkRunnableSparkContainer( - command=[], - args=[], - resources=_task_models.Resources(limits=[], requests=[]), - env=environment or {}, - config={}, - ) diff --git a/flytekit/models/presto.py b/flytekit/models/presto.py index 870bcce901..b5d13f7288 100644 --- a/flytekit/models/presto.py +++ b/flytekit/models/presto.py @@ -1,3 +1,8 @@ +""" +This is a deprecated module. Model files for plugins should go alongside the microlib. +See ``plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py`` as an example. +""" + ## Todo - change this to qubole_presto once Luis's PR gets merged # from flyteidl.plugins import qubole_presto as _qubole from flyteidl.plugins import presto_pb2 as _presto diff --git a/flytekit/models/qubole.py b/flytekit/models/qubole.py index 2247d6e5fa..97d4c0d3b1 100644 --- a/flytekit/models/qubole.py +++ b/flytekit/models/qubole.py @@ -1,3 +1,8 @@ +""" +This is a deprecated module. Model files for plugins should go alongside the microlib. +See ``plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py`` as an example. +""" + from flyteidl.plugins import qubole_pb2 as _qubole from flytekit.models import common as _common diff --git a/flytekit/models/sagemaker/__init__.py b/flytekit/models/sagemaker/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 0358abfe9d..c20872096d 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -6,9 +6,7 @@ from flyteidl.core import compiler_pb2 as _compiler from flyteidl.core import literals_pb2 as _literals_pb2 from flyteidl.core import tasks_pb2 as _core_task -from flyteidl.plugins import pytorch_pb2 as _pytorch_task from flyteidl.plugins import spark_pb2 as _spark_task -from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -605,6 +603,11 @@ def from_flyte_idl(cls, pb2_object): class SparkJob(_common.FlyteIdlEntity): + """ + This model is deprecated and will be removed in 1.0.0. Please use the definition in the + flytekit spark plugin instead. + """ + def __init__( self, spark_type, @@ -1096,55 +1099,3 @@ def from_flyte_idl(cls, pb2_object): annotations=pb2_object.annotations, labels=pb2_object.labels, ) - - -class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): - self._workers_count = workers_count - - @property - def workers_count(self): - return self._workers_count - - def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ) - - -class TensorFlowJob(_common.FlyteIdlEntity): - def __init__(self, workers_count, ps_replicas_count, chief_replicas_count): - self._workers_count = workers_count - self._ps_replicas_count = ps_replicas_count - self._chief_replicas_count = chief_replicas_count - - @property - def workers_count(self): - return self._workers_count - - @property - def ps_replicas_count(self): - return self._ps_replicas_count - - @property - def chief_replicas_count(self): - return self._chief_replicas_count - - def to_flyte_idl(self): - return _tensorflow_task.DistributedTensorflowTrainingTask( - workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ps_replicas_count=pb2_object.ps_replicas, - chief_replicas_count=pb2_object.chief_replicas, - ) diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py index e20bf5eb36..61333933f1 100644 --- a/flytekit/plugins/__init__.py +++ b/flytekit/plugins/__init__.py @@ -19,10 +19,6 @@ hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: _lazy_loader._LazyLoadModule type(hmsclient).add_sub_module("genthrift.hive_metastore.ttypes") -sagemaker_training = _lazy_loader.lazy_load_module("sagemaker_training") # type: _lazy_loader._LazyLoadModule - -papermill = _lazy_loader.lazy_load_module("papermill") # type: _lazy_loader._LazyLoadModule - _lazy_loader.LazyLoadPlugin("spark", ["pyspark>=2.4.0,<3.0.0"], [pyspark]) _lazy_loader.LazyLoadPlugin("spark3", ["pyspark>=3.0.0"], [pyspark]) @@ -36,7 +32,3 @@ ) _lazy_loader.LazyLoadPlugin("hive_sensor", ["hmsclient>=0.0.1,<1.0.0"], [hmsclient]) - -_lazy_loader.LazyLoadPlugin("sagemaker", ["sagemaker-training>=3.6.2,<4.0.0"], [sagemaker_training]) - -_lazy_loader.LazyLoadPlugin("papermill", ["papermill>=2.0.0,<3.0.0"], [papermill]) diff --git a/flytekit/sdk/sagemaker/__init__.py b/flytekit/sdk/sagemaker/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/flytekit/sdk/sagemaker/task.py b/flytekit/sdk/sagemaker/task.py deleted file mode 100644 index cde721467b..0000000000 --- a/flytekit/sdk/sagemaker/task.py +++ /dev/null @@ -1,153 +0,0 @@ -import datetime as _datetime -import typing - -from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask -from flytekit.common.tasks.sagemaker.distributed_training import DefaultOutputPersistPredicate -from flytekit.models.sagemaker import training_job as _training_job_models - - -def custom_training_job_task( - _task_function=None, - algorithm_specification: _training_job_models.AlgorithmSpecification = None, - training_job_resource_config: _training_job_models.TrainingJobResourceConfig = None, - cache_version: str = "", - retries: int = 0, - deprecated: str = "", - storage_request: str = None, - cpu_request: str = None, - gpu_request: str = None, - memory_request: str = None, - storage_limit: str = None, - cpu_limit: str = None, - gpu_limit: str = None, - memory_limit: str = None, - cache: bool = False, - timeout: _datetime.timedelta = None, - environment: typing.Dict[str, str] = None, - cls: typing.Type = None, - output_persist_predicate: typing.Callable = DefaultOutputPersistPredicate(), -): - """ - Decorator to create a Custom Training Job definition. This task will run as a single unit of work on the platform. - - .. code-block:: python - - @inputs(int_list=[Types.Integer]) - @outputs(sum_of_list=Types.Integer - @custom_task - def my_task(wf_params, int_list, sum_of_list): - sum_of_list.set(sum(int_list)) - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param _training_job_models.AlgorithmSpecification algorithm_specification: This represents the algorithm specification - - :param _training_job_models.TrainingJobResourceConfig training_job_resource_config: This represents the training job config. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - - :param Text storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for the task to run. Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_request: [optional] Kubernetes resource string for lower-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs. - Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for the task to execute. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for the task to run. This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for the task to execute. - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - - :param Text gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs. This amount is not - guaranteed! If not specified, it is set equal to gpu_request. - - :param Text memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for the task to execute. This amount is not guaranteed! If not specified, it is set equal to - memory_request. - - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to - inject bespoke logic into the base Flyte programming model. - - :param Callable output_persist_predicate: [optional] This callable should return a boolean and is used to indicate whether - the current copy (i.e., an instance of the task running on a particular node inside the worker pool) would - write output. - - :rtype: flytekit.common.tasks.sagemaker.custom_training_job_task.CustomTrainingJobTask - """ - - def wrapper(fn): - return (cls or CustomTrainingJobTask)( - task_function=fn, - cache_version=cache_version, - retries=retries, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - cache=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment, - algorithm_specification=algorithm_specification, - training_job_resource_config=training_job_resource_config, - output_persist_predicate=output_persist_predicate, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index 06845fc936..dd43055077 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -6,15 +6,12 @@ from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import generic_spark_task as _sdk_generic_spark_task from flytekit.common.tasks import hive_task as _sdk_hive_tasks -from flytekit.common.tasks import pytorch_task as _sdk_pytorch_tasks from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks from flytekit.common.tasks import sidecar_task as _sdk_sidecar_tasks from flytekit.common.tasks import spark_task as _sdk_spark_tasks from flytekit.common.tasks import task as _task -from flytekit.common.tasks import tensorflow_task as _sdk_tensorflow_tasks from flytekit.common.types import helpers as _type_helpers -from flytekit.contrib.notebook import tasks as _nb_tasks from flytekit.models import interface as _interface_model from flytekit.sdk.spark_types import SparkType as _spark_type @@ -1206,322 +1203,3 @@ def wrapper(fn): return wrapper(_task_function) else: return wrapper - - -def pytorch_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=False, - deprecated="", - cache=False, - timeout=None, - workers_count=1, - per_replica_storage_request="", - per_replica_cpu_request="", - per_replica_gpu_request="", - per_replica_memory_request="", - per_replica_storage_limit="", - per_replica_cpu_limit="", - per_replica_gpu_limit="", - per_replica_memory_limit="", - environment=None, - cls=None, -): - """ - Decorator to create a Pytorch Task definition. This task will submit PyTorchJob (see https://github.com/kubeflow/pytorch-operator) - defined by the code within the _task_function to k8s cluster. - - .. code-block:: python - - @inputs(int_list=[Types.Integer]) - @outputs(result=Types.Integer - @pytorch_task( - workers_count=2, - per_replica_cpu_request="500m", - per_replica_memory_request="4Gi", - per_replica_memory_limit="8Gi", - per_replica_gpu_limit="1", - ) - def my_pytorch_job(wf_params, int_list, result): - pass - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: [optional] boolean describing if the task is interruptible. - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - - :param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). - - :param Text per_replica_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for each replica spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text per_replica_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each replica - spawned for this job (i.e. both for master and workers). - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text per_replica_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each - replica spawned for this job (i.e. both for master and workers). - Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text per_replica_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for each replica spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text per_replica_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for each replica spawned for this job (i.e. both for master and workers). - This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text per_replica_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each replica - spawned for this job (i.e. both for master and workers). - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - - :param Text per_replica_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each - replica spawned for this job (i.e. both for master and workers). - This amount is not guaranteed! If not specified, it is set equal to gpu_request. - - :param Text per_replica_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for each replica spawned for this job (i.e. both for master and workers). - This amount is not guaranteed! If not specified, it is set equal to memory_request. - - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to - inject bespoke logic into the base Flyte programming model. - - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - - def wrapper(fn): - return (cls or _sdk_pytorch_tasks.SdkPyTorchTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.PYTORCH_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - workers_count=workers_count, - per_replica_storage_request=per_replica_storage_request, - per_replica_cpu_request=per_replica_cpu_request, - per_replica_gpu_request=per_replica_gpu_request, - per_replica_memory_request=per_replica_memory_request, - per_replica_storage_limit=per_replica_storage_limit, - per_replica_cpu_limit=per_replica_cpu_limit, - per_replica_gpu_limit=per_replica_gpu_limit, - per_replica_memory_limit=per_replica_memory_limit, - environment=environment or {}, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper - - -def tensorflow_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=False, - deprecated="", - cache=False, - timeout=None, - workers_count=1, - ps_replicas_count=None, - chief_replicas_count=None, - per_replica_storage_request="", - per_replica_cpu_request="", - per_replica_gpu_request="", - per_replica_memory_request="", - per_replica_storage_limit="", - per_replica_cpu_limit="", - per_replica_gpu_limit="", - per_replica_memory_limit="", - environment=None, - cls=None, -): - """ - Decorator to create a Tensorflow Task definition. This task will submit TFJob (see https://github.com/kubeflow/tf-operator) - defined by the code within the _task_function to k8s cluster. - - .. code-block:: python - - @inputs(int_list=[Types.Integer]) - @outputs(result=Types.Integer - @tensorflow_task( - workers_count=2, - ps_replicas_count=1, - chief_replicas_count=1, - per_replica_cpu_request="500m", - per_replica_memory_request="4Gi", - per_replica_memory_limit="8Gi", - per_replica_gpu_limit="1", - ) - def my_tensorflow_job(wf_params, int_list, result): - pass - - :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must - take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword - arguments are allowed for wrapped task functions. - - :param Text cache_version: [optional] string representing logical version for discovery. This field should be - updated whenever the underlying algorithm changes. - - .. note:: - - This argument is required to be a non-empty string if `cache` is True. - - :param int retries: [optional] integer determining number of times task can be retried on - :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults - to 0. - - .. note:: - - If retries > 0, the task must be able to recover from any remote state created within the user code. It is - strongly recommended that tasks are written to be idempotent. - - :param bool interruptible: [optional] boolean describing if the task is interruptible. - - :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string - will be logged as a warning so it should contain information regarding how to update to a newer task. - - :param bool cache: [optional] boolean describing if the outputs of this task should be cached and - re-usable. - - :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to - run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run - indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. - - :param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job - - :param int ps_replicas_count: integer determining the number of parameter server replicas spawned in the cluster for this job - - :param int chief_replicas_count: integer determining the number of chief server replicas spawned in the cluster for this job - - :param Text per_replica_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space - for each replica spawned for this job (i.e. both for parameter, chief server and workers). Default is set by platform-level configuration. - - .. note:: - - This is currently not supported by the platform. - - :param Text per_replica_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each replica - spawned for this job (i.e. both for parameter, chief server and workers). - This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text per_replica_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each - replica spawned for this job (i.e. both for parameter, chief server and workers). - Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text per_replica_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory - necessary for each replica spawned for this job (i.e. both for parameter, chief server and workers). Default is set by platform-level configuration. - - TODO: Add links to resource string documentation for Kubernetes - - :param Text per_replica_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space - for each replica spawned for this job (i.e. both for parameter, chief server and workers). - This amount is not guaranteed! If not specified, it is set equal to storage_request. - - .. note:: - - This is currently not supported by the platform. - - :param Text per_replica_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each replica - spawned for this job (i.e. both for parameter, chief server and workers). - This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, - it is set equal to cpu_request. - - :param Text per_replica_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each - replica spawned for this job (i.e. both for parameter, chief server and workers). - This amount is not guaranteed! If not specified, it is set equal to gpu_request. - - :param Text per_replica_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory - necessary for each replica spawned for this job (i.e. both for parameter, chief server and workers). - This amount is not guaranteed! If not specified, it is set equal to memory_request. - - :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. - - :param cls: This can be used to override the task implementation with a user-defined extension. The class - provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to - inject bespoke logic into the base Flyte programming model. - - :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask - """ - - def wrapper(fn): - return (cls or _sdk_tensorflow_tasks.SdkTensorFlowTask)( - task_function=fn, - task_type=_common_constants.SdkTaskType.TENSORFLOW_TASK, - cache_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - cache=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - workers_count=workers_count, - ps_replicas_count=ps_replicas_count, - chief_replicas_count=chief_replicas_count, - per_replica_storage_request=per_replica_storage_request, - per_replica_cpu_request=per_replica_cpu_request, - per_replica_gpu_request=per_replica_gpu_request, - per_replica_memory_request=per_replica_memory_request, - per_replica_storage_limit=per_replica_storage_limit, - per_replica_cpu_limit=per_replica_cpu_limit, - per_replica_gpu_limit=per_replica_gpu_limit, - per_replica_memory_limit=per_replica_memory_limit, - environment=environment or {}, - ) - - if _task_function: - return wrapper(_task_function) - else: - return wrapper diff --git a/flytekit_scripts/flytekit_sagemaker_runner.py b/flytekit_scripts/flytekit_sagemaker_runner.py deleted file mode 100644 index d4989e611e..0000000000 --- a/flytekit_scripts/flytekit_sagemaker_runner.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse -import logging -import os -import subprocess -import sys - -from natsort import natsorted - -FLYTE_ARG_PREFIX = "--__FLYTE" -FLYTE_ENV_VAR_PREFIX = f"{FLYTE_ARG_PREFIX}_ENV_VAR_" -FLYTE_CMD_PREFIX = f"{FLYTE_ARG_PREFIX}_CMD_" -FLYTE_ARG_SUFFIX = "__" - - -# This script is the "entrypoint" script for SageMaker. An environment variable must be set on the container (typically -# in the Dockerfile) of SAGEMAKER_PROGRAM=flytekit_sagemaker_runner.py. When the container is launched in SageMaker, -# it'll run `train flytekit_sagemaker_runner.py `, the responsibility of this script is then to decode -# the known hyperparameters (passed as command line args) to recreate the original command that will actually run the -# virtual environment and execute the intended task (e.g. `service_venv pyflyte-execute --task-module ....`) - -# An example for a valid command: -# python flytekit_sagemaker_runner.py --__FLYTE_ENV_VAR_env1__ val1 --__FLYTE_ENV_VAR_env2__ val2 -# --__FLYTE_CMD_0_service_venv__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_1_pyflyte-execute__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_2_--task-module__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_3_blah__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_4_--task-name__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_5_bloh__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_6_--output-prefix__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_7_s3://fake-bucket__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_8_--inputs__ __FLYTE_CMD_DUMMY_VALUE__ -# --__FLYTE_CMD_9_s3://fake-bucket__ __FLYTE_CMD_DUMMY_VALUE__ - - -def parse_args(cli_args): - parser = argparse.ArgumentParser(description="Running sagemaker task") - args, unknowns = parser.parse_known_args(cli_args) - - # Parse the command line and env vars - flyte_cmd = [] - env_vars = {} - i = 0 - - while i < len(unknowns): - unknown = unknowns[i] - logging.info(f"Processing argument {unknown}") - if unknown.startswith(FLYTE_CMD_PREFIX) and unknown.endswith(FLYTE_ARG_SUFFIX): - processed = unknown[len(FLYTE_CMD_PREFIX) :][: -len(FLYTE_ARG_SUFFIX)] - # Parse the format `1_--task-module` - parts = processed.split("_", maxsplit=1) - flyte_cmd.append((parts[0], parts[1])) - i += 1 - elif unknown.startswith(FLYTE_ENV_VAR_PREFIX) and unknown.endswith(FLYTE_ARG_SUFFIX): - processed = unknown[len(FLYTE_ENV_VAR_PREFIX) :][: -len(FLYTE_ARG_SUFFIX)] - i += 1 - if unknowns[i].startswith(FLYTE_ARG_PREFIX) is False: - env_vars[processed] = unknowns[i] - i += 1 - else: - # To prevent SageMaker from ignoring our __FLYTE_CMD_*__ hyperparameters, we need to set a dummy value - # which serves as a placeholder for each of them. The dummy value placeholder `__FLYTE_CMD_DUMMY_VALUE__` - # falls into this branch and will be ignored - i += 1 - - return flyte_cmd, env_vars - - -def sort_flyte_cmd(flyte_cmd): - # Order the cmd using the index (the first element in each tuple) - flyte_cmd = natsorted(flyte_cmd, key=lambda x: x[0]) - flyte_cmd = [x[1] for x in flyte_cmd] - return flyte_cmd - - -def set_env_vars(env_vars): - for key, val in env_vars.items(): - os.environ[key] = val - - -def run(cli_args): - flyte_cmd, env_vars = parse_args(cli_args) - flyte_cmd = sort_flyte_cmd(flyte_cmd) - set_env_vars(env_vars) - - logging.info(f"Cmd:{flyte_cmd}") - logging.info(f"Env vars:{env_vars}") - - # Launching a subprocess with the selected entrypoint script and the rest of the arguments - logging.info(f"Launching command: {flyte_cmd}") - subprocess.run(flyte_cmd, stdout=sys.stdout, stderr=sys.stderr, encoding="utf-8", check=True) - - -if __name__ == "__main__": - run(sys.argv) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index c3a705f48e..03e1e38b0d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -1,18 +1,18 @@ -from flytekit.models.sagemaker.hpo_job import ( +from flytekitplugins.awssagemaker.models.hpo_job import ( HyperparameterTuningJobConfig, HyperparameterTuningObjective, HyperparameterTuningObjectiveType, HyperparameterTuningStrategy, TrainingJobEarlyStoppingType, ) -from flytekit.models.sagemaker.parameter_ranges import ( +from flytekitplugins.awssagemaker.models.parameter_ranges import ( CategoricalParameterRange, ContinuousParameterRange, HyperparameterScalingType, IntegerParameterRange, ParameterRangeOneOf, ) -from flytekit.models.sagemaker.training_job import ( +from flytekitplugins.awssagemaker.models.training_job import ( AlgorithmName, AlgorithmSpecification, DistributedProtocol, diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py index cdf144d23e..629c8c588d 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/hpo.py @@ -12,11 +12,12 @@ from flytekit.common.types import primitives from flytekit.extend import DictTransformer, PythonTask, SerializationSettings, TypeEngine, TypeTransformer from flytekit.models.literals import Literal -from flytekit.models.sagemaker import hpo_job as _hpo_job_model -from flytekit.models.sagemaker import parameter_ranges as _params -from flytekit.models.sagemaker import training_job as _training_job_model from flytekit.models.types import LiteralType +from .models import hpo_job as _hpo_job_model +from .models import parameter_ranges as _params +from .models import training_job as _training_job_model + @dataclass class HPOJob(object): diff --git a/flytekit/common/tasks/sagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/__init__.py similarity index 100% rename from flytekit/common/tasks/sagemaker/__init__.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/__init__.py diff --git a/flytekit/models/sagemaker/hpo_job.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py similarity index 99% rename from flytekit/models/sagemaker/hpo_job.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py index ea484f26cf..16b11c4bf1 100644 --- a/flytekit/models/sagemaker/hpo_job.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/hpo_job.py @@ -1,7 +1,8 @@ from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job from flytekit.models import common as _common -from flytekit.models.sagemaker import training_job as _training_job + +from . import training_job as _training_job class HyperparameterTuningObjectiveType(object): diff --git a/flytekit/models/sagemaker/parameter_ranges.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py similarity index 100% rename from flytekit/models/sagemaker/parameter_ranges.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py diff --git a/flytekit/models/sagemaker/training_job.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/training_job.py similarity index 100% rename from flytekit/models/sagemaker/training_job.py rename to plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/training_job.py diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py index df28d4e9e5..557165259a 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py @@ -9,10 +9,11 @@ import flytekit from flytekit import ExecutionParameters, FlyteContextManager, PythonFunctionTask, kwtypes from flytekit.extend import ExecutionState, IgnoreOutputs, Interface, PythonTask, SerializationSettings, TaskPlugins -from flytekit.models.sagemaker import training_job as _training_job_models from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FlyteFile +from .models import training_job as _training_job_models + @dataclass class SagemakerTrainingJobConfig(object): diff --git a/plugins/flytekit-aws-sagemaker/tests/test_hpo.py b/plugins/flytekit-aws-sagemaker/tests/test_hpo.py index 65cbac0343..28a226e696 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_hpo.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_hpo.py @@ -9,18 +9,22 @@ ParameterRangesTransformer, SagemakerHPOTask, ) -from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerTrainingJobConfig - -from flytekit import FlyteContext -from flytekit.common.types.primitives import Generic -from flytekit.models.sagemaker.hpo_job import ( +from flytekitplugins.awssagemaker.models.hpo_job import ( HyperparameterTuningJobConfig, HyperparameterTuningObjective, HyperparameterTuningObjectiveType, TrainingJobEarlyStoppingType, ) -from flytekit.models.sagemaker.parameter_ranges import IntegerParameterRange, ParameterRangeOneOf -from flytekit.models.sagemaker.training_job import AlgorithmName, AlgorithmSpecification, TrainingJobResourceConfig +from flytekitplugins.awssagemaker.models.parameter_ranges import IntegerParameterRange, ParameterRangeOneOf +from flytekitplugins.awssagemaker.models.training_job import ( + AlgorithmName, + AlgorithmSpecification, + TrainingJobResourceConfig, +) +from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerTrainingJobConfig + +from flytekit import FlyteContext +from flytekit.common.types.primitives import Generic from .test_training import _get_reg_settings diff --git a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py b/plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py similarity index 97% rename from tests/flytekit/unit/models/sagemaker/test_hpo_job.py rename to plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py index 4b38672300..494eecd2ab 100644 --- a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_hpo_job.py @@ -1,4 +1,4 @@ -from flytekit.models.sagemaker import hpo_job, training_job +from flytekitplugins.awssagemaker.models import hpo_job, training_job def test_hyperparameter_tuning_objective(): diff --git a/tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py b/plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py similarity index 98% rename from tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py rename to plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py index e10899bdb7..6d33388c33 100644 --- a/tests/flytekit/unit/models/sagemaker/test_parameter_ranges.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_parameter_ranges.py @@ -1,8 +1,7 @@ import unittest import pytest - -from flytekit.models.sagemaker import parameter_ranges +from flytekitplugins.awssagemaker.models import parameter_ranges # assert statements cannot be written inside lambda expressions. This is a convenient function to work around that. diff --git a/plugins/flytekit-aws-sagemaker/tests/test_training.py b/plugins/flytekit-aws-sagemaker/tests/test_training.py index 84c596cd64..3e6ada1ff5 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_training.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_training.py @@ -4,18 +4,18 @@ import pytest from flytekitplugins.awssagemaker.distributed_training import setup_envars_for_testing +from flytekitplugins.awssagemaker.models.training_job import ( + AlgorithmName, + AlgorithmSpecification, + DistributedProtocol, + TrainingJobResourceConfig, +) from flytekitplugins.awssagemaker.training import SagemakerBuiltinAlgorithmsTask, SagemakerTrainingJobConfig import flytekit from flytekit import task from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.extend import Image, ImageConfig, SerializationSettings -from flytekit.models.sagemaker.training_job import ( - AlgorithmName, - AlgorithmSpecification, - DistributedProtocol, - TrainingJobResourceConfig, -) def _get_reg_settings(): diff --git a/tests/flytekit/unit/models/sagemaker/test_training_job.py b/plugins/flytekit-aws-sagemaker/tests/test_training_job.py similarity index 98% rename from tests/flytekit/unit/models/sagemaker/test_training_job.py rename to plugins/flytekit-aws-sagemaker/tests/test_training_job.py index 271669b16c..8774857b1f 100644 --- a/tests/flytekit/unit/models/sagemaker/test_training_job.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_training_job.py @@ -1,6 +1,6 @@ import unittest -from flytekit.models.sagemaker import training_job +from flytekitplugins.awssagemaker.models import training_job def test_training_job_resource_config(): diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py new file mode 100644 index 0000000000..517f4a9eb6 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py @@ -0,0 +1,23 @@ +from flyteidl.plugins import pytorch_pb2 as _pytorch_task + +from flytekit.models import common as _common + + +class PyTorchJob(_common.FlyteIdlEntity): + def __init__(self, workers_count): + self._workers_count = workers_count + + @property + def workers_count(self): + return self._workers_count + + def to_flyte_idl(self): + return _pytorch_task.DistributedPyTorchTrainingTask( + workers=self.workers_count, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + workers_count=pb2_object.workers, + ) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 64a26e5408..c72f615307 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -9,7 +9,8 @@ from flytekit import PythonFunctionTask from flytekit.extend import SerializationSettings, TaskPlugins -from flytekit.models import task as _task_model + +from .models import PyTorchJob @dataclass @@ -44,7 +45,7 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = _task_model.PyTorchJob(workers_count=self.task_config.num_workers) + job = PyTorchJob(workers_count=self.task_config.num_workers) return MessageToDict(job.to_flyte_idl()) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py new file mode 100644 index 0000000000..87d7bb7b90 --- /dev/null +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py @@ -0,0 +1,35 @@ +from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task + +from flytekit.models import common as _common + + +class TensorFlowJob(_common.FlyteIdlEntity): + def __init__(self, workers_count, ps_replicas_count, chief_replicas_count): + self._workers_count = workers_count + self._ps_replicas_count = ps_replicas_count + self._chief_replicas_count = chief_replicas_count + + @property + def workers_count(self): + return self._workers_count + + @property + def ps_replicas_count(self): + return self._ps_replicas_count + + @property + def chief_replicas_count(self): + return self._chief_replicas_count + + def to_flyte_idl(self): + return _tensorflow_task.DistributedTensorflowTrainingTask( + workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + workers_count=pb2_object.workers, + ps_replicas_count=pb2_object.ps_replicas, + chief_replicas_count=pb2_object.chief_replicas, + ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 1d1526410e..f8b767a631 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -9,7 +9,8 @@ from flytekit import PythonFunctionTask from flytekit.extend import SerializationSettings, TaskPlugins -from flytekit.models import task as _task_model + +from .models import TensorFlowJob @dataclass @@ -50,7 +51,7 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = _task_model.TensorFlowJob( + job = TensorFlowJob( workers_count=self.task_config.num_workers, ps_replicas_count=self.task_config.num_ps_replicas, chief_replicas_count=self.task_config.num_chief_replicas, diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py new file mode 100644 index 0000000000..d03949f1ac --- /dev/null +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -0,0 +1,147 @@ +import typing + +from flyteidl.plugins import spark_pb2 as _spark_task + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import common as _common +from flytekit.sdk.spark_types import SparkType as _spark_type + + +class SparkJob(_common.FlyteIdlEntity): + def __init__( + self, + spark_type, + application_file, + main_class, + spark_conf, + hadoop_conf, + executor_path, + ): + """ + This defines a SparkJob target. It will execute the appropriate SparkJob. + + :param application_file: The main application file to execute. + :param dict[Text, Text] spark_conf: A definition of key-value pairs for spark config for the job. + :param dict[Text, Text] hadoop_conf: A definition of key-value pairs for hadoop config for the job. + """ + self._application_file = application_file + self._spark_type = spark_type + self._main_class = main_class + self._executor_path = executor_path + self._spark_conf = spark_conf + self._hadoop_conf = hadoop_conf + + def with_overrides( + self, new_spark_conf: typing.Dict[str, str] = None, new_hadoop_conf: typing.Dict[str, str] = None + ) -> "SparkJob": + if not new_spark_conf: + new_spark_conf = self.spark_conf + + if not new_hadoop_conf: + new_hadoop_conf = self.hadoop_conf + + return SparkJob( + spark_type=self.spark_type, + application_file=self.application_file, + main_class=self.main_class, + spark_conf=new_spark_conf, + hadoop_conf=new_hadoop_conf, + executor_path=self.executor_path, + ) + + @property + def main_class(self): + """ + The main class to execute + :rtype: Text + """ + return self._main_class + + @property + def spark_type(self): + """ + Spark Job Type + :rtype: Text + """ + return self._spark_type + + @property + def application_file(self): + """ + The main application file to execute + :rtype: Text + """ + return self._application_file + + @property + def executor_path(self): + """ + The python executable to use + :rtype: Text + """ + return self._executor_path + + @property + def spark_conf(self): + """ + A definition of key-value pairs for spark config for the job. + :rtype: dict[Text, Text] + """ + return self._spark_conf + + @property + def hadoop_conf(self): + """ + A definition of key-value pairs for hadoop config for the job. + :rtype: dict[Text, Text] + """ + return self._hadoop_conf + + def to_flyte_idl(self): + """ + :rtype: flyteidl.plugins.spark_pb2.SparkJob + """ + + if self.spark_type == _spark_type.PYTHON: + application_type = _spark_task.SparkApplication.PYTHON + elif self.spark_type == _spark_type.JAVA: + application_type = _spark_task.SparkApplication.JAVA + elif self.spark_type == _spark_type.SCALA: + application_type = _spark_task.SparkApplication.SCALA + elif self.spark_type == _spark_type.R: + application_type = _spark_task.SparkApplication.R + else: + raise _user_exceptions.FlyteValidationException("Invalid Spark Application Type Specified") + + return _spark_task.SparkJob( + applicationType=application_type, + mainApplicationFile=self.application_file, + mainClass=self.main_class, + executorPath=self.executor_path, + sparkConf=self.spark_conf, + hadoopConf=self.hadoop_conf, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.plugins.spark_pb2.SparkJob pb2_object: + :rtype: SparkJob + """ + + application_type = _spark_type.PYTHON + if pb2_object.type == _spark_task.SparkApplication.JAVA: + application_type = _spark_type.JAVA + elif pb2_object.type == _spark_task.SparkApplication.SCALA: + application_type = _spark_type.SCALA + elif pb2_object.type == _spark_task.SparkApplication.R: + application_type = _spark_type.R + + return cls( + type=application_type, + spark_conf=pb2_object.sparkConf, + application_file=pb2_object.mainApplicationFile, + main_class=pb2_object.mainClass, + hadoop_conf=pb2_object.hadoopConf, + executor_path=pb2_object.executorPath, + ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index a2893a06ad..bb20da8b48 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,9 +9,10 @@ from flytekit import FlyteContextManager, PythonFunctionTask from flytekit.common.tasks.sdk_runnable import ExecutionParameters from flytekit.extend import ExecutionState, SerializationSettings, TaskPlugins -from flytekit.models import task as _task_model from flytekit.sdk.spark_types import SparkType +from .models import SparkJob + @dataclass class Spark(object): @@ -91,7 +92,7 @@ def __init__(self, task_config: Spark, task_function: Callable, **kwargs): self.sess: Optional[SparkSession] = None def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = _task_model.SparkJob( + job = SparkJob( spark_conf=self.task_config.spark_conf, hadoop_conf=self.task_config.hadoop_conf, application_file="local://" + settings.entrypoint_settings.path if settings.entrypoint_settings else "", diff --git a/setup.py b/setup.py index bf730e2c67..ae3591fa85 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,6 @@ scripts=[ "flytekit_scripts/flytekit_build_image.sh", "flytekit_scripts/flytekit_venv", - "flytekit_scripts/flytekit_sagemaker_runner.py", "flytekit/bin/entrypoint.py", ], license="apache2", diff --git a/tests/flytekit/common/workflows/sagemaker.py b/tests/flytekit/common/workflows/sagemaker.py deleted file mode 100644 index 044bba29f2..0000000000 --- a/tests/flytekit/common/workflows/sagemaker.py +++ /dev/null @@ -1,121 +0,0 @@ -import os as _os - -from flytekit import configuration as _configuration -from flytekit.common.tasks.sagemaker import hpo_job_task -from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask -from flytekit.common.tasks.sagemaker.types import HyperparameterTuningJobConfig -from flytekit.models.sagemaker.hpo_job import HyperparameterTuningJobConfig as _HyperparameterTuningJobConfig -from flytekit.models.sagemaker.hpo_job import ( - HyperparameterTuningObjective, - HyperparameterTuningObjectiveType, - HyperparameterTuningStrategy, - TrainingJobEarlyStoppingType, -) -from flytekit.models.sagemaker.parameter_ranges import ( - ContinuousParameterRange, - HyperparameterScalingType, - IntegerParameterRange, -) -from flytekit.models.sagemaker.training_job import ( - AlgorithmName, - AlgorithmSpecification, - InputContentType, - InputMode, - TrainingJobResourceConfig, -) -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - -example_hyperparams = { - "base_score": "0.5", - "booster": "gbtree", - "csv_weights": "0", - "dsplit": "row", - "grow_policy": "depthwise", - "lambda_bias": "0.0", - "max_bin": "256", - "max_leaves": "0", - "normalize_type": "tree", - "objective": "reg:linear", - "one_drop": "0", - "prob_buffer_row": "1.0", - "process_type": "default", - "rate_drop": "0.0", - "refresh_leaf": "1", - "sample_type": "uniform", - "scale_pos_weight": "1.0", - "silent": "0", - "sketch_eps": "0.03", - "skip_drop": "0.0", - "tree_method": "auto", - "tweedie_variance_power": "1.5", - "updater": "grow_colmaker,prune", -} - -builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - algorithm_name=AlgorithmName.XGBOOST, - algorithm_version="0.72", - ), -) - -simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask( - training_job=builtin_algorithm_training_job_task2, - max_number_of_training_jobs=10, - max_parallel_training_jobs=5, - cache_version="1", - retries=2, - cacheable=True, - tunable_parameters=["num_round", "max_depth", "gamma"], -) - - -@workflow_class -class SageMakerHPO(object): - train_dataset = Input(Types.MultiPartCSV, default="s3://somelocation") - validation_dataset = Input(Types.MultiPartCSV, default="s3://somelocation") - static_hyperparameters = Input(Types.Generic, default=example_hyperparams) - hyperparameter_tuning_job_config = Input( - HyperparameterTuningJobConfig, - default=_HyperparameterTuningJobConfig( - tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, - tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, - metric_name="validation:error", - ), - training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, - ), - ) - - a = simple_xgboost_hpo_job_task( - train=train_dataset, - validation=validation_dataset, - static_hyperparameters=static_hyperparameters, - hyperparameter_tuning_job_config=hyperparameter_tuning_job_config, - num_round=IntegerParameterRange(min_value=2, max_value=8, scaling_type=HyperparameterScalingType.LINEAR), - max_depth=IntegerParameterRange(min_value=5, max_value=7, scaling_type=HyperparameterScalingType.LINEAR), - gamma=ContinuousParameterRange(min_value=0.0, max_value=0.3, scaling_type=HyperparameterScalingType.LINEAR), - ) - - -sagemaker_hpo_lp = SageMakerHPO.create_launch_plan() - -with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - "../../common/configs/local.config", - ), - internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, -): - print("Printing WF definition") - print(SageMakerHPO) - - print("Printing LP definition") - print(sagemaker_hpo_lp) diff --git a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py deleted file mode 100644 index 8472cde292..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py +++ /dev/null @@ -1,48 +0,0 @@ -import datetime as _datetime - -from flytekit.common import constants as _common_constants -from flytekit.common.tasks import pytorch_task as _pytorch_task -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.models import types as _type_models -from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import inputs, outputs, pytorch_task -from flytekit.sdk.types import Types - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@pytorch_task(workers_count=1) -def simple_pytorch_task(wf_params, sc, in1, out1): - pass - - -simple_pytorch_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") - - -def test_simple_pytorch_task(): - assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask) - assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask) - assert simple_pytorch_task.interface.inputs["in1"].description == "" - assert simple_pytorch_task.interface.inputs["in1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.INTEGER - ) - assert simple_pytorch_task.interface.outputs["out1"].description == "" - assert simple_pytorch_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING - ) - assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK - assert simple_pytorch_task.task_function_name == "simple_pytorch_task" - assert simple_pytorch_task.task_module == __name__ - assert simple_pytorch_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert simple_pytorch_task.metadata.deprecated_error_message == "" - assert simple_pytorch_task.metadata.discoverable is False - assert simple_pytorch_task.metadata.discovery_version == "" - assert simple_pytorch_task.metadata.retries.retries == 0 - assert len(simple_pytorch_task.container.resources.limits) == 0 - assert len(simple_pytorch_task.container.resources.requests) == 0 - assert simple_pytorch_task.custom["workers"] == 1 - # Should strip out the venv component of the args. - assert simple_pytorch_task._get_container_definition().args[0] == "pyflyte-execute" - - pb2 = simple_pytorch_task.to_flyte_idl() - assert pb2.custom["workers"] == 1 diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py deleted file mode 100644 index 19b4b8e2a1..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py +++ /dev/null @@ -1,540 +0,0 @@ -import datetime as _datetime -import os as _os -import unittest -from unittest import mock - -import retry.api -from flyteidl.plugins.sagemaker.training_job_pb2 import TrainingJobResourceConfig as _pb2_TrainingJobResourceConfig -from google.protobuf.json_format import ParseDict - -from flytekit.common import constants as _common_constants -from flytekit.common import utils as _utils -from flytekit.common.core.identifier import WorkflowExecutionIdentifier -from flytekit.common.tasks import task as _sdk_task -from flytekit.common.tasks.sagemaker import distributed_training as _sm_distribution -from flytekit.common.tasks.sagemaker import hpo_job_task -from flytekit.common.tasks.sagemaker.built_in_training_job_task import SdkBuiltinAlgorithmTrainingJobTask -from flytekit.common.tasks.sagemaker.custom_training_job_task import CustomTrainingJobTask -from flytekit.common.tasks.sagemaker.hpo_job_task import ( - HyperparameterTuningJobConfig, - SdkSimpleHyperparameterTuningJobTask, -) -from flytekit.common.types import helpers as _type_helpers -from flytekit.engines import common as _common_engine -from flytekit.engines.unit.mock_stats import MockStats -from flytekit.models import literals as _literals -from flytekit.models import types as _idl_types -from flytekit.models.core import identifier as _identifier -from flytekit.models.core import types as _core_types -from flytekit.models.sagemaker.hpo_job import HyperparameterTuningJobConfig as _HyperparameterTuningJobConfig -from flytekit.models.sagemaker.hpo_job import ( - HyperparameterTuningObjective, - HyperparameterTuningObjectiveType, - HyperparameterTuningStrategy, - TrainingJobEarlyStoppingType, -) -from flytekit.models.sagemaker.parameter_ranges import ( - ContinuousParameterRange, - HyperparameterScalingType, - IntegerParameterRange, - ParameterRangeOneOf, -) -from flytekit.models.sagemaker.training_job import ( - AlgorithmName, - AlgorithmSpecification, - InputContentType, - InputMode, - MetricDefinition, - TrainingJobResourceConfig, -) -from flytekit.sdk import types as _sdk_types -from flytekit.sdk.sagemaker.task import custom_training_job_task -from flytekit.sdk.tasks import inputs, outputs -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow_class - -example_hyperparams = { - "base_score": "0.5", - "booster": "gbtree", - "csv_weights": "0", - "dsplit": "row", - "grow_policy": "depthwise", - "lambda_bias": "0.0", - "max_bin": "256", - "max_leaves": "0", - "normalize_type": "tree", - "objective": "reg:linear", - "one_drop": "0", - "prob_buffer_row": "1.0", - "process_type": "default", - "rate_drop": "0.0", - "refresh_leaf": "1", - "sample_type": "uniform", - "scale_pos_weight": "1.0", - "silent": "0", - "sketch_eps": "0.03", - "skip_drop": "0.0", - "tree_method": "auto", - "tweedie_variance_power": "1.5", - "updater": "grow_colmaker,prune", -} - - -def test_builtin_algorithm_training_job_task(): - builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - algorithm_name=AlgorithmName.XGBOOST, - algorithm_version="0.72", - ), - ) - - builtin_algorithm_training_job_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version" - ) - assert isinstance(builtin_algorithm_training_job_task, SdkBuiltinAlgorithmTrainingJobTask) - assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) - assert builtin_algorithm_training_job_task.interface.inputs["train"].description == "" - assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) - assert ( - builtin_algorithm_training_job_task.interface.inputs["train"].type - == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - ) - assert builtin_algorithm_training_job_task.interface.inputs["validation"].description == "" - assert ( - builtin_algorithm_training_job_task.interface.inputs["validation"].type - == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - ) - assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) - assert builtin_algorithm_training_job_task.interface.inputs["static_hyperparameters"].description == "" - assert ( - builtin_algorithm_training_job_task.interface.inputs["static_hyperparameters"].type - == _sdk_types.Types.Generic.to_flyte_literal_type() - ) - assert builtin_algorithm_training_job_task.interface.outputs["model"].description == "" - assert ( - builtin_algorithm_training_job_task.interface.outputs["model"].type - == _sdk_types.Types.Blob.to_flyte_literal_type() - ) - assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK - assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == "" - assert builtin_algorithm_training_job_task.metadata.discoverable is False - assert builtin_algorithm_training_job_task.metadata.discovery_version == "" - assert builtin_algorithm_training_job_task.metadata.retries.retries == 0 - assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom["algorithmSpecification"].keys() - - ParseDict( - builtin_algorithm_training_job_task.custom["trainingJobResourceConfig"], - _pb2_TrainingJobResourceConfig(), - ) # fails the test if it cannot be parsed - - -builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - algorithm_name=AlgorithmName.XGBOOST, - algorithm_version="0.72", - metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")], - ), -) - -simple_xgboost_hpo_job_task = hpo_job_task.SdkSimpleHyperparameterTuningJobTask( - training_job=builtin_algorithm_training_job_task2, - max_number_of_training_jobs=10, - max_parallel_training_jobs=5, - cache_version="1", - retries=2, - cacheable=True, - tunable_parameters=["num_round", "gamma", "max_depth"], -) - -simple_xgboost_hpo_job_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version" -) - - -def test_simple_hpo_job_task(): - assert isinstance(simple_xgboost_hpo_job_task, SdkSimpleHyperparameterTuningJobTask) - assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask) - # Checking if the input of the underlying SdkTrainingJobTask has been embedded - assert simple_xgboost_hpo_job_task.interface.inputs["train"].description == "" - assert ( - simple_xgboost_hpo_job_task.interface.inputs["train"].type - == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - ) - assert simple_xgboost_hpo_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) - assert simple_xgboost_hpo_job_task.interface.inputs["validation"].description == "" - assert ( - simple_xgboost_hpo_job_task.interface.inputs["validation"].type - == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - ) - assert simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) - assert simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].description == "" - assert ( - simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].type - == _sdk_types.Types.Generic.to_flyte_literal_type() - ) - - # Checking if the hpo-specific input is defined - assert simple_xgboost_hpo_job_task.interface.inputs["hyperparameter_tuning_job_config"].description == "" - assert ( - simple_xgboost_hpo_job_task.interface.inputs["hyperparameter_tuning_job_config"].type - == HyperparameterTuningJobConfig.to_flyte_literal_type() - ) - assert simple_xgboost_hpo_job_task.interface.outputs["model"].description == "" - assert simple_xgboost_hpo_job_task.interface.outputs["model"].type == _sdk_types.Types.Blob.to_flyte_literal_type() - assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK - - # Checking if the spec of the TrainingJob is embedded into the custom field - # of this SdkSimpleHyperparameterTuningJobTask - assert simple_xgboost_hpo_job_task.to_flyte_idl().custom["trainingJob"] == ( - builtin_algorithm_training_job_task2.to_flyte_idl().custom - ) - - assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert simple_xgboost_hpo_job_task.metadata.discoverable is True - assert simple_xgboost_hpo_job_task.metadata.discovery_version == "1" - assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2 - - assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == "" - assert "metricDefinitions" in simple_xgboost_hpo_job_task.custom["trainingJob"]["algorithmSpecification"].keys() - assert len(simple_xgboost_hpo_job_task.custom["trainingJob"]["algorithmSpecification"]["metricDefinitions"]) == 1 - """ - These are attributes for SdkRunnable. We will need these when supporting CustomTrainingJobTask and CustomHPOJobTask - assert simple_xgboost_hpo_job_task.task_module == __name__ - assert simple_xgboost_hpo_job_task._get_container_definition().args[0] == 'pyflyte-execute' - """ - - -def test_custom_training_job(): - @inputs(input_1=Types.Integer) - @outputs(model=Types.Blob) - @custom_training_job_task( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")], - ), - ) - def my_task(wf_params, input_1, model): - pass - - assert type(my_task) == CustomTrainingJobTask - - -def test_simple_hpo_job_task_interface(): - @workflow_class - class MyWf(object): - train_dataset = Input(Types.Blob) - validation_dataset = Input(Types.Blob) - static_hyperparameters = Input(Types.Generic) - hyperparameter_tuning_job_config = Input( - HyperparameterTuningJobConfig, - default=_HyperparameterTuningJobConfig( - tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, - tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, - metric_name="validation:error", - ), - training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, - ), - ) - - a = simple_xgboost_hpo_job_task( - train=train_dataset, - validation=validation_dataset, - static_hyperparameters=static_hyperparameters, - hyperparameter_tuning_job_config=hyperparameter_tuning_job_config, - num_round=ParameterRangeOneOf( - IntegerParameterRange(min_value=3, max_value=10, scaling_type=HyperparameterScalingType.LINEAR) - ), - max_depth=ParameterRangeOneOf( - IntegerParameterRange(min_value=5, max_value=7, scaling_type=HyperparameterScalingType.LINEAR) - ), - gamma=ParameterRangeOneOf( - ContinuousParameterRange(min_value=0.0, max_value=0.3, scaling_type=HyperparameterScalingType.LINEAR) - ), - ) - - assert MyWf.nodes[0].inputs[2].binding.scalar.generic is not None - - -# Defining a output-persist predicate to indicate if the copy of the instance should persist its output -def predicate(distributed_training_context): - return ( - distributed_training_context[_sm_distribution.DistributedTrainingContextKey.CURRENT_HOST] - == distributed_training_context[_sm_distribution.DistributedTrainingContextKey.HOSTS][1] - ) - - -def dontretry(f, *args, **kwargs): - return f() - - -class SingleNodeCustomTrainingJobTaskTests(unittest.TestCase): - @mock.patch.dict("os.environ", {}) - def setUp(self): - with _utils.AutoDeletingTempDir("input_dir") as input_dir: - self._task_input = _literals.LiteralMap( - {"input_1": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1)))} - ) - - self._context = _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), - execution_date=_datetime.datetime.utcnow(), - stats=MockStats(), - logging=None, - tmp_dir=input_dir.name, - ) - - @inputs(input_1=Types.Integer) - @outputs(model=Types.Blob) - @custom_training_job_task( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")], - ), - ) - def my_single_node_task(wf_params, input_1, model): - pass - - self._my_single_node_task = my_single_node_task - assert type(self._my_single_node_task) == CustomTrainingJobTask - - def test_output_persistence(self): - # In single-node training on SageMaker, the distributed training env vars are still filled in - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-0", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - ret = self._my_single_node_task.execute(self._context, self._task_input) - assert _common_constants.OUTPUT_FILE_NAME in ret.keys() - - -class DistributedCustomTrainingJobTaskTests(unittest.TestCase): - @mock.patch.dict("os.environ", {}) - def setUp(self): - with _utils.AutoDeletingTempDir("input_dir") as input_dir: - self._task_input = _literals.LiteralMap( - {"input_1": _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1)))} - ) - - self._context = _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), - execution_date=_datetime.datetime.utcnow(), - stats=MockStats(), - logging=None, - tmp_dir=input_dir.name, - ) - - # Defining the distributed training task without specifying an output-persist - # predicate (so it will use the default) - @inputs(input_1=Types.Integer) - @outputs(model=Types.Blob) - @custom_training_job_task( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=2, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")], - ), - ) - def my_distributed_task(wf_params, input_1, model): - pass - - self._my_distributed_task = my_distributed_task - assert type(self._my_distributed_task) == CustomTrainingJobTask - - def test_missing_current_host_in_distributed_training_context_keys_lead_to_keyerrors(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - # eliminate the wait in unittest https://stackoverflow.com/a/32698175 - with mock.patch.object(retry.api, "__retry_internal", dontretry): - self.assertRaises(KeyError, self._my_distributed_task.execute, self._context, self._task_input) - - def test_missing_hosts_distributed_training_context_keys_lead_to_keyerrors(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - # eliminate the wait in unittest https://stackoverflow.com/a/32698175 - with mock.patch.object(retry.api, "__retry_internal", dontretry): - self.assertRaises(KeyError, self._my_distributed_task.execute, self._context, self._task_input) - - def test_missing_network_interface_name_in_distributed_training_context_keys_lead_to_keyerrors(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - }, - clear=True, - ): - # eliminate the wait in unittest https://stackoverflow.com/a/32698175 - with mock.patch.object(retry.api, "__retry_internal", dontretry): - self.assertRaises(KeyError, self._my_distributed_task.execute, self._context, self._task_input) - - def test_with_default_predicate_with_rank0_master(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-0", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - # execute the distributed task with its distributed_training_context == None - ret = self._my_distributed_task.execute(self._context, self._task_input) - assert _common_constants.OUTPUT_FILE_NAME in ret.keys() - - def test_with_default_predicate_with_rank1_master(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - ret = self._my_distributed_task.execute(self._context, self._task_input) - assert not ret - - def test_with_custom_predicate_with_none_dist_context(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - - self._my_distributed_task._output_persist_predicate = predicate - # execute the distributed task with its distributed_training_context == None - ret = self._my_distributed_task.execute(self._context, self._task_input) - assert ret - assert _common_constants.OUTPUT_FILE_NAME in ret.keys() - - def test_with_custom_predicate_with_valid_dist_context(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - # fill in the distributed_training_context to the context object and execute again - self._my_distributed_task._output_persist_predicate = predicate - ret = self._my_distributed_task.execute(self._context, self._task_input) - assert _common_constants.OUTPUT_FILE_NAME in ret.keys() - python_std_output_map = _type_helpers.unpack_literal_map_to_sdk_python_std( - ret[_common_constants.OUTPUT_FILE_NAME] - ) - assert "model" in python_std_output_map.keys() - - def test_if_wf_param_has_dist_context(self): - with mock.patch.dict( - _os.environ, - { - _sm_distribution.SM_ENV_VAR_CURRENT_HOST: "algo-1", - _sm_distribution.SM_ENV_VAR_HOSTS: '["algo-0", "algo-1", "algo-2"]', - _sm_distribution.SM_ENV_VAR_NETWORK_INTERFACE_NAME: "eth0", - }, - clear=True, - ): - - # This test is making sure that the distributed_training_context is successfully passed into the - # task_function. - # Specifically, we want to make sure the _execute_user_code() of the CustomTrainingJobTask class does the - # thing that it is supposed to do - - @inputs(input_1=Types.Integer) - @outputs(model=Types.Blob) - @custom_training_job_task( - training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=2, - volume_size_in_gb=25, - ), - algorithm_specification=AlgorithmSpecification( - input_mode=InputMode.FILE, - input_content_type=InputContentType.TEXT_CSV, - metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")], - ), - ) - def my_distributed_task_with_valid_dist_training_context(wf_params, input_1, model): - if not wf_params.distributed_training_context: - raise ValueError - - try: - my_distributed_task_with_valid_dist_training_context.execute(self._context, self._task_input) - except ValueError: - self.fail("The distributed_training_context is not passed into task function successfully") diff --git a/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py b/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py deleted file mode 100644 index 9e6caf29d3..0000000000 --- a/tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py +++ /dev/null @@ -1,55 +0,0 @@ -import datetime as _datetime - -from flytekit.common import constants as _common_constants -from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common.tasks import tensorflow_task as _tensorflow_task -from flytekit.models import types as _type_models -from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import inputs, outputs, tensorflow_task -from flytekit.sdk.types import Types - - -@inputs(in1=Types.Integer) -@outputs(out1=Types.String) -@tensorflow_task(workers_count=2, ps_replicas_count=1, chief_replicas_count=1) -def simple_tensorflow_task(wf_params, sc, in1, out1): - pass - - -simple_tensorflow_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "project", "domain", "name", "version" -) - - -def test_simple_tensorflow_task(): - assert isinstance(simple_tensorflow_task, _tensorflow_task.SdkTensorFlowTask) - assert isinstance(simple_tensorflow_task, _sdk_runnable.SdkRunnableTask) - assert simple_tensorflow_task.interface.inputs["in1"].description == "" - assert simple_tensorflow_task.interface.inputs["in1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.INTEGER - ) - assert simple_tensorflow_task.interface.outputs["out1"].description == "" - assert simple_tensorflow_task.interface.outputs["out1"].type == _type_models.LiteralType( - simple=_type_models.SimpleType.STRING - ) - assert simple_tensorflow_task.type == _common_constants.SdkTaskType.TENSORFLOW_TASK - assert simple_tensorflow_task.task_function_name == "simple_tensorflow_task" - assert simple_tensorflow_task.task_module == __name__ - assert simple_tensorflow_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert simple_tensorflow_task.metadata.deprecated_error_message == "" - assert simple_tensorflow_task.metadata.discoverable is False - assert simple_tensorflow_task.metadata.discovery_version == "" - assert simple_tensorflow_task.metadata.retries.retries == 0 - assert len(simple_tensorflow_task.container.resources.limits) == 0 - assert len(simple_tensorflow_task.container.resources.requests) == 0 - assert simple_tensorflow_task.custom["workers"] == 2 - assert simple_tensorflow_task.custom["psReplicas"] == 1 - assert simple_tensorflow_task.custom["chiefReplicas"] == 1 - - # Should strip out the venv component of the args. - assert simple_tensorflow_task._get_container_definition().args[0] == "pyflyte-execute" - - pb2 = simple_tensorflow_task.to_flyte_idl() - assert pb2.custom["workers"] == 2 - assert pb2.custom["psReplicas"] == 1 - assert pb2.custom["chiefReplicas"] == 1 diff --git a/tests/scripts/test_flytekit_sagemaker_runner.py b/tests/scripts/test_flytekit_sagemaker_runner.py deleted file mode 100644 index 60771ccab9..0000000000 --- a/tests/scripts/test_flytekit_sagemaker_runner.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import sys -from unittest import mock - -from flytekit_scripts.flytekit_sagemaker_runner import run as _flyte_sagemaker_run - -cmd = [] -cmd.extend(["--__FLYTE_ENV_VAR_env1__", "val1"]) -cmd.extend(["--__FLYTE_ENV_VAR_env2__", "val2"]) -cmd.extend(["--__FLYTE_CMD_0_service_venv__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_1_pyflyte-execute__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_2_--task-module__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_3_blah__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_4_--task-name__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_5_bloh__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_6_--output-prefix__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_7_s3://fake-bucket__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_8_--inputs__", "__FLYTE_CMD_DUMMY_VALUE__"]) -cmd.extend(["--__FLYTE_CMD_9_s3://fake-bucket__", "__FLYTE_CMD_DUMMY_VALUE__"]) - - -@mock.patch.dict("os.environ") -@mock.patch("subprocess.run") -def test(mock_subprocess_run): - _flyte_sagemaker_run(cmd) - assert "env1" in os.environ - assert "env2" in os.environ - assert os.environ["env1"] == "val1" - assert os.environ["env2"] == "val2" - mock_subprocess_run.assert_called_with( - "service_venv pyflyte-execute --task-module blah --task-name bloh " - "--output-prefix s3://fake-bucket --inputs s3://fake-bucket".split(), - stdout=sys.stdout, - stderr=sys.stderr, - encoding="utf-8", - check=True, - )