diff --git a/CODEOWNERS b/CODEOWNERS index 9389524869..a9aab29ffd 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,3 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence. -* @wild-endeavor @kumare3 @eapolinario @pingsutw +* @wild-endeavor @kumare3 @eapolinario @pingsutw @cosmicBboy diff --git a/doc-requirements.in b/doc-requirements.in index 2850232418..17862869cf 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -46,3 +46,4 @@ whylabs-client # whylogs ray # ray scikit-learn # scikit-learn vaex # vaex +mlflow # mlflow diff --git a/doc-requirements.txt b/doc-requirements.txt index 7c92fcb018..8eb39a5a1e 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -14,6 +14,8 @@ aiosignal==1.3.1 # via ray alabaster==0.7.12 # via sphinx +alembic==1.9.1 + # via mlflow altair==4.2.0 # via great-expectations ansiwrap==0.8.4 @@ -42,13 +44,13 @@ arrow==1.2.3 # jinja2-time astroid==2.12.13 # via sphinx-autoapi -astropy==5.1.1 +astropy==5.2 # via vaex-astro asttokens==2.2.1 # via stack-data astunparse==1.6.3 # via tensorflow -attrs==22.1.0 +attrs==22.2.0 # via # jsonschema # ray @@ -65,11 +67,11 @@ beautifulsoup4==4.11.1 # sphinx-material binaryornot==0.4.4 # via cookiecutter -blake3==0.3.1 +blake3==0.3.3 # via vaex-core bleach==5.0.1 # via nbconvert -botocore==1.29.26 +botocore==1.29.44 # via -r doc-requirements.in bqplot==0.12.36 # via vaex-jupyter @@ -93,12 +95,15 @@ chardet==5.1.0 # via binaryornot charset-normalizer==2.1.1 # via requests -click==8.0.4 +click==8.1.3 # via # cookiecutter # dask + # databricks-cli + # flask # flytekit # great-expectations + # mlflow # papermill # ray # sphinx-click @@ -107,6 +112,8 @@ cloudpickle==2.2.0 # via # dask # flytekit + # mlflow + # shap # vaex-core colorama==0.4.6 # via great-expectations @@ -120,7 +127,7 @@ cookiecutter==2.1.1 # via flytekit croniter==1.3.8 # via flytekit -cryptography==38.0.4 +cryptography==39.0.0 # via # -r doc-requirements.in # great-expectations @@ -129,8 +136,10 @@ css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib -dask==2022.12.0 +dask==2022.12.1 # via vaex-core +databricks-cli==0.17.4 + # via mlflow dataclasses-json==0.5.7 # via # dolt-integrations @@ -150,7 +159,9 @@ diskcache==5.4.0 distlib==0.3.6 # via virtualenv docker==6.0.1 - # via flytekit + # via + # flytekit + # mlflow docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 @@ -168,6 +179,7 @@ entrypoints==0.4 # via # altair # jupyter-client + # mlflow # papermill executing==1.2.0 # via stack-data @@ -175,14 +187,16 @@ fastapi==0.88.0 # via vaex-server fastjsonschema==2.16.2 # via nbformat -filelock==3.8.2 +filelock==3.9.0 # via # ray # vaex-core # virtualenv -flatbuffers==22.12.6 +flask==2.2.2 + # via mlflow +flatbuffers==23.1.4 # via tensorflow -flyteidl==1.3.1 +flyteidl==1.3.2 # via flytekit fonttools==4.38.0 # via matplotlib @@ -205,11 +219,14 @@ future==0.18.2 # via vaex-core gast==0.5.3 # via tensorflow +gitdb==4.0.10 + # via gitpython +gitpython==3.1.30 + # via mlflow google-api-core[grpc]==2.11.0 # via # -r doc-requirements.in # google-cloud-bigquery - # google-cloud-bigquery-storage # google-cloud-core google-auth==2.15.0 # via @@ -222,10 +239,8 @@ google-auth-oauthlib==0.4.6 # via tensorboard google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.4.0 +google-cloud-bigquery==3.4.1 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.16.2 - # via google-cloud-bigquery google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.5.0 @@ -234,12 +249,13 @@ google-pasta==0.2.0 # via tensorflow google-resumable-media==2.4.0 # via google-cloud-bigquery -googleapis-common-protos==1.57.0 +googleapis-common-protos==1.57.1 # via # flyteidl + # flytekit # google-api-core # grpcio-status -great-expectations==0.15.37 +great-expectations==0.15.42 # via -r doc-requirements.in greenlet==2.0.1 # via sqlalchemy @@ -256,6 +272,8 @@ grpcio-status==1.51.1 # via # flytekit # google-api-core +gunicorn==20.1.0 + # via mlflow h11==0.14.0 # via uvicorn h5py==3.7.0 @@ -266,7 +284,7 @@ htmlmin==0.1.12 # via pandas-profiling httptools==0.5.0 # via uvicorn -identify==2.5.9 +identify==2.5.12 # via pre-commit idna==3.4 # via @@ -277,17 +295,19 @@ imagehash==4.3.1 # via visions imagesize==1.4.1 # via sphinx -importlib-metadata==5.1.0 +importlib-metadata==5.2.0 # via + # flask # flytekit # great-expectations # keyring # markdown + # mlflow # nbconvert # sphinx ipydatawidgets==4.3.2 # via pythreejs -ipykernel==6.19.2 +ipykernel==6.19.4 # via # ipywidgets # jupyter @@ -299,7 +319,7 @@ ipyleaflet==0.17.2 # via vaex-jupyter ipympl==0.9.2 # via vaex-jupyter -ipython==8.7.0 +ipython==8.8.0 # via # great-expectations # ipykernel @@ -320,7 +340,7 @@ ipyvuetify==1.8.4 # via vaex-jupyter ipywebrtc==0.6.0 # via ipyvolume -ipywidgets==8.0.3 +ipywidgets==8.0.4 # via # bqplot # great-expectations @@ -333,6 +353,8 @@ ipywidgets==8.0.3 # pythreejs isoduration==20.11.0 # via jsonschema +itsdangerous==2.1.2 + # via flask jaraco-classes==3.2.3 # via keyring jedi==0.18.2 @@ -342,9 +364,11 @@ jinja2==3.1.2 # altair # branca # cookiecutter + # flask # great-expectations # jinja2-time # jupyter-server + # mlflow # nbclassic # nbconvert # notebook @@ -387,7 +411,7 @@ jupyter-client==7.4.8 # qtconsole jupyter-console==6.4.4 # via jupyter -jupyter-core==5.1.0 +jupyter-core==5.1.2 # via # jupyter-client # jupyter-server @@ -399,27 +423,27 @@ jupyter-core==5.1.0 # qtconsole jupyter-events==0.5.0 # via jupyter-server -jupyter-server==2.0.1 +jupyter-server==2.0.6 # via # nbclassic # notebook-shim -jupyter-server-terminals==0.4.2 +jupyter-server-terminals==0.4.3 # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==3.0.4 +jupyterlab-widgets==3.0.5 # via ipywidgets keras==2.8.0 # via tensorflow keras-preprocessing==1.1.2 # via tensorflow -keyring==23.11.0 +keyring==23.13.1 # via flytekit kiwisolver==1.4.4 # via matplotlib kubernetes==25.3.0 # via -r doc-requirements.in -lazy-object-proxy==1.8.0 +lazy-object-proxy==1.9.0 # via astroid libclang==14.0.6 # via tensorflow @@ -427,17 +451,21 @@ llvmlite==0.39.1 # via numba locket==1.0.0 # via partd -lxml==4.9.1 +lxml==4.9.2 # via sphinx-material makefun==1.15.0 # via great-expectations +mako==1.2.4 + # via alembic markdown==3.4.1 # via # -r doc-requirements.in + # mlflow # tensorboard markupsafe==2.1.1 # via # jinja2 + # mako # nbconvert # werkzeug marshmallow==3.19.0 @@ -453,6 +481,7 @@ marshmallow-jsonschema==0.13.0 matplotlib==3.6.2 # via # ipympl + # mlflow # pandas-profiling # phik # seaborn @@ -465,13 +494,15 @@ mistune==2.0.4 # via # great-expectations # nbconvert -modin==0.17.1 +mlflow==2.1.1 + # via -r doc-requirements.in +modin==0.18.0 # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes msgpack==1.0.4 # via ray -multimethod==1.9 +multimethod==1.9.1 # via # pandas-profiling # visions @@ -485,13 +516,13 @@ nbclient==0.7.2 # via # nbconvert # papermill -nbconvert==7.2.6 +nbconvert==7.2.7 # via # jupyter # jupyter-server # nbclassic # notebook -nbformat==5.7.0 +nbformat==5.7.1 # via # great-expectations # jupyter-server @@ -518,13 +549,16 @@ notebook==6.5.2 notebook-shim==0.2.2 # via nbclassic numba==0.56.4 - # via vaex-ml + # via + # shap + # vaex-ml numpy==1.23.5 # via # altair # astropy # bqplot # contourpy + # flytekit # great-expectations # h5py # imagehash @@ -533,6 +567,7 @@ numpy==1.23.5 # ipyvolume # keras-preprocessing # matplotlib + # mlflow # modin # numba # opt-einsum @@ -549,6 +584,7 @@ numpy==1.23.5 # scikit-learn # scipy # seaborn + # shap # statsmodels # tensorboard # tensorflow @@ -556,7 +592,9 @@ numpy==1.23.5 # visions # xarray oauthlib==3.2.2 - # via requests-oauthlib + # via + # databricks-cli + # requests-oauthlib opt-einsum==3.3.0 # via tensorflow packaging==21.3 @@ -570,10 +608,12 @@ packaging==21.3 # jupyter-server # marshmallow # matplotlib + # mlflow # modin # nbconvert # pandera # qtpy + # shap # sphinx # statsmodels # xarray @@ -584,16 +624,18 @@ pandas==1.5.2 # dolt-integrations # flytekit # great-expectations + # mlflow # modin # pandas-profiling # pandera # phik # seaborn + # shap # statsmodels # vaex-core # visions # xarray -pandas-profiling==3.5.0 +pandas-profiling==3.6.2 # via -r doc-requirements.in pandera==0.13.4 # via -r doc-requirements.in @@ -613,7 +655,7 @@ phik==0.12.3 # via pandas-profiling pickleshare==0.7.5 # via ipython -pillow==9.3.0 +pillow==9.4.0 # via # imagehash # ipympl @@ -621,13 +663,13 @@ pillow==9.3.0 # matplotlib # vaex-viz # visions -platformdirs==2.6.0 +platformdirs==2.6.2 # via # jupyter-core # virtualenv plotly==5.11.0 # via -r doc-requirements.in -pre-commit==2.20.0 +pre-commit==2.21.0 # via sphinx-tags progressbar2==4.2.0 # via vaex-core @@ -641,17 +683,15 @@ prompt-toolkit==3.0.36 # ipython # jupyter-console proto-plus==1.22.1 - # via - # google-cloud-bigquery - # google-cloud-bigquery-storage -protobuf==4.21.11 + # via google-cloud-bigquery +protobuf==4.21.12 # via # flyteidl # google-api-core # google-cloud-bigquery - # google-cloud-bigquery-storage # googleapis-common-protos # grpcio-status + # mlflow # proto-plus # protoc-gen-swagger # ray @@ -677,7 +717,7 @@ py4j==0.10.9.5 pyarrow==10.0.1 # via # flytekit - # google-cloud-bigquery + # mlflow # vaex-core pyasn1==0.4.8 # via @@ -687,7 +727,7 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.10.2 +pydantic==1.10.4 # via # fastapi # great-expectations @@ -696,7 +736,7 @@ pydantic==1.10.2 # vaex-core pyerfa==2.0.0.1 # via astropy -pygments==2.13.0 +pygments==2.14.0 # via # furo # ipython @@ -706,14 +746,16 @@ pygments==2.13.0 # rich # sphinx # sphinx-prompt -pyopenssl==22.1.0 +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 # via # great-expectations # matplotlib # packaging -pyrsistent==0.19.2 +pyrsistent==0.19.3 # via jsonschema pyspark==3.3.1 # via -r doc-requirements.in @@ -746,11 +788,12 @@ pythreejs==2.4.1 # via ipyvolume pytimeparse==1.1.8 # via flytekit -pytz==2022.6 +pytz==2022.7 # via # babel # flytekit # great-expectations + # mlflow # pandas pytz-deprecation-shim==0.1.0.post0 # via tzlocal @@ -764,6 +807,7 @@ pyyaml==6.0 # flytekit # jupyter-events # kubernetes + # mlflow # pandas-profiling # papermill # pre-commit @@ -783,13 +827,16 @@ qtconsole==5.4.0 # via jupyter qtpy==2.3.0 # via qtconsole -ray==2.1.0 +querystring-parser==1.2.4 + # via mlflow +ray==2.2.0 # via -r doc-requirements.in regex==2022.10.31 # via docker-image-py requests==2.28.1 # via # cookiecutter + # databricks-cli # docker # flytekit # google-api-core @@ -797,6 +844,7 @@ requests==2.28.1 # great-expectations # ipyvolume # kubernetes + # mlflow # pandas-profiling # papermill # ray @@ -817,7 +865,7 @@ rfc3339-validator==0.1.4 # via jsonschema rfc3986-validator==0.1.1 # via jsonschema -rich==12.6.0 +rich==13.0.0 # via vaex-core rsa==4.9 # via google-auth @@ -826,37 +874,50 @@ ruamel-yaml==0.17.17 ruamel-yaml-clib==0.2.7 # via ruamel-yaml scikit-learn==1.2.0 - # via -r doc-requirements.in + # via + # -r doc-requirements.in + # mlflow + # shap scipy==1.9.3 # via # great-expectations # imagehash + # mlflow # pandas-profiling # phik # scikit-learn + # shap # statsmodels -seaborn==0.12.1 +seaborn==0.12.2 # via pandas-profiling send2trash==1.8.0 # via # jupyter-server # nbclassic # notebook +shap==0.41.0 + # via mlflow six==1.16.0 # via # asttokens # astunparse # bleach + # databricks-cli # google-auth # google-pasta # keras-preprocessing # kubernetes # patsy # python-dateutil + # querystring-parser # rfc3339-validator # sphinx-code-include # tensorflow # vaex-core +slicer==0.0.7 + # via shap +smmap==5.0.0 + # via gitdb sniffio==1.3.0 # via anyio snowballstemmer==2.2.0 @@ -917,8 +978,13 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.44 - # via -r doc-requirements.in +sqlalchemy==1.4.46 + # via + # -r doc-requirements.in + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow stack-data==0.6.2 # via ipython starlette==0.22.0 @@ -928,7 +994,9 @@ statsd==3.3.0 statsmodels==0.13.5 # via pandas-profiling tabulate==0.9.0 - # via vaex-core + # via + # databricks-cli + # vaex-core tangled-up-in-unicode==0.2.0 # via visions tenacity==8.1.0 @@ -945,9 +1013,9 @@ tensorflow==2.8.1 # via -r doc-requirements.in tensorflow-estimator==2.8.0 # via tensorflow -tensorflow-io-gcs-filesystem==0.28.0 +tensorflow-io-gcs-filesystem==0.29.0 # via tensorflow -termcolor==2.1.1 +termcolor==2.2.0 # via tensorflow terminado==0.17.1 # via @@ -964,15 +1032,13 @@ threadpoolctl==3.1.0 tinycss2==1.2.1 # via nbconvert toml==0.10.2 - # via - # pre-commit - # responses + # via responses toolz==0.12.0 # via # altair # dask # partd -torch==1.13.0 +torch==1.13.1 # via -r doc-requirements.in tornado==6.2 # via @@ -988,7 +1054,8 @@ tqdm==4.64.1 # great-expectations # pandas-profiling # papermill -traitlets==5.7.0 + # shap +traitlets==5.8.0 # via # bqplot # comm @@ -1109,7 +1176,9 @@ websocket-client==1.4.2 websockets==10.4 # via uvicorn werkzeug==2.2.2 - # via tensorboard + # via + # flask + # tensorboard wheel==0.38.4 # via # astunparse @@ -1117,11 +1186,11 @@ wheel==0.38.4 # tensorboard whylabs-client==0.4.2 # via -r doc-requirements.in -whylogs==1.1.16 +whylogs==1.1.20 # via -r doc-requirements.in whylogs-sketching==3.4.1.dev3 # via whylogs -widgetsnbextension==4.0.4 +widgetsnbextension==4.0.5 # via ipywidgets wrapt==1.14.1 # via diff --git a/docs/Makefile b/docs/Makefile index e61723ad76..afa73807cb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,3 +18,7 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + + +clean: + rm -rf ./build ./source/generated diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index bf0b03fb95..008f2b4bbe 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -29,6 +29,7 @@ Plugin API reference * :ref:`Ray ` - Ray API reference * :ref:`DBT ` - DBT API reference * :ref:`Vaex ` - Vaex API reference +* :ref:`MLflow ` - MLflow API reference .. toctree:: :maxdepth: 2 @@ -59,3 +60,4 @@ Plugin API reference Ray DBT Vaex + MLflow diff --git a/docs/source/plugins/mlflow.rst b/docs/source/plugins/mlflow.rst new file mode 100644 index 0000000000..60d1a7c66b --- /dev/null +++ b/docs/source/plugins/mlflow.rst @@ -0,0 +1,9 @@ +.. _mlflow: + +################################################### +MLflow API reference +################################################### + +.. tags:: Integration, MachineLearning, Tracking + +.. automodule:: flytekitplugins.mlflow diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 1a9f74f114..5ba43b78dc 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -190,6 +190,7 @@ from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase from flytekit.models.core.types import BlobType +from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.types import directory, file, numpy, schema diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 7c4439d83d..6c8f54e9ce 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -10,11 +10,13 @@ import grpc import requests as _requests from flyteidl.admin.project_pb2 import ProjectListRequest +from flyteidl.admin.signal_pb2 import SignalList, SignalListRequest, SignalSetRequest, SignalSetResponse from flyteidl.service import admin_pb2_grpc as _admin_service from flyteidl.service import auth_pb2 from flyteidl.service import auth_pb2_grpc as auth_service from flyteidl.service import dataproxy_pb2 as _dataproxy_pb2 from flyteidl.service import dataproxy_pb2_grpc as dataproxy_service +from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub from google.protobuf.json_format import MessageToJson as _MessageToJson @@ -145,6 +147,7 @@ def __init__(self, cfg: PlatformConfig, **kwargs): ) self._stub = _admin_service.AdminServiceStub(self._channel) self._auth_stub = auth_service.AuthMetadataServiceStub(self._channel) + self._signal = signal_service.SignalServiceStub(self._channel) try: resp = self._auth_stub.GetPublicClientConfig(auth_pb2.PublicClientAuthConfigRequest()) self._public_client_config = resp @@ -406,6 +409,20 @@ def get_task(self, get_object_request): """ return self._stub.GetTask(get_object_request, metadata=self._metadata) + @_handle_rpc_error(retry=True) + def set_signal(self, signal_set_request: SignalSetRequest) -> SignalSetResponse: + """ + This sets a signal + """ + return self._signal.SetSignal(signal_set_request, metadata=self._metadata) + + @_handle_rpc_error(retry=True) + def list_signals(self, signal_list_request: SignalListRequest) -> SignalList: + """ + This lists signals + """ + return self._signal.ListSignals(signal_list_request, metadata=self._metadata) + #################################################################################################################### # # Workflow Endpoints diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index cb228b788a..5d65f8c3ca 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -463,7 +463,7 @@ class GCSConfig(object): gsutil_parallelism: bool = False @classmethod - def auto(self, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: config_file = get_config_file(config_file) kwargs = {} kwargs = set_if_exists(kwargs, "gsutil_parallelism", _internal.GCP.GSUTIL_PARALLELISM.read(config_file)) @@ -647,6 +647,7 @@ class SerializationSettings(object): domain: typing.Optional[str] = None version: typing.Optional[str] = None env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -719,6 +720,7 @@ def new_builder(self) -> Builder: version=self.version, image_config=self.image_config, env=self.env.copy() if self.env else None, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, @@ -768,6 +770,7 @@ class Builder(object): version: str image_config: ImageConfig env: Optional[Dict[str, str]] = None + git_repo: Optional[str] = None flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None @@ -783,6 +786,7 @@ def build(self) -> SerializationSettings: version=self.version, image_config=self.image_config, env=self.env, + git_repo=self.git_repo, flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index 467f660d42..793917cffe 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -224,7 +224,7 @@ def legacy_config(self) -> _configparser.ConfigParser: return self._legacy_config @property - def yaml_config(self) -> typing.Dict[str, Any]: + def yaml_config(self) -> typing.Dict[str, typing.Any]: return self._yaml_config diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index dccbaec803..491bed4385 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -45,6 +45,7 @@ from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext @@ -156,6 +157,7 @@ def __init__( metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, + docs: Optional[Documentation] = None, **kwargs, ): self._task_type = task_type @@ -164,6 +166,7 @@ def __init__( self._metadata = metadata if metadata else TaskMetadata() self._task_type_version = task_type_version self._security_ctx = security_ctx + self._docs = docs FlyteEntities.entities.append(self) @@ -195,6 +198,10 @@ def task_type_version(self) -> int: def security_context(self) -> SecurityContext: return self._security_ctx + @property + def docs(self) -> Documentation: + return self._docs + def get_type_for_input_var(self, k: str, v: Any) -> type: """ Returns the python native type for the given input variable @@ -390,6 +397,17 @@ def __init__( self._environment = environment if environment else {} self._task_config = task_config self._disable_deck = disable_deck + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + self._docs.short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs.long_description = Description(value=self._python_interface.docstring.long_description) # TODO lets call this interface and the other as flyte_interface? @property diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 63d7c8106f..954c1ae409 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -207,7 +207,6 @@ def transform_interface_to_typed_interface( """ if interface is None: return None - if interface.docstring is None: input_descriptions = output_descriptions = {} else: diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6e5b0a6b6a..bb24181338 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -7,6 +7,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -89,6 +90,7 @@ def task( secret_requests: Optional[List[Secret]] = None, execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + docs: Optional[Documentation] = None, disable_deck: bool = True, ) -> Union[Callable, PythonFunctionTask]: """ @@ -179,6 +181,7 @@ def foo2(): :param execution_mode: This is mainly for internal use. Please ignore. It is filled in automatically. :param task_resolver: Provide a custom task resolver. :param disable_deck: If true, this task will not output deck html file + :param docs: Documentation about this task """ def wrapper(fn) -> PythonFunctionTask: @@ -204,6 +207,7 @@ def wrapper(fn) -> PythonFunctionTask: execution_mode=execution_mode, task_resolver=task_resolver, disable_deck=disable_deck, + docs=docs, ) update_wrapper(task_instance, fn) return task_instance diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 468a5aa7ea..a1a1581e96 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -39,6 +39,7 @@ from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.documentation import Description, Documentation GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -168,6 +169,7 @@ def __init__( workflow_metadata: WorkflowMetadata, workflow_metadata_defaults: WorkflowMetadataDefaults, python_interface: Interface, + docs: Optional[Documentation] = None, **kwargs, ): self._name = name @@ -179,6 +181,20 @@ def __init__( self._unbound_inputs = set() self._nodes = [] self._output_bindings: List[_literal_models.Binding] = [] + self._docs = docs + + if self._python_interface.docstring: + if self.docs is None: + self._docs = Documentation( + short_description=self._python_interface.docstring.short_description, + long_description=Description(value=self._python_interface.docstring.long_description), + ) + else: + if self._python_interface.docstring.short_description: + self._docs.short_description = self._python_interface.docstring.short_description + if self._python_interface.docstring.long_description: + self._docs = Description(value=self._python_interface.docstring.long_description) + FlyteEntities.entities.append(self) super().__init__(**kwargs) @@ -186,6 +202,10 @@ def __init__( def name(self) -> str: return self._name + @property + def docs(self): + return self._docs + @property def short_name(self) -> str: return extract_obj_name(self._name) @@ -571,7 +591,8 @@ def __init__( workflow_function: Callable, metadata: Optional[WorkflowMetadata], default_metadata: Optional[WorkflowMetadataDefaults], - docstring: Docstring = None, + docstring: Optional[Docstring] = None, + docs: Optional[Documentation] = None, ): name, _, _, _ = extract_task_module(workflow_function) self._workflow_function = workflow_function @@ -586,6 +607,7 @@ def __init__( workflow_metadata=metadata, workflow_metadata_defaults=default_metadata, python_interface=native_interface, + docs=docs, ) @property @@ -690,6 +712,7 @@ def workflow( _workflow_function=None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, + docs: Optional[Documentation] = None, ): """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG @@ -718,6 +741,7 @@ def workflow( :param _workflow_function: This argument is implicitly passed and represents the decorated function. :param failure_policy: Use the options in flytekit.WorkflowFailurePolicy :param interruptible: Whether or not tasks launched from this workflow are by default interruptible + :param docs: Description entity for the workflow """ def wrapper(fn): @@ -730,6 +754,7 @@ def wrapper(fn): metadata=workflow_metadata, default_metadata=workflow_metadata_defaults, docstring=Docstring(callable_=fn), + docs=docs, ) workflow_instance.compile() update_wrapper(workflow_instance, fn) diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index f34e692123..e40307b6ba 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -1,13 +1,21 @@ +import typing + from flyteidl.admin import workflow_pb2 as _admin_workflow from flytekit.models import common as _common from flytekit.models.core import compiler as _compiler_models from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _core_workflow +from flytekit.models.documentation import Documentation class WorkflowSpec(_common.FlyteIdlEntity): - def __init__(self, template, sub_workflows): + def __init__( + self, + template: _core_workflow.WorkflowTemplate, + sub_workflows: typing.List[_core_workflow.WorkflowTemplate], + docs: typing.Optional[Documentation] = None, + ): """ This object fully encapsulates the specification of a workflow :param flytekit.models.core.workflow.WorkflowTemplate template: @@ -15,6 +23,7 @@ def __init__(self, template, sub_workflows): """ self._template = template self._sub_workflows = sub_workflows + self._docs = docs @property def template(self): @@ -30,6 +39,13 @@ def sub_workflows(self): """ return self._sub_workflows + @property + def docs(self): + """ + :rtype: Description entity for the workflow + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec @@ -37,6 +53,7 @@ def to_flyte_idl(self): return _admin_workflow.WorkflowSpec( template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + description=self._docs.to_flyte_idl() if self._docs else None, ) @classmethod @@ -48,6 +65,7 @@ def from_flyte_idl(cls, pb2_object): return cls( _core_workflow.WorkflowTemplate.from_flyte_idl(pb2_object.template), [_core_workflow.WorkflowTemplate.from_flyte_idl(s) for s in pb2_object.sub_workflows], + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, ) diff --git a/flytekit/models/documentation.py b/flytekit/models/documentation.py new file mode 100644 index 0000000000..e1bae8122e --- /dev/null +++ b/flytekit/models/documentation.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from flyteidl.admin import description_entity_pb2 + +from flytekit.models import common as _common_models + + +@dataclass +class Description(_common_models.FlyteIdlEntity): + """ + Full user description with formatting preserved. This can be rendered + by clients, such as the console or command line tools with in-tact + formatting. + """ + + class DescriptionFormat(Enum): + UNKNOWN = 0 + MARKDOWN = 1 + HTML = 2 + RST = 3 + + value: Optional[str] = None + uri: Optional[str] = None + icon_link: Optional[str] = None + format: DescriptionFormat = DescriptionFormat.RST + + def to_flyte_idl(self): + return description_entity_pb2.Description( + value=self.value if self.value else None, + uri=self.uri if self.uri else None, + format=self.format.value, + icon_link=self.icon_link, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.Description) -> "Description": + return cls( + value=pb2_object.value if pb2_object.value else None, + uri=pb2_object.uri if pb2_object.uri else None, + format=Description.DescriptionFormat(pb2_object.format), + icon_link=pb2_object.icon_link if pb2_object.icon_link else None, + ) + + +@dataclass +class SourceCode(_common_models.FlyteIdlEntity): + """ + Link to source code used to define this task or workflow. + """ + + link: Optional[str] = None + + def to_flyte_idl(self): + return description_entity_pb2.SourceCode(link=self.link) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.SourceCode) -> "SourceCode": + return cls(link=pb2_object.link) if pb2_object.link else None + + +@dataclass +class Documentation(_common_models.FlyteIdlEntity): + """ + DescriptionEntity contains detailed description for the task/workflow/launch plan. + Documentation could provide insight into the algorithms, business use case, etc. + Args: + short_description (str): One-liner overview of the entity. + long_description (Optional[Description]): Full user description with formatting preserved. + source_code (Optional[SourceCode]): link to source code used to define this entity + """ + + short_description: Optional[str] = None + long_description: Optional[Description] = None + source_code: Optional[SourceCode] = None + + def to_flyte_idl(self): + return description_entity_pb2.DescriptionEntity( + short_description=self.short_description, + long_description=self.long_description.to_flyte_idl() if self.long_description else None, + source_code=self.source_code.to_flyte_idl() if self.source_code else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object: description_entity_pb2.DescriptionEntity) -> "Documentation": + return cls( + short_description=pb2_object.short_description, + long_description=Description.from_flyte_idl(pb2_object.long_description) + if pb2_object.long_description + else None, + source_code=SourceCode.from_flyte_idl(pb2_object.source_code) if pb2_object.source_code else None, + ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index f2ff5efd89..2129cdd88f 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -13,6 +13,7 @@ from flytekit.models import literals as _literals from flytekit.models import security as _sec from flytekit.models.core import identifier as _identifier +from flytekit.models.documentation import Documentation class Resources(_common.FlyteIdlEntity): @@ -480,11 +481,13 @@ def from_flyte_idl(cls, pb2_object): class TaskSpec(_common.FlyteIdlEntity): - def __init__(self, template): + def __init__(self, template: TaskTemplate, docs: typing.Optional[Documentation] = None): """ :param TaskTemplate template: + :param Documentation docs: """ self._template = template + self._docs = docs @property def template(self): @@ -493,11 +496,20 @@ def template(self): """ return self._template + @property + def docs(self): + """ + :rtype: Description entity for the task + """ + return self._docs + def to_flyte_idl(self): """ :rtype: flyteidl.admin.tasks_pb2.TaskSpec """ - return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) + return _admin_task.TaskSpec( + template=self.template.to_flyte_idl(), description=self.docs.to_flyte_idl() if self.docs else None + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -505,7 +517,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.tasks_pb2.TaskSpec pb2_object: :rtype: TaskSpec """ - return cls(TaskTemplate.from_flyte_idl(pb2_object.template)) + return cls( + TaskTemplate.from_flyte_idl(pb2_object.template), + Documentation.from_flyte_idl(pb2_object.description) if pb2_object.description else None, + ) class Task(_common.FlyteIdlEntity): diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 174928a5b4..4d6f172586 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -51,9 +51,9 @@ :toctree: generated/ :nosignatures: - ~task.FlyteTask - ~workflow.FlyteWorkflow - ~launch_plan.FlyteLaunchPlan + ~entities.FlyteTask + ~entities.FlyteWorkflow + ~entities.FlyteLaunchPlan .. _remote-flyte-entity-components: @@ -65,9 +65,9 @@ :toctree: generated/ :nosignatures: - ~nodes.FlyteNode - ~component_nodes.FlyteTaskNode - ~component_nodes.FlyteWorkflowNode + ~entities.FlyteNode + ~entities.FlyteTaskNode + ~entities.FlyteWorkflowNode .. _remote-flyte-execution-objects: diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index 0c745c11bb..c9de5aea33 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -334,6 +334,12 @@ def promote_from_model( return cls(new_if_else_block), converted_sub_workflows +class FlyteGateNode(_workflow_model.GateNode): + @classmethod + def promote_from_model(cls, model: _workflow_model.GateNode): + return cls(model.signal, model.sleep, model.approve) + + class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): """A class encapsulating a remote Flyte node.""" @@ -343,22 +349,23 @@ def __init__( upstream_nodes, bindings, metadata, - task_node: FlyteTaskNode = None, - workflow_node: FlyteWorkflowNode = None, - branch_node: FlyteBranchNode = None, + task_node: Optional[FlyteTaskNode] = None, + workflow_node: Optional[FlyteWorkflowNode] = None, + branch_node: Optional[FlyteBranchNode] = None, + gate_node: Optional[FlyteGateNode] = None, ): - if not task_node and not workflow_node and not branch_node: + if not task_node and not workflow_node and not branch_node and not gate_node: raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one of task|workflow|branch entity specified at once" + "An Flyte node must have one of task|workflow|branch|gate entity specified at once" ) - # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from - # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. + # TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead + # of a node? if task_node: self._flyte_entity = task_node.flyte_task elif workflow_node: self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan else: - self._flyte_entity = branch_node + self._flyte_entity = branch_node or gate_node super(FlyteNode, self).__init__( id=id, @@ -369,6 +376,7 @@ def __init__( task_node=task_node, workflow_node=workflow_node, branch_node=branch_node, + gate_node=gate_node, ) self._upstream = upstream_nodes @@ -412,7 +420,7 @@ def promote_from_model( remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") return None, converted_sub_workflows - flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None + flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None if model.task_node is not None: if model.task_node.reference_id not in tasks: raise RuntimeError( @@ -435,6 +443,8 @@ def promote_from_model( tasks, converted_sub_workflows, ) + elif model.gate_node is not None: + flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node) else: raise _system_exceptions.FlyteSystemException( f"Bad Node model, neither task nor workflow detected, node: {model}" @@ -459,6 +469,7 @@ def promote_from_model( task_node=flyte_task_node, workflow_node=flyte_workflow_node, branch_node=flyte_branch_node, + gate_node=flyte_gate_node, ), converted_sub_workflows, ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 6473d46ec9..23c9803b07 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -17,7 +17,9 @@ from dataclasses import asdict, dataclass from datetime import datetime, timedelta +from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 +from git import Repo from flytekit import Literal from flytekit.clients.friendly import SynchronousFlyteClient @@ -40,11 +42,12 @@ from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models from flytekit.models import task as task_models +from flytekit.models import types as type_models from flytekit.models.admin import common as admin_common_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.common import Sort from flytekit.models.core import workflow as workflow_model -from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier +from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( ExecutionMetadata, @@ -119,6 +122,17 @@ def _get_entity_identifier( ) +def _get_git_repo_url(source_path): + """ + Get git repo URL from remote.origin.url + """ + try: + return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1] + except Exception: + # If the file isn't in the git repo, we can't get the url from git config + return "" + + class FlyteRemote(object): """Main entrypoint for programmatically accessing a Flyte remote backend. @@ -350,6 +364,69 @@ def fetch_execution(self, project: str = None, domain: str = None, name: str = N # Listing Entities # ###################### + def list_signals( + self, + execution_name: str, + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + limit: int = 100, + filters: typing.Optional[typing.List[filter_models.Filter]] = None, + ) -> typing.List[Signal]: + """ + :param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param limit: The number of signals to fetch + :param filters: Optional list of filters + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters) + resp = self.client.list_signals(req) + s = resp.signals + return s + + def set_signal( + self, + signal_id: str, + execution_name: str, + value: typing.Union[literal_models.Literal, typing.Any], + project: typing.Optional[str] = None, + domain: typing.Optional[str] = None, + python_type: typing.Optional[typing.Type] = None, + literal_type: typing.Optional[type_models.LiteralType] = None, + ): + """ + :param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call. + :param execution_name: The name of the execution. This is the tail-end of the URL when looking + at the workflow execution. + :param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to + convert into a Literal. This argument is only value for wait_for_input type signals. + :param project: The execution project, will default to the Remote's default project. + :param domain: The execution domain, will default to the Remote's default domain. + :param python_type: Provide a python type to help with conversion if the value you provided is not a Literal. + :param literal_type: Provide a Flyte literal type to help with conversion if the value you provided + is not a Literal + """ + wf_exec_id = WorkflowExecutionIdentifier( + project=project or self.default_project, domain=domain or self.default_domain, name=execution_name + ) + if isinstance(value, Literal): + remote_logger.debug(f"Using provided {value} as existing Literal value") + lit = value + else: + lt = literal_type or ( + TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value)) + ) + lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt) + remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}") + + req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl()) + + # Response is empty currently, nothing to give back to the user. + self.client.set_signal(req) + def recent_executions( self, project: typing.Optional[str] = None, @@ -725,11 +802,11 @@ def register_script( filename="scriptmode.tar.gz", ), ) - serialization_settings = SerializationSettings( project=project, domain=domain, image_config=image_config, + git_repo=_get_git_repo_url(source_path), fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 50bac67844..3c9fe64068 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -12,7 +12,7 @@ from flytekit.models import launch_plan from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote -from flytekit.remote.remote import RegistrationSkipped +from flytekit.remote.remote import RegistrationSkipped, _get_git_repo_url from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities @@ -162,7 +162,7 @@ def load_packages_and_modules( :param options: :return: The common detected root path, the output of _find_project_root """ - + ss.git_repo = _get_git_repo_url(project_root) pkgs_and_modules = [] for pm in pkgs_or_mods: p = Path(pm).resolve() diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index f0ad5e96c6..5ec249fa4b 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,9 +1,10 @@ +import sys import typing from collections import OrderedDict from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit import PythonFunctionTask +from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -23,6 +24,7 @@ from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import security from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model @@ -211,7 +213,8 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return TaskSpec(template=tt) + + return TaskSpec(template=tt, docs=entity.docs) def get_serializable_workflow( @@ -295,8 +298,9 @@ def get_serializable_workflow( nodes=serialized_nodes, outputs=entity.output_bindings, ) + return admin_workflow_models.WorkflowSpec( - template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()) + template=wf_t, sub_workflows=sorted(set(sub_wfs), key=lambda x: x.short_string()), docs=entity.docs ) @@ -658,6 +662,11 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) + elif isinstance(entity, GateNode): + import ipdb + + ipdb.set_trace() + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): if entity.should_register: if isinstance(entity, FlyteTask): @@ -678,6 +687,16 @@ def get_serializable( else: raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") + if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): + # 1. Check if the size of long description exceeds 16KB + # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity + if entity.docs and entity.docs.long_description: + if entity.docs.long_description.value: + if sys.getsizeof(entity.docs.long_description.value) > 16 * 1024 * 1024: + raise ValueError( + "Long Description of the flyte entity exceeds the 16KB size limit. Please specify the uri in the long description instead." + ) + entity.docs.source_code = SourceCode(link=settings.git_repo) # This needs to be at the bottom not the top - i.e. dependent tasks get added before the workflow containing it entity_mapping[entity] = cp_entity return cp_entity diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index ec6b367c20..0e4649203a 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -484,7 +484,7 @@ def register_for_protocol( if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT: if h.python_type in cls.DEFAULT_FORMATS and not override: if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format: - logger.debug( + logger.info( f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified." ) else: diff --git a/plugins/flytekit-mlflow/README.md b/plugins/flytekit-mlflow/README.md new file mode 100644 index 0000000000..6cbee9cf59 --- /dev/null +++ b/plugins/flytekit-mlflow/README.md @@ -0,0 +1,22 @@ +# Flytekit MLflow Plugin + +MLflow enables us to log parameters, code, and results in machine learning experiments and compare them using an interactive UI. +This MLflow plugin enables seamless use of MLFlow within Flyte, and render the metrics and parameters on Flyte Deck. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-mlflow +``` + +Example +```python +from flytekit import task, workflow +from flytekitplugins.mlflow import mlflow_autolog +import mlflow + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(): + ... +``` diff --git a/plugins/flytekit-mlflow/dev-requirements.in b/plugins/flytekit-mlflow/dev-requirements.in new file mode 100644 index 0000000000..0f57144081 --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.in @@ -0,0 +1 @@ +tensorflow diff --git a/plugins/flytekit-mlflow/dev-requirements.txt b/plugins/flytekit-mlflow/dev-requirements.txt new file mode 100644 index 0000000000..6ad9be49bb --- /dev/null +++ b/plugins/flytekit-mlflow/dev-requirements.txt @@ -0,0 +1,122 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile dev-requirements.in +# +absl-py==1.3.0 + # via + # tensorboard + # tensorflow +astunparse==1.6.3 + # via tensorflow +cachetools==5.2.0 + # via google-auth +certifi==2022.9.24 + # via requests +charset-normalizer==2.1.1 + # via requests +flatbuffers==22.10.26 + # via tensorflow +gast==0.4.0 + # via tensorflow +google-auth==2.14.1 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==0.4.6 + # via tensorboard +google-pasta==0.2.0 + # via tensorflow +grpcio==1.50.0 + # via + # tensorboard + # tensorflow +h5py==3.7.0 + # via tensorflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via markdown +keras==2.10.0 + # via tensorflow +keras-preprocessing==1.1.2 + # via tensorflow +libclang==14.0.6 + # via tensorflow +markdown==3.4.1 + # via tensorboard +markupsafe==2.1.1 + # via werkzeug +numpy==1.23.4 + # via + # h5py + # keras-preprocessing + # opt-einsum + # tensorboard + # tensorflow +oauthlib==3.2.2 + # via requests-oauthlib +opt-einsum==3.3.0 + # via tensorflow +packaging==21.3 + # via tensorflow +protobuf==3.19.6 + # via + # tensorboard + # tensorflow +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pyparsing==3.0.9 + # via packaging +requests==2.28.1 + # via + # requests-oauthlib + # tensorboard +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +rsa==4.9 + # via google-auth +six==1.16.0 + # via + # astunparse + # google-auth + # google-pasta + # grpcio + # keras-preprocessing + # tensorflow +tensorboard==2.10.1 + # via tensorflow +tensorboard-data-server==0.6.1 + # via tensorboard +tensorboard-plugin-wit==1.8.1 + # via tensorboard +tensorflow==2.10.0 + # via -r dev-requirements.in +tensorflow-estimator==2.10.0 + # via tensorflow +tensorflow-io-gcs-filesystem==0.27.0 + # via tensorflow +termcolor==2.1.0 + # via tensorflow +typing-extensions==4.4.0 + # via tensorflow +urllib3==1.26.12 + # via requests +werkzeug==2.2.2 + # via tensorboard +wheel==0.38.3 + # via + # astunparse + # tensorboard +wrapt==1.14.1 + # via tensorflow +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py new file mode 100644 index 0000000000..98e84547e0 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.mlflow + +This plugin enables seamless integration between Flyte and mlflow. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + mlflow_autolog +""" + +from .tracking import mlflow_autolog diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py new file mode 100644 index 0000000000..b58aa4a120 --- /dev/null +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -0,0 +1,140 @@ +import typing +from functools import partial, wraps + +import mlflow +import pandas +import pandas as pd +import plotly.graph_objects as go +from mlflow import MlflowClient +from mlflow.entities.metric import Metric +from plotly.subplots import make_subplots + +import flytekit +from flytekit import FlyteContextManager +from flytekit.bin.entrypoint import get_one_of +from flytekit.core.context_manager import ExecutionState +from flytekit.deck import TopFrameRenderer + + +def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: + """ + Converts mlflow Metric object to a dataframe of 2 columns ['timestamp', 'value'] + """ + t = [] + v = [] + for m in metrics: + t.append(m.timestamp) + v.append(m.value) + return pd.DataFrame(list(zip(t, v)), columns=["timestamp", "value"]) + + +def get_run_metrics(c: MlflowClient, run_id: str) -> typing.Dict[str, pandas.DataFrame]: + """ + Extracts all metrics and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + metrics = {} + for k in r.data.metrics.keys(): + metrics[k] = metric_to_df(metrics=c.get_metric_history(run_id=run_id, key=k)) + return metrics + + +def get_run_params(c: MlflowClient, run_id: str) -> typing.Optional[pd.DataFrame]: + """ + Extracts all parameters and returns a dictionary of metric name to the list of metric for the given run_id + """ + r = c.get_run(run_id) + name = [] + value = [] + if r.data.params == {}: + return None + for k, v in r.data.params.items(): + name.append(k) + value.append(v) + return pd.DataFrame(list(zip(name, value)), columns=["name", "value"]) + + +def plot_metrics(metrics: typing.Dict[str, pandas.DataFrame]) -> typing.Optional[go.Figure]: + v = len(metrics) + if v == 0: + return None + + # Initialize figure with subplots + fig = make_subplots(rows=v, cols=1, subplot_titles=list(metrics.keys())) + + # Add traces + row = 1 + for k, v in metrics.items(): + v["timestamp"] = (v["timestamp"] - v["timestamp"][0]) / 1000 + fig.add_trace(go.Scatter(x=v["timestamp"], y=v["value"], name=k), row=row, col=1) + row = row + 1 + + fig.update_xaxes(title_text="Time (s)") + fig.update_layout(height=700, width=900) + return fig + + +def mlflow_autolog(fn=None, *, framework=mlflow.sklearn, experiment_name: typing.Optional[str] = None): + """MLFlow decorator to enable autologging of training metrics. + + This decorator can be used as a nested decorator for a ``@task`` and it will automatically enable mlflow autologging, + for the given ``framework``. By default autologging is enabled for ``sklearn``. + + .. code-block:: python + + @task + @mlflow_autolog(framework=mlflow.tensorflow) + def my_tensorflow_trainer(): + ... + + One benefit of doing so is that the mlflow metrics are then rendered inline using FlyteDecks and can be viewed + in jupyter notebook, as well as in hosted Flyte environment: + + .. code-block:: python + + # jupyter notebook cell + with flytekit.new_context() as ctx: + my_tensorflow_trainer() + ctx.get_deck() # IPython.display + + When the task is called in a Flyte backend, the decorator starts a new MLFlow run using the Flyte execution name + by default, or a user-provided ``experiment_name`` in the decorator. + + :param fn: Function to generate autologs for. + :param framework: The mlflow module to use for autologging + :param experiment_name: The MLFlow experiment name. If not provided, uses the Flyte execution name. + """ + + @wraps(fn) + def wrapper(*args, **kwargs): + framework.autolog() + params = FlyteContextManager.current_context().user_space_params + ctx = FlyteContextManager.current_context() + + experiment = experiment_name or "local workflow" + run_name = None # MLflow will generate random name if value is None + + if ctx.execution_state.mode != ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + experiment = experiment_name and f"{get_one_of('FLYTE_INTERNAL_EXECUTION_WORKFLOW', '_F_WF')}" + run_name = f"{params.execution_id.name}.{params.task_id.name.split('.')[-1]}" + + mlflow.set_experiment(experiment) + with mlflow.start_run(run_name=run_name): + out = fn(*args, **kwargs) + run = mlflow.active_run() + if run is not None: + client = MlflowClient() + run_id = run.info.run_id + metrics = get_run_metrics(client, run_id) + figure = plot_metrics(metrics) + if figure: + flytekit.Deck("mlflow metrics", figure.to_html()) + params = get_run_params(client, run_id) + if params is not None: + flytekit.Deck("mlflow params", TopFrameRenderer(max_rows=10).to_html(params)) + return out + + if fn is None: + return partial(mlflow_autolog, framework=framework, experiment_name=experiment_name) + + return wrapper diff --git a/plugins/flytekit-mlflow/requirements.in b/plugins/flytekit-mlflow/requirements.in new file mode 100644 index 0000000000..cbe58e3885 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.in @@ -0,0 +1,3 @@ +. +-e file:.#egg=flytekitplugins-mlflow +grpcio-status<1.49.0 diff --git a/plugins/flytekit-mlflow/requirements.txt b/plugins/flytekit-mlflow/requirements.txt new file mode 100644 index 0000000000..03873c05f5 --- /dev/null +++ b/plugins/flytekit-mlflow/requirements.txt @@ -0,0 +1,274 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-mlflow + # via -r requirements.in +alembic==1.8.1 + # via mlflow +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.9.24 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.1 + # via requests +click==8.1.3 + # via + # cookiecutter + # databricks-cli + # flask + # flytekit + # mlflow +cloudpickle==2.2.0 + # via + # flytekit + # mlflow +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.7 + # via flytekit +cryptography==38.0.3 + # via pyopenssl +databricks-cli==0.17.3 + # via mlflow +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==6.0.1 + # via + # flytekit + # mlflow +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +entrypoints==0.4 + # via mlflow +flask==2.2.2 + # via + # mlflow + # prometheus-flask-exporter +flyteidl==1.1.22 + # via flytekit +flytekit==1.2.3 + # via flytekitplugins-mlflow +gitdb==4.0.9 + # via gitpython +gitpython==3.1.29 + # via mlflow +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +greenlet==2.0.1 + # via sqlalchemy +grpcio==1.50.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.48.2 + # via + # -r requirements.in + # flytekit +gunicorn==20.1.0 + # via mlflow +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via + # flask + # flytekit + # keyring + # mlflow +itsdangerous==2.1.2 + # via flask +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # flask + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.11.0 + # via flytekit +mako==1.2.3 + # via alembic +markupsafe==2.1.1 + # via + # jinja2 + # mako + # werkzeug +marshmallow==3.18.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +mlflow==1.30.0 + # via flytekitplugins-mlflow +more-itertools==9.0.0 + # via jaraco-classes +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.2.0 + # via flytekit +numpy==1.23.4 + # via + # mlflow + # pandas + # pyarrow + # scipy +oauthlib==3.2.2 + # via databricks-cli +packaging==21.3 + # via + # docker + # marshmallow + # mlflow +pandas==1.5.1 + # via + # flytekit + # mlflow +plotly==5.11.0 + # via flytekitplugins-mlflow +prometheus-client==0.15.0 + # via prometheus-flask-exporter +prometheus-flask-exporter==0.20.3 + # via mlflow +protobuf==3.20.3 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # mlflow + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyjwt==2.6.0 + # via databricks-cli +pyopenssl==22.1.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.6 + # via + # flytekit + # mlflow + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # mlflow +querystring-parser==1.2.4 + # via mlflow +regex==2022.10.31 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # databricks-cli + # docker + # flytekit + # mlflow + # responses +responses==0.22.0 + # via flytekit +retry==0.9.2 + # via flytekit +scipy==1.9.3 + # via mlflow +six==1.16.0 + # via + # databricks-cli + # grpcio + # python-dateutil + # querystring-parser +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +sqlalchemy==1.4.43 + # via + # alembic + # mlflow +sqlparse==0.4.3 + # via mlflow +statsd==3.3.0 + # via flytekit +tabulate==0.9.0 + # via databricks-cli +tenacity==8.1.0 + # via plotly +text-unidecode==1.3 + # via python-slugify +toml==0.10.2 + # via responses +types-toml==0.10.8 + # via responses +typing-extensions==4.4.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.12 + # via + # docker + # flytekit + # requests + # responses +websocket-client==1.4.2 + # via docker +werkzeug==2.2.2 + # via flask +wheel==0.38.3 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zipp==3.10.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-mlflow/setup.py b/plugins/flytekit-mlflow/setup.py new file mode 100644 index 0000000000..2033ce5d27 --- /dev/null +++ b/plugins/flytekit-mlflow/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "mlflow" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.1.0,<2.0.0", "plotly", "mlflow"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables seamless use of MLFlow within Flyte", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-mlflow/tests/__init__.py b/plugins/flytekit-mlflow/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py new file mode 100644 index 0000000000..b196327d8d --- /dev/null +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -0,0 +1,32 @@ +import mlflow +import tensorflow as tf +from flytekitplugins.mlflow import mlflow_autolog + +import flytekit +from flytekit import task + + +@task(disable_deck=False) +@mlflow_autolog(framework=mlflow.keras) +def train_model(epochs: int): + fashion_mnist = tf.keras.datasets.fashion_mnist + (train_images, train_labels), (_, _) = fashion_mnist.load_data() + train_images = train_images / 255.0 + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + + model.compile( + optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + ) + model.fit(train_images, train_labels, epochs=epochs) + + +def test_local_exec(): + train_model(epochs=1) + assert len(flytekit.current_context().decks) == 4 # mlflow metrics, params, input, and output diff --git a/requirements.txt b/requirements.txt index 5623078a25..22d31976ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with Python 3.7 -# by the following command: +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: # # make requirements.txt # @@ -66,9 +66,7 @@ idna==3.4 # via requests importlib-metadata==5.1.0 # via - # click # flytekit - # jsonschema # keyring jaraco-classes==3.2.3 # via keyring @@ -108,7 +106,6 @@ natsort==8.2.0 numpy==1.21.6 # via # -r requirements.in - # flytekit # pandas # pyarrow packaging==21.3 @@ -194,10 +191,7 @@ types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 # via - # arrow # flytekit - # importlib-metadata - # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json diff --git a/setup.py b/setup.py index 3d9710004f..b42f8a8490 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ # TODO: We should remove mentions to the deprecated numpy # aliases. More details in https://github.com/flyteorg/flyte/issues/3166 "numpy<1.24.0", + "gitpython", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index b3f1807b96..10a7e09333 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -40,12 +40,13 @@ def get_admin_stub_mock() -> mock.MagicMock: return auth_stub_mock +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_client_set_token(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True @@ -73,6 +74,7 @@ def test_refresh_credentials_from_command(mock_call_to_external_process, mock_ad mock_set_access_token.assert_called_with(token, client.public_client_config.authorization_metadata_key) +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.get_basic_authorization_header") @mock.patch("flytekit.clients.raw.get_token") @@ -88,6 +90,7 @@ def test_refresh_client_credentials_aka_basic( mock_get_token, mock_get_basic_header, mock_dataproxy, + mock_signal, ): mock_secure_channel.return_value = True mock_channel.return_value = True @@ -112,12 +115,13 @@ def test_refresh_client_credentials_aka_basic( assert client._metadata[0][0] == "authorization" +@mock.patch("flytekit.clients.raw.signal_service") @mock.patch("flytekit.clients.raw.dataproxy_service") @mock.patch("flytekit.clients.raw.auth_service") @mock.patch("flytekit.clients.raw._admin_service") @mock.patch("flytekit.clients.raw.grpc.insecure_channel") @mock.patch("flytekit.clients.raw.grpc.secure_channel") -def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy): +def test_raises(mock_secure_channel, mock_channel, mock_admin, mock_admin_auth, mock_dataproxy, mock_signal): mock_secure_channel.return_value = True mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 98af80638a..6e68c9d4be 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -207,7 +207,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 376 + assert len(tp) == 388 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py index a4689ed814..c92e1c9e19 100644 --- a/tests/flytekit/unit/core/test_gate.py +++ b/tests/flytekit/unit/core/test_gate.py @@ -13,7 +13,8 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow -from flytekit.tools.translator import get_serializable +from flytekit.remote.entities import FlyteWorkflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( @@ -290,3 +291,35 @@ def cond_wf(a: int) -> float: x = cond_wf(a=3) assert x == 6 assert stdin.read() == "" + + +def test_promote(): + @task + def t1(a: int) -> int: + return a + 5 + + @task + def t2(a: int) -> int: + return a + 6 + + @workflow + def wf(a: int) -> typing.Tuple[int, int, int]: + zzz = sleep(timedelta(seconds=10)) + x = t1(a=a) + s1 = wait_for_input("my-signal-name", timeout=timedelta(hours=1), expected_type=bool) + s2 = wait_for_input("my-signal-name-2", timeout=timedelta(hours=2), expected_type=int) + z = t1(a=5) + y = t2(a=s2) + q = t2(a=approve(y, "approvalfory", timeout=timedelta(hours=2))) + zzz >> x + x >> s1 + s1 >> z + + return y, z, q + + entries = OrderedDict() + wf_spec = get_serializable(entries, serialization_settings, wf) + tts, wf_specs, lp_specs = gather_dependent_entities(entries) + + fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=tts) + assert fwf.template.nodes[2].gate_node is not None diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 442851a8a2..db05de0ddb 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -4,6 +4,7 @@ from typing_extensions import Annotated # type: ignore +from flytekit import task from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.interface import ( @@ -320,3 +321,20 @@ def z(a: Foo) -> Foo: assert params.parameters["a"].default is None assert our_interface.outputs["o0"].__origin__ == FlytePickle assert our_interface.inputs["a"].__origin__ == FlytePickle + + +def test_doc_string(): + @task + def t1(a: int) -> int: + """Set the temperature value. + + The value of the temp parameter is stored as a value in + the class variable temperature. + """ + return a + + assert t1.docs.short_description == "Set the temperature value." + assert ( + t1.docs.long_description.value + == "The value of the temp parameter is stored as a value in\nthe class variable temperature." + ) diff --git a/tests/flytekit/unit/core/test_signal.py b/tests/flytekit/unit/core/test_signal.py new file mode 100644 index 0000000000..a37da8955f --- /dev/null +++ b/tests/flytekit/unit/core/test_signal.py @@ -0,0 +1,42 @@ +from flyteidl.admin.signal_pb2 import Signal, SignalList +from mock import MagicMock + +from flytekit.configuration import Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.core.identifier import SignalIdentifier, WorkflowExecutionIdentifier +from flytekit.remote.remote import FlyteRemote + + +def test_remote_list_signals(): + ctx = FlyteContextManager.current_context() + wfeid = WorkflowExecutionIdentifier("p", "d", "execid") + signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl() + lt = TypeEngine.to_literal_type(int) + signal = Signal( + id=signal_id, + type=lt.to_flyte_idl(), + value=TypeEngine.to_literal(ctx, 3, int, lt).to_flyte_idl(), + ) + + mock_client = MagicMock() + mock_client.list_signals.return_value = SignalList(signals=[signal], token="") + + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + remote._client = mock_client + res = remote.list_signals("execid", "p", "d", limit=10) + assert len(res) == 1 + + +def test_remote_set_signal(): + mock_client = MagicMock() + + def checker(request): + assert request.id.signal_id == "sigid" + assert request.value.scalar.primitive.integer == 3 + + mock_client.set_signal.side_effect = checker + + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + remote._client = mock_client + remote.set_signal("sigid", "execid", 3) diff --git a/tests/flytekit/unit/models/test_documentation.py b/tests/flytekit/unit/models/test_documentation.py new file mode 100644 index 0000000000..7702df0452 --- /dev/null +++ b/tests/flytekit/unit/models/test_documentation.py @@ -0,0 +1,29 @@ +from flytekit.models.documentation import Description, Documentation, SourceCode + + +def test_long_description(): + value = "long" + icon_link = "http://icon" + obj = Description(value=value, icon_link=icon_link) + assert Description.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.value == value + assert obj.icon_link == icon_link + assert obj.format == Description.DescriptionFormat.RST + + +def test_source_code(): + link = "https://github.com/flyteorg/flytekit" + obj = SourceCode(link=link) + assert SourceCode.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.link == link + + +def test_documentation(): + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + obj = Documentation(short_description=short_description, long_description=long_description, source_code=source_code) + assert Documentation.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.short_description == short_description + assert obj.long_description == long_description + assert obj.source_code == source_code diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index fcebf465f9..fed32b63aa 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -7,6 +7,7 @@ import flytekit.models.interface as interface_models import flytekit.models.literals as literal_models +from flytekit import Description, Documentation, SourceCode from flytekit.models import literals, task, types from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -123,6 +124,60 @@ def test_task_template(in_tuple): assert obj.config == {"a": "b"} +def test_task_spec(): + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + ) + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + {"a": interface_models.Variable(int_type, "description1")}, + { + "b": interface_models.Variable(int_type, "description2"), + "c": interface_models.Variable(int_type, "description3"), + }, + ) + + resource = [task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1")] + resources = task.Resources(resource, resource) + + template = task.TaskTemplate( + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + "python", + task_metadata, + interfaces, + {"a": 1, "b": {"c": 2, "d": 3}}, + container=task.Container( + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, + ), + config={"a": "b"}, + ) + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + obj = task.TaskSpec(template, docs) + assert task.TaskSpec.from_flyte_idl(obj.to_flyte_idl()) == obj + assert obj.docs == docs + assert obj.template == template + + def test_task_template__k8s_pod_target(): int_type = types.LiteralType(types.SimpleType.INTEGER) obj = task.TaskTemplate( diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3a42f5af81..d229d0d5c9 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -5,8 +5,10 @@ from flytekit.models import task as _task from flytekit.models import types as _types from flytekit.models import workflow_closure as _workflow_closure +from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core import identifier as _identifier from flytekit.models.core import workflow as _workflow +from flytekit.models.documentation import Description, Documentation, SourceCode def test_workflow_closure(): @@ -81,3 +83,16 @@ def test_workflow_closure(): obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + + short_description = "short" + long_description = Description(value="long", icon_link="http://icon") + source_code = SourceCode(link="https://github.com/flyteorg/flytekit") + docs = Documentation( + short_description=short_description, long_description=long_description, source_code=source_code + ) + + workflow_spec = WorkflowSpec(template=template, sub_workflows=[], docs=docs) + assert WorkflowSpec.from_flyte_idl(workflow_spec.to_flyte_idl()) == workflow_spec + assert workflow_spec.docs.short_description == short_description + assert workflow_spec.docs.long_description == long_description + assert workflow_spec.docs.source_code == source_code