diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 210605e493..49dc69e7bf 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -130,8 +130,6 @@ jobs: # See: https://github.com/flyteorg/flytekit/actions/runs/4493746408/jobs/7905368664 - python-version: 3.11 plugin-names: "flytekit-whylogs" - - steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/plugins/flytekit-kf-mpi/README.md b/plugins/flytekit-kf-mpi/README.md index 35c9444c42..db475868eb 100644 --- a/plugins/flytekit-kf-mpi/README.md +++ b/plugins/flytekit-kf-mpi/README.md @@ -8,4 +8,67 @@ To install the plugin, run the following command: pip install flytekitplugins-kfmpi ``` -_Example coming soon!_ +## Code Example +MPI usage: +```python + @task( + task_config=MPIJob( + launcher=Launcher( + replicas=1, + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + ), + slots=2, + ), + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x +``` + + +Horovod Usage: +You can override the command of a replica group by: +```python + @task( + task_config=HorovodJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + worker=Worker( + replicas=1, + command=["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + restart_policy=RestartPolicy.NEVER, + ), + slots=2, + verbose=False, + log_level="INFO", + ), + ) + def my_horovod_task(): + ... +``` + + + + +## Upgrade MPI Plugin from V0 to V1 +MPI plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: +``` + task_config=MPIJob(num_workers=10), +``` +to +``` + task_config=MPIJob(worker=Worker(replicas=10)), +``` diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index df5c74288e..7d2107c8ae 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import HorovodJob, MPIJob +from .task import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index e1c1be0a03..20179c7376 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -2,62 +2,89 @@ This Plugin adds the capability of running distributed MPI training to Flyte using backend plugins, natively on Kubernetes. It leverages `MPI Job `_ Plugin from kubeflow. """ -from dataclasses import dataclass -from typing import Any, Callable, Dict, List +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Union -from flyteidl.plugins import mpi_pb2 as _mpi_task +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import mpi_pb2 as mpi_task from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins -from flytekit.models import common as _common -class MPIJobModel(_common.FlyteIdlEntity): - """Model definition for MPI the plugin +@dataclass +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted + """ - Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER + + +@dataclass +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. + """ - num_launcher_replicas: Number of launcher server replicas to use + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING - slots: Number of slots per worker used in hostfile - .. note:: - Please use resources=Resources(cpu="1"...) to specify per worker resource +@dataclass +class RunPolicy: + """ + RunPolicy describes some policy to apply to the execution of a kubeflow job. + Args: + clean_pod_policy: Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None. + ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs. + active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job. + can remain active before it is terminated. Must be a positive integer. This setting applies only to pods. + where restartPolicy is OnFailure or Always. + backoff_limit (int): Number of retries before marking this job as failed. """ - def __init__(self, num_workers, num_launcher_replicas, slots): - self._num_workers = num_workers - self._num_launcher_replicas = num_launcher_replicas - self._slots = slots + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None - @property - def num_workers(self): - return self._num_workers - @property - def num_launcher_replicas(self): - return self._num_launcher_replicas +@dataclass +class Worker: + """ + Worker replica configuration. Worker command can be customized. If not specified, the worker will use + default command generated by the mpi operator. + """ - @property - def slots(self): - return self._slots + command: Optional[List[str]] = None + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None - def to_flyte_idl(self): - return _mpi_task.DistributedMPITrainingTask( - num_workers=self.num_workers, num_launcher_replicas=self.num_launcher_replicas, slots=self.slots - ) - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - num_workers=pb2_object.num_workers, - num_launcher_replicas=pb2_object.num_launcher_replicas, - slots=pb2_object.slots, - ) +@dataclass +class Launcher: + """ + Launcher replica configuration. Launcher command can be customized. If not specified, the launcher will use + the command specified in the task signature. + """ + + command: Optional[List[str]] = None + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None @dataclass @@ -67,18 +94,21 @@ class MPIJob(object): to run distributed training on k8s with MPI Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). - - num_launcher_replicas: Number of launcher server replicas to use - - slots: Number of slots per worker used in hostfile - + launcher: Configuration for the launcher replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + slots: The number of slots per worker used in the hostfile. + num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. + num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job """ - slots: int - num_launcher_replicas: int = 1 - num_workers: int = 1 + launcher: Launcher = field(default_factory=lambda: Launcher()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + slots: int = 1 + # Support v0 config for backwards compatibility + num_launcher_replicas: Optional[int] = None + num_workers: Optional[int] = None class MPIFunctionTask(PythonFunctionTask[MPIJob]): @@ -110,6 +140,22 @@ class MPIFunctionTask(PythonFunctionTask[MPIJob]): ] def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): + if task_config.num_workers and task_config.worker.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_workers is None and task_config.worker.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_launcher_replicas and task_config.launcher.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated." + ) + if task_config.num_launcher_replicas is None and task_config.launcher.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `launcher.replicas`. Please use `launcher.replicas` as `num_launcher_replicas` is depreacated." + ) super().__init__( task_config=task_config, task_function=task_function, @@ -117,27 +163,87 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): **kwargs, ) + def _convert_replica_spec( + self, replica_config: Union[Launcher, Worker] + ) -> mpi_task.DistributedMPITrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + return mpi_task.DistributedMPITrainingReplicaSpec( + command=replica_config.command, + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + + def _get_base_command(self, settings: SerializationSettings) -> List[str]: + return super().get_command(settings) + def get_command(self, settings: SerializationSettings) -> List[str]: - cmd = super().get_command(settings) - num_procs = self.task_config.num_workers * self.task_config.slots + cmd = self._get_base_command(settings) + if self.task_config.num_workers: + num_workers = self.task_config.num_workers + else: + num_workers = self.task_config.worker.replicas + num_procs = num_workers * self.task_config.slots mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", settings.entrypoint_settings.path] + cmd # the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile return mpi_cmd def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = MPIJobModel( - num_workers=self.task_config.num_workers, - num_launcher_replicas=self.task_config.num_launcher_replicas, + worker = self._convert_replica_spec(self.task_config.worker) + if self.task_config.num_workers: + worker.replicas = self.task_config.num_workers + + launcher = self._convert_replica_spec(self.task_config.launcher) + if self.task_config.num_launcher_replicas: + launcher.replicas = self.task_config.num_launcher_replicas + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + mpi_job = mpi_task.DistributedMPITrainingTask( + worker_replicas=worker, + launcher_replicas=launcher, slots=self.task_config.slots, + run_policy=run_policy, ) - return MessageToDict(job.to_flyte_idl()) + return MessageToDict(mpi_job) @dataclass class HorovodJob(object): - slots: int - num_launcher_replicas: int = 1 - num_workers: int = 1 + """ + Configuration for an executable `Horovod Job using MPI operator`_. Use this + to run distributed training on k8s with MPI. For more info, check out Running Horovod`_. + + Args: + worker: Worker configuration for the job. + launcher: Launcher configuration for the job. + run_policy: Configuration for the run policy. + slots: Number of slots per worker used in hostfile (default: 1). + verbose: Optional flag indicating whether to enable verbose logging (default: False). + log_level: Optional string specifying the log level (default: "INFO"). + discovery_script_path: Path to the discovery script used for host discovery (default: "/etc/mpi/discover_hosts.sh"). + num_launcher_replicas: [DEPRECATED] The number of launcher server replicas to use. This argument is deprecated. Please use launcher.replicas instead. + num_workers: [DEPRECATED] The number of worker replicas to spawn in the cluster for this job. Please use worker.replicas instead. + """ + + worker: Worker = field(default_factory=lambda: Worker()) + launcher: Launcher = field(default_factory=lambda: Launcher()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + slots: int = 1 + verbose: Optional[bool] = False + log_level: Optional[str] = "INFO" + discovery_script_path: Optional[str] = "/etc/mpi/discover_hosts.sh" + # Support v0 config for backwards compatibility + num_launcher_replicas: Optional[int] = None + num_workers: Optional[int] = None class HorovodFunctionTask(MPIFunctionTask): @@ -146,11 +252,8 @@ class HorovodFunctionTask(MPIFunctionTask): """ # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. - ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" - discovery_script_path = "/etc/mpi/discover_hosts.sh" def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): - super().__init__( task_config=task_config, task_function=task_function, @@ -158,23 +261,21 @@ def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): ) def get_command(self, settings: SerializationSettings) -> List[str]: - cmd = super().get_command(settings) + cmd = self._get_base_command(settings) mpi_cmd = self._get_horovod_prefix() + cmd return mpi_cmd - def get_config(self, settings: SerializationSettings) -> Dict[str, str]: - config = super().get_config(settings) - return {**config, "worker_spec_command": self.ssh_command} - def _get_horovod_prefix(self) -> List[str]: - np = self.task_config.num_workers * self.task_config.slots + np = self.task_config.worker.replicas * self.task_config.slots + verbose = "--verbose" if self.task_config.verbose is True else "" + log_level = self.task_config.log_level base_cmd = [ "horovodrun", "-np", f"{np}", - "--verbose", + f"{verbose}", "--log-level", - "INFO", + f"{log_level}", "--network-interface", "eth0", "--min-np", @@ -184,7 +285,7 @@ def _get_horovod_prefix(self) -> List[str]: "--slots-per-host", f"{self.task_config.slots}", "--host-discovery-script", - self.discovery_script_path, + self.task_config.discovery_script_path, ] return base_cmd diff --git a/plugins/flytekit-kf-mpi/requirements.txt b/plugins/flytekit-kf-mpi/requirements.txt index f96f963645..2c3e5c48fc 100644 --- a/plugins/flytekit-kf-mpi/requirements.txt +++ b/plugins/flytekit-kf-mpi/requirements.txt @@ -1,27 +1,67 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfmpi # via -r requirements.in +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot charset-normalizer==3.1.0 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 @@ -29,11 +69,16 @@ cookiecutter==2.1.1 croniter==1.3.8 # via flytekit cryptography==39.0.2 - # via pyopenssl + # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -44,20 +89,53 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.14 +flyteidl==1.5.5 + # via flytekit +flytekit==1.6.1 + # via flytekitplugins-kfmpi +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.5.0 # via + # adlfs # flytekit - # flytekitplugins-kfmpi -flytekit==1.3.1 - # via flytekitplugins-kfmpi + # gcsfs + # s3fs +gcsfs==2023.5.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.18.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage googleapis-common-protos==1.59.0 # via # flyteidl # flytekit + # google-api-core # grpcio-status grpcio==1.51.3 # via @@ -66,11 +144,17 @@ grpcio==1.51.3 grpcio-status==1.51.3 # via flytekit idna==3.4 - # via requests + # via + # requests + # yarl importlib-metadata==6.1.0 # via # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -79,10 +163,16 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit +markdown-it-py==2.2.0 + # via rich markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -94,8 +184,21 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mdurl==0.1.2 + # via markdown-it-py more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.3.1 @@ -105,33 +208,48 @@ numpy==1.23.5 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==23.0 # via # docker # marshmallow pandas==1.5.3 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==4.22.1 # via # flyteidl + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal pyopenssl==23.0.0 # via flytekit python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -147,21 +265,48 @@ pyyaml==6.0 # via # cookiecutter # flytekit + # kubernetes # responses regex==2022.10.31 # via docker-image-py requests==2.28.2 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.3.5 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.5.0 # via flytekit six==1.16.0 - # via python-dateutil + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil smmap==5.0.0 # via gitdb sortedcontainers==2.4.0 @@ -174,23 +319,40 @@ types-pyyaml==6.0.12.8 # via responses typing-extensions==4.5.0 # via + # aioitertools + # azure-core + # azure-storage-blob # flytekit + # rich # typing-inspect typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.15 # via + # botocore # docker # flytekit + # google-auth + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker + # via + # docker + # kubernetes wheel==0.40.0 # via flytekit wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit +yarl==1.9.2 + # via aiohttp zipp==3.15.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index 566506069c..05efff84b0 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "flyteidl>=0.21.4"] +plugin_requires = ["flytekit>=1.6.1,<2.0.0"] __version__ = "0.0.0+develop" @@ -18,7 +18,7 @@ packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, license="apache2", - python_requires=">=3.6", + python_requires=">=3.8", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 7732d520c2..f6eb2655f6 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,22 +1,25 @@ -from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel +import pytest +from flytekitplugins.kfmpi import CleanPodPolicy, HorovodJob, Launcher, MPIJob, RestartPolicy, RunPolicy, Worker +from flytekitplugins.kfmpi.task import MPIFunctionTask from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings -def test_mpi_model_task(): - job = MPIJobModel( - num_workers=1, - num_launcher_replicas=1, - slots=1, +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert job.num_workers == 1 - assert job.num_launcher_replicas == 1 - assert job.slots == 1 - assert job.from_flyte_idl(job.to_flyte_idl()) + return settings -def test_mpi_task(): +def test_mpi_task(serialization_settings: SerializationSettings): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), requests=Resources(cpu="1"), @@ -30,37 +33,165 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.task_config is not None - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), + assert my_mpi_task.get_custom(serialization_settings) == { + "launcherReplicas": {"replicas": 10, "resources": {}}, + "workerReplicas": {"replicas": 10, "resources": {}}, + "slots": 1, + } + assert my_mpi_task.task_type == "mpi" + + +def test_mpi_task_with_default_config(serialization_settings: SerializationSettings): + task_config = MPIJob( + worker=Worker(replicas=1), + launcher=Launcher(replicas=1), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x + + assert my_mpi_task(x=10, y="hello") == 10 + + assert my_mpi_task.task_config is not None + assert my_mpi_task.task_type == "mpi" + assert my_mpi_task.resources.limits == Resources() + assert my_mpi_task.resources.requests == Resources(cpu="1") + assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith( + " ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]) + ) + + expected_dict = { + "launcherReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "slots": 1, + } + assert my_mpi_task.get_custom(serialization_settings) == expected_dict + + +def test_mpi_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = MPIJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + image="launcher:latest", + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.NEVER, + ), + run_policy=RunPolicy( + clean_pod_policy=CleanPodPolicy.ALL, + ), + slots=2, ) - assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_mpi_task(x: int, y: str) -> int: + return x + + assert my_mpi_task(x=10, y="hello") == 10 + + assert my_mpi_task.task_config is not None assert my_mpi_task.task_type == "mpi" + assert my_mpi_task.resources.limits == Resources() + assert my_mpi_task.resources.requests == Resources(cpu="1") + assert " ".join(my_mpi_task.get_command(serialization_settings)).startswith( + " ".join(MPIFunctionTask._MPI_BASE_COMMAND + ["-np", "1"]) + ) + + expected_custom_dict = { + "launcherReplicas": { + "replicas": 1, + "image": "launcher:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + }, + "slots": 2, + "runPolicy": {"cleanPodPolicy": "CLEANPOD_POLICY_ALL"}, + } + assert my_mpi_task.get_custom(serialization_settings) == expected_custom_dict -def test_horovod_task(): +def test_horovod_task(serialization_settings): @task( - task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + task_config=HorovodJob( + launcher=Launcher( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + worker=Worker( + replicas=1, + command=["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + restart_policy=RestartPolicy.NEVER, + ), + slots=2, + verbose=False, + log_level="INFO", + ), ) def my_horovod_task(): ... - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) - cmd = my_horovod_task.get_command(settings) + cmd = my_horovod_task.get_command(serialization_settings) assert "horovodrun" in cmd - config = my_horovod_task.get_config(settings) - assert "/usr/sbin/sshd" in config["worker_spec_command"] - custom = my_horovod_task.get_custom(settings) - assert isinstance(custom, dict) is True + assert "--verbose" not in cmd + assert "--log-level" in cmd + assert "INFO" in cmd + expected_dict = { + "launcherReplicas": { + "replicas": 1, + "resources": { + "requests": [ + {"name": "CPU", "value": "1"}, + ], + "limits": [ + {"name": "CPU", "value": "2"}, + ], + }, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + "command": ["/usr/sbin/sshd", "-De", "-f", "/home/jobuser/.sshd_config"], + }, + "slots": 2, + } + assert my_horovod_task.get_custom(serialization_settings) == expected_dict diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 7de27502bf..c1516d3248 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -14,3 +14,42 @@ pip install flytekitplugins-kfpytorch To set up PyTorch operator in the Flyte deployment's backend, follow the [PyTorch Operator Setup](https://docs.flyte.org/en/latest/deployment/plugin_setup/pytorch_operator.html) guide. An [example](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/kfpytorch/pytorch_mnist.html#sphx-glr-auto-integrations-kubernetes-kfpytorch-pytorch-mnist-py) showcasing PyTorch operator can be found in the documentation. + +## Code Example +```python +from flytekitplugins.kfpytorch import PyTorch, Worker, Master, RestartPolicy, RunPolicy, CleanPodPolicy + +@task( + task_config = PyTorch( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + master=Master( + restart_policy=RestartPolicy.ALWAYS, + ), + ) + image="test_image", + resources=Resources(cpu="1", mem="1Gi"), +) +def pytorch_job(): + ... +``` + + +## Upgrade Pytorch Plugin from V0 to V1 +Pytorch plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: + ``` + task_config=Pytorch(num_workers=10), + ``` + to: + ``` + task_config=PyTorch(worker=Worker(replicas=10)), + ``` diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index cb9add7302..d56e1f83d9 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -11,4 +11,4 @@ Elastic """ -from .task import Elastic, PyTorch +from .task import CleanPodPolicy, Elastic, Master, PyTorch, RestartPolicy, RunPolicy, Worker diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index aea2c9a2e6..6625263db1 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -3,34 +3,104 @@ Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ import os -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum from typing import Any, Callable, Dict, Optional, Union import cloudpickle -from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import pytorch_pb2 as pytorch_task from google.protobuf.json_format import MessageToDict import flytekit -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import IgnoreOutputs, TaskPlugins TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @dataclass -class PyTorch(object): +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted """ - Configuration for an executable `Pytorch Job `_. Use this - to run distributed pytorch training on k8s + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER + + +@dataclass +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. + """ + + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING + + +@dataclass +class RunPolicy: + """ + RunPolicy describes some policy to apply to the execution of a kubeflow job. Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). + clean_pod_policy (int): Defines the policy for cleaning up pods after the PyTorchJob completes. Default to None. + ttl_seconds_after_finished (int): Defines the TTL for cleaning up finished PyTorchJobs. + active_deadline_seconds (int): Specifies the duration (in seconds) since startTime during which the job. + can remain active before it is terminated. Must be a positive integer. This setting applies only to pods. + where restartPolicy is OnFailure or Always. + backoff_limit (int): Number of retries before marking this job as failed. + """ + + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None + + +@dataclass +class Worker: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None + +@dataclass +class Master: + """ + Configuration for master replica group. Master should always have 1 replica, so we don't need a `replicas` field + """ + + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + restart_policy: Optional[RestartPolicy] = None + + +@dataclass +class PyTorch(object): + """ + Configuration for an executable `PyTorch Job `_. Use this + to run distributed PyTorch training on Kubernetes. + + Args: + master: Configuration for the master replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. """ - num_workers: int + master: Master = field(default_factory=lambda: Master()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + # Support v0 config for backwards compatibility + num_workers: Optional[int] = None @dataclass @@ -67,6 +137,14 @@ class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): _PYTORCH_TASK_TYPE = "pytorch" def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): + if task_config.num_workers and task_config.worker.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_workers is None and task_config.worker.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) super().__init__( task_config, task_function, @@ -74,9 +152,42 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): **kwargs, ) + def _convert_replica_spec( + self, replica_config: Union[Master, Worker] + ) -> pytorch_task.DistributedPyTorchTrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + replicas = 1 + # Master should always have 1 replica + if not isinstance(replica_config, Master): + replicas = replica_config.replicas + return pytorch_task.DistributedPyTorchTrainingReplicaSpec( + replicas=replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers) - return MessageToDict(job) + worker = self._convert_replica_spec(self.task_config.worker) + # support v0 config for backwards compatibility + if self.task_config.num_workers: + worker.replicas = self.task_config.num_workers + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + pytorch_job = pytorch_task.DistributedPyTorchTrainingTask( + worker_replicas=worker, + master_replicas=self._convert_replica_spec(self.task_config.master), + run_policy=run_policy, + ) + return MessageToDict(pytorch_job) # Register the Pytorch Plugin into the flytekit core plugin system @@ -236,8 +347,10 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] nproc_per_node=self.task_config.nproc_per_node, max_restarts=self.task_config.max_restarts, ) - job = DistributedPyTorchTrainingTask( - workers=self.max_nodes, + job = pytorch_task.DistributedPyTorchTrainingTask( + worker_replicas=pytorch_task.DistributedPyTorchTrainingReplicaSpec( + replicas=self.max_nodes, + ), elastic_config=elastic_config, ) return MessageToDict(job) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 46973017ac..85d7ed0ab3 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # # pip-compile requirements.in @@ -63,6 +63,7 @@ click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via # flytekit @@ -93,11 +94,9 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.19 - # via - # flytekit - # flytekitplugins-kfpytorch -flytekit==1.5.0 +flyteidl==1.5.5 + # via flytekit +flytekit==1.6.1 # via flytekitplugins-kfpytorch frozenlist==1.3.3 # via @@ -157,6 +156,8 @@ importlib-metadata==6.6.0 # via # flytekit # keyring +importlib-resources==5.12.0 + # via keyring isodate==0.6.1 # via azure-storage-blob jaraco-classes==3.2.3 @@ -175,6 +176,8 @@ keyring==23.13.1 # via flytekit kubernetes==26.1.0 # via flytekit +markdown-it-py==2.2.0 + # via rich markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -186,6 +189,8 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mdurl==0.1.2 + # via markdown-it-py more-itertools==9.1.0 # via jaraco-classes msal==1.22.0 @@ -236,6 +241,8 @@ pyasn1-modules==0.3.0 # via google-auth pycparser==2.21 # via cffi +pygments==2.15.1 + # via rich pyjwt[crypto]==2.6.0 # via # adal @@ -290,6 +297,12 @@ requests-oauthlib==1.3.1 # kubernetes responses==0.23.1 # via flytekit +rich==13.3.5 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit rsa==4.9 # via google-auth s3fs==2023.4.0 @@ -318,6 +331,7 @@ typing-extensions==4.5.0 # azure-core # azure-storage-blob # flytekit + # rich # typing-inspect typing-inspect==0.8.0 # via dataclasses-json @@ -343,7 +357,9 @@ wrapt==1.15.0 yarl==1.9.2 # via aiohttp zipp==3.15.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index 23543ac7bc..048524dbfc 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["cloudpickle", "flytekit>=1.3.0,<2.0.0", "flyteidl>=1.3.19"] +plugin_requires = ["cloudpickle", "flytekit>=1.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 00eb6c0953..ecdf9e375c 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -1,10 +1,24 @@ -from flytekitplugins.kfpytorch.task import PyTorch +import pytest +from flytekitplugins.kfpytorch.task import Master, PyTorch, RestartPolicy, Worker from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings -def test_pytorch_task(): +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_pytorch_task(serialization_settings: SerializationSettings): @task( task_config=PyTorch(num_workers=10), cache=True, @@ -18,16 +32,97 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.task_config is not None - default_img = Image(name="default", fqn="test", tag="tag") - settings = SerializationSettings( - project="project", - domain="domain", - version="version", - env={"FOO": "baz"}, - image_config=ImageConfig(default_image=default_img, images=[default_img]), + assert my_pytorch_task.get_custom(serialization_settings) == { + "workerReplicas": {"replicas": 10, "resources": {}}, + "masterReplicas": {"replicas": 1, "resources": {}}, + } + assert my_pytorch_task.resources.limits == Resources() + assert my_pytorch_task.resources.requests == Resources(cpu="1") + assert my_pytorch_task.task_type == "pytorch" + + +def test_pytorch_task_with_default_config(serialization_settings: SerializationSettings): + task_config = PyTorch(worker=Worker(replicas=1)) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", ) + def my_pytorch_task(x: int, y: str) -> int: + return x - assert my_pytorch_task.get_custom(settings) == {"workers": 10} + assert my_pytorch_task(x=10, y="hello") == 10 + + assert my_pytorch_task.task_config is not None + assert my_pytorch_task.task_type == "pytorch" assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") + + expected_dict = { + "masterReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + } + assert my_pytorch_task.get_custom(serialization_settings) == expected_dict + + +def test_pytorch_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = PyTorch( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + master=Master( + restart_policy=RestartPolicy.ALWAYS, + ), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_pytorch_task(x: int, y: str) -> int: + return x + + assert my_pytorch_task(x=10, y="hello") == 10 + + assert my_pytorch_task.task_config is not None assert my_pytorch_task.task_type == "pytorch" + assert my_pytorch_task.resources.limits == Resources() + assert my_pytorch_task.resources.requests == Resources(cpu="1") + + expected_custom_dict = { + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, + "masterReplicas": { + "resources": {}, + "replicas": 1, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, + } + assert my_pytorch_task.get_custom(serialization_settings) == expected_custom_dict diff --git a/plugins/flytekit-kf-tensorflow/README.md b/plugins/flytekit-kf-tensorflow/README.md index 9e4c26fa70..d059624f03 100644 --- a/plugins/flytekit-kf-tensorflow/README.md +++ b/plugins/flytekit-kf-tensorflow/README.md @@ -8,4 +8,47 @@ To install the plugin, run the following command: pip install flytekitplugins-kftensorflow ``` -_Example coming soon!_ +## Code Example +To build a TFJob with: +10 workers with restart policy as failed and 2 CPU and 2Gi Memory +1 ps replica with resources the same as task defined resources +1 chief replica with resources the same as task defined resources and restart policy as always +run policy as clean up all pods after job is finished. + +You code: +```python +from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker + +@task( + task_config=TfJob( + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="2", mem="2Gi"), + restart_policy=RestartPolicy.FAILURE, + ), + ps=PS(replicas=1), + chief=Chief(replicas=1, restart_policy=RestartPolicy.ALWAYS), + run_policy=RunPolicy(clean_pod_policy=CleanPodPolicy.RUNNING), + ), + image="test_image", + resources=Resources(cpu="1", mem="1Gi"), +) +def tf_job(): + ... +``` + + +## Upgrade TensorFlow Plugin from V0 to V1 +Tensorflow plugin is now updated from v0 to v1 to enable more configuration options. +To migrate from v0 to v1, change the following: +1. Update flytepropeller to v1.6.0 +2. Update flytekit version to v1.6.2 +3. Update your code from: + ``` + task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), + ``` + to: + ``` + task_config=TfJob(worker=Worker(replicas=10), ps=PS(replicas=1), chief=Chief(replicas=1)), + ``` diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py index 02dec6cc7d..81a4cbc248 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/__init__.py @@ -10,4 +10,4 @@ TfJob """ -from .task import TfJob +from .task import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py deleted file mode 100644 index 87d7bb7b90..0000000000 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/models.py +++ /dev/null @@ -1,35 +0,0 @@ -from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task - -from flytekit.models import common as _common - - -class TensorFlowJob(_common.FlyteIdlEntity): - def __init__(self, workers_count, ps_replicas_count, chief_replicas_count): - self._workers_count = workers_count - self._ps_replicas_count = ps_replicas_count - self._chief_replicas_count = chief_replicas_count - - @property - def workers_count(self): - return self._workers_count - - @property - def ps_replicas_count(self): - return self._ps_replicas_count - - @property - def chief_replicas_count(self): - return self._chief_replicas_count - - def to_flyte_idl(self): - return _tensorflow_task.DistributedTensorflowTrainingTask( - workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ps_replicas_count=pb2_object.ps_replicas, - chief_replicas_count=pb2_object.chief_replicas, - ) diff --git a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py index 03855e3095..bd6a97a293 100644 --- a/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py +++ b/plugins/flytekit-kf-tensorflow/flytekitplugins/kftensorflow/task.py @@ -2,37 +2,113 @@ This Plugin adds the capability of running distributed tensorflow training to Flyte using backend plugins, natively on Kubernetes. It leverages `TF Job `_ Plugin from kubeflow. """ -from dataclasses import dataclass -from typing import Any, Callable, Dict +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, Optional, Union +from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common +from flyteidl.plugins.kubeflow import tensorflow_pb2 as tensorflow_task from google.protobuf.json_format import MessageToDict -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, Resources from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import TaskPlugins -from .models import TensorFlowJob + +@dataclass +class RestartPolicy(Enum): + """ + RestartPolicy describes how the replicas should be restarted + """ + + ALWAYS = kubeflow_common.RESTART_POLICY_ALWAYS + FAILURE = kubeflow_common.RESTART_POLICY_ON_FAILURE + NEVER = kubeflow_common.RESTART_POLICY_NEVER @dataclass -class TfJob(object): +class CleanPodPolicy(Enum): + """ + CleanPodPolicy describes how to deal with pods when the job is finished. """ - Configuration for an executable `TF Job `_. Use this - to run distributed tensorflow training on k8s (with parameter server) + + NONE = kubeflow_common.CLEANPOD_POLICY_NONE + ALL = kubeflow_common.CLEANPOD_POLICY_ALL + RUNNING = kubeflow_common.CLEANPOD_POLICY_RUNNING + + +@dataclass +class RunPolicy: + """ + RunPolicy describes a set of policies to apply to the execution of a Kubeflow job. Args: - num_workers: integer determining the number of worker replicas spawned in the cluster for this job - (in addition to 1 master). + clean_pod_policy: The policy for cleaning up pods after the job completes. Defaults to None. + ttl_seconds_after_finished: The time-to-live (TTL) in seconds for cleaning up finished jobs. + active_deadline_seconds: The duration (in seconds) since startTime during which the job can remain + active before it is terminated. Must be a positive integer. This setting applies only to pods + where restartPolicy is OnFailure or Always. + backoff_limit: The number of retries before marking this job as failed. + """ + + clean_pod_policy: CleanPodPolicy = None + ttl_seconds_after_finished: Optional[int] = None + active_deadline_seconds: Optional[int] = None + backoff_limit: Optional[int] = None + + +@dataclass +class Chief: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None + + +@dataclass +class PS: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None - num_ps_replicas: Number of Parameter server replicas to use - num_chief_replicas: Number of chief replicas to use +@dataclass +class Worker: + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + replicas: Optional[int] = None + restart_policy: Optional[RestartPolicy] = None + +@dataclass +class TfJob: """ + Configuration for an executable `TensorFlow Job `_. Use this + to run distributed TensorFlow training on Kubernetes. - num_workers: int - num_ps_replicas: int - num_chief_replicas: int + Args: + chief: Configuration for the chief replica group. + ps: Configuration for the parameter server (PS) replica group. + worker: Configuration for the worker replica group. + run_policy: Configuration for the run policy. + num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead. + num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead. + num_chief_replicas: [DEPRECATED] This argument is deprecated. Use `chief.replicas` instead. + """ + + chief: Chief = field(default_factory=lambda: Chief()) + ps: PS = field(default_factory=lambda: PS()) + worker: Worker = field(default_factory=lambda: Worker()) + run_policy: Optional[RunPolicy] = field(default_factory=lambda: None) + # Support v0 config for backwards compatibility + num_workers: Optional[int] = None + num_ps_replicas: Optional[int] = None + num_chief_replicas: Optional[int] = None class TensorflowFunctionTask(PythonFunctionTask[TfJob]): @@ -44,20 +120,79 @@ class TensorflowFunctionTask(PythonFunctionTask[TfJob]): _TF_JOB_TASK_TYPE = "tensorflow" def __init__(self, task_config: TfJob, task_function: Callable, **kwargs): + if task_config.num_workers and task_config.worker.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_workers is None and task_config.worker.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `worker.replicas`. Please use `worker.replicas` as `num_workers` is depreacated." + ) + if task_config.num_chief_replicas and task_config.chief.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + ) + if task_config.num_chief_replicas is None and task_config.chief.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated." + ) + if task_config.num_ps_replicas and task_config.ps.replicas: + raise ValueError( + "Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) + if task_config.num_ps_replicas is None and task_config.ps.replicas is None: + raise ValueError( + "Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated." + ) super().__init__( task_type=self._TF_JOB_TASK_TYPE, task_config=task_config, task_function=task_function, + task_type_version=1, **kwargs, ) + def _convert_replica_spec( + self, replica_config: Union[Chief, PS, Worker] + ) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec: + resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits) + return tensorflow_task.DistributedTensorflowTrainingReplicaSpec( + replicas=replica_config.replicas, + image=replica_config.image, + resources=resources.to_flyte_idl() if resources else None, + restart_policy=replica_config.restart_policy.value if replica_config.restart_policy else None, + ) + + def _convert_run_policy(self, run_policy: RunPolicy) -> kubeflow_common.RunPolicy: + return kubeflow_common.RunPolicy( + clean_pod_policy=run_policy.clean_pod_policy.value if run_policy.clean_pod_policy.value else None, + ttl_seconds_after_finished=run_policy.ttl_seconds_after_finished, + active_deadline_seconds=run_policy.active_deadline_seconds, + backoff_limit=run_policy.active_deadline_seconds, + ) + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = TensorFlowJob( - workers_count=self.task_config.num_workers, - ps_replicas_count=self.task_config.num_ps_replicas, - chief_replicas_count=self.task_config.num_chief_replicas, + chief = self._convert_replica_spec(self.task_config.chief) + if self.task_config.num_chief_replicas: + chief.replicas = self.task_config.num_chief_replicas + + worker = self._convert_replica_spec(self.task_config.worker) + if self.task_config.num_workers: + worker.replicas = self.task_config.num_workers + + ps = self._convert_replica_spec(self.task_config.ps) + if self.task_config.num_ps_replicas: + ps.replicas = self.task_config.num_ps_replicas + + run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None + training_task = tensorflow_task.DistributedTensorflowTrainingTask( + chief_replicas=chief, + worker_replicas=worker, + ps_replicas=ps, + run_policy=run_policy, ) - return MessageToDict(job.to_flyte_idl()) + + return MessageToDict(training_task) # Register the Tensorflow Plugin into the flytekit core plugin system diff --git a/plugins/flytekit-kf-tensorflow/requirements.txt b/plugins/flytekit-kf-tensorflow/requirements.txt index 064ffd73d5..8f67a26831 100644 --- a/plugins/flytekit-kf-tensorflow/requirements.txt +++ b/plugins/flytekit-kf-tensorflow/requirements.txt @@ -1,27 +1,67 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.8 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kftensorflow # via -r requirements.in +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot charset-normalizer==3.1.0 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via flytekit cookiecutter==2.1.1 @@ -29,11 +69,16 @@ cookiecutter==2.1.1 croniter==1.3.8 # via flytekit cryptography==39.0.2 - # via pyopenssl + # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -44,18 +89,53 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.14 +flyteidl==1.5.5 # via flytekit -flytekit==1.3.1 +flytekit==1.6.1 # via flytekitplugins-kftensorflow +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.5.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.5.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.18.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage googleapis-common-protos==1.59.0 # via # flyteidl # flytekit + # google-api-core # grpcio-status grpcio==1.51.3 # via @@ -64,11 +144,17 @@ grpcio==1.51.3 grpcio-status==1.51.3 # via flytekit idna==3.4 - # via requests + # via + # requests + # yarl importlib-metadata==6.1.0 # via # flytekit # keyring +importlib-resources==5.12.0 + # via keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -77,10 +163,16 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit +markdown-it-py==2.2.0 + # via rich markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -92,8 +184,21 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit +mdurl==0.1.2 + # via markdown-it-py more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.3.1 @@ -103,33 +208,48 @@ numpy==1.23.5 # flytekit # pandas # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib packaging==23.0 # via # docker # marshmallow pandas==1.5.3 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==4.22.1 # via # flyteidl + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.7.0 + # via msal pyopenssl==23.0.0 # via flytekit python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -145,21 +265,48 @@ pyyaml==6.0 # via # cookiecutter # flytekit + # kubernetes # responses regex==2022.10.31 # via docker-image-py requests==2.28.2 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes responses==0.23.1 # via flytekit -retry==0.9.2 +rich==13.3.5 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.5.0 # via flytekit six==1.16.0 - # via python-dateutil + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil smmap==5.0.0 # via gitdb sortedcontainers==2.4.0 @@ -172,23 +319,40 @@ types-pyyaml==6.0.12.8 # via responses typing-extensions==4.5.0 # via + # aioitertools + # azure-core + # azure-storage-blob # flytekit + # rich # typing-inspect typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.15 # via + # botocore # docker # flytekit + # google-auth + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker + # via + # docker + # kubernetes wheel==0.40.0 # via flytekit wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit +yarl==1.9.2 + # via aiohttp zipp==3.15.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index 4614b90497..79c1ade31d 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -4,8 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -# TODO: Requirements are missing, add them back in later. -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py index 2bcfcda550..d863d3fdc4 100644 --- a/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py +++ b/plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py @@ -1,9 +1,173 @@ -from flytekitplugins.kftensorflow import TfJob +import pytest +from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_tensorflow_task_with_default_config(serialization_settings: SerializationSettings): + task_config = TfJob( + worker=Worker(replicas=1), + chief=Chief(replicas=0), + ps=PS(replicas=0), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + assert my_tensorflow_task.task_type == "tensorflow" + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + + expected_dict = { + "chiefReplicas": { + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "psReplicas": { + "resources": {}, + }, + } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict + + +def test_tensorflow_task_with_custom_config(serialization_settings: SerializationSettings): + task_config = TfJob( + chief=Chief( + replicas=1, + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + image="chief:latest", + ), + worker=Worker( + replicas=5, + requests=Resources(cpu="2", mem="2Gi"), + limits=Resources(cpu="4", mem="2Gi"), + image="worker:latest", + restart_policy=RestartPolicy.FAILURE, + ), + ps=PS( + replicas=2, + restart_policy=RestartPolicy.ALWAYS, + ), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + assert my_tensorflow_task.task_type == "tensorflow" + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + + expected_custom_dict = { + "chiefReplicas": { + "replicas": 1, + "image": "chief:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workerReplicas": { + "replicas": 5, + "image": "worker:latest", + "resources": { + "requests": [ + {"name": "CPU", "value": "2"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + "limits": [ + {"name": "CPU", "value": "4"}, + {"name": "MEMORY", "value": "2Gi"}, + ], + }, + "restartPolicy": "RESTART_POLICY_ON_FAILURE", + }, + "psReplicas": { + "resources": {}, + "replicas": 2, + "restartPolicy": "RESTART_POLICY_ALWAYS", + }, + } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_custom_dict + + +def test_tensorflow_task_with_run_policy(serialization_settings: SerializationSettings): + task_config = TfJob( + worker=Worker(replicas=1), + ps=PS(replicas=0), + chief=Chief(replicas=0), + run_policy=RunPolicy(clean_pod_policy=CleanPodPolicy.RUNNING), + ) + + @task( + task_config=task_config, + cache=True, + requests=Resources(cpu="1"), + cache_version="1", + ) + def my_tensorflow_task(x: int, y: str) -> int: + return x + + assert my_tensorflow_task(x=10, y="hello") == 10 + + assert my_tensorflow_task.task_config is not None + assert my_tensorflow_task.task_type == "tensorflow" + assert my_tensorflow_task.resources.limits == Resources() + assert my_tensorflow_task.resources.requests == Resources(cpu="1") + + expected_dict = { + "chiefReplicas": { + "resources": {}, + }, + "workerReplicas": { + "replicas": 1, + "resources": {}, + }, + "psReplicas": { + "resources": {}, + }, + "runPolicy": { + "cleanPodPolicy": "CLEANPOD_POLICY_RUNNING", + }, + } + assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict + + def test_tensorflow_task(): @task( task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1), @@ -27,7 +191,21 @@ def my_tensorflow_task(x: int, y: str) -> int: image_config=ImageConfig(default_image=default_img, images=[default_img]), ) - assert my_tensorflow_task.get_custom(settings) == {"workers": 10, "psReplicas": 1, "chiefReplicas": 1} + expected_dict = { + "chiefReplicas": { + "replicas": 1, + "resources": {}, + }, + "workerReplicas": { + "replicas": 10, + "resources": {}, + }, + "psReplicas": { + "replicas": 1, + "resources": {}, + }, + } + assert my_tensorflow_task.get_custom(settings) == expected_dict assert my_tensorflow_task.resources.limits == Resources() assert my_tensorflow_task.resources.requests == Resources(cpu="1") assert my_tensorflow_task.task_type == "tensorflow"