diff --git a/.github/workflows/docs_build.yml b/.github/workflows/docs_build.yml index 4fd71ce3b0..9bb57306bb 100644 --- a/.github/workflows/docs_build.yml +++ b/.github/workflows/docs_build.yml @@ -23,4 +23,4 @@ jobs: run: | sudo apt-get install python3-sphinx pip install -r doc-requirements.txt - SPHINXOPTS="-W" cd docs && make html + cd docs && SPHINXOPTS="-W" make html diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 75e356ab0a..ab39bc4e02 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -65,6 +65,7 @@ jobs: - flytekit-deck-standard - flytekit-dolt - flytekit-duckdb + - flytekit-envd - flytekit-greatexpectations - flytekit-hive - flytekit-k8s-pod diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 097d82323e..169e0e58b7 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -47,11 +47,28 @@ jobs: run: | make -C plugins build_all_plugins make -C plugins publish_all_plugins - # Added sleep because PYPI take some time in publish - - name: Sleep for 180 seconds - uses: jakejarvis/wait-action@master - with: - time: '180s' + - name: Sleep until pypi is available + id: pypiwait + run: | + # from refs/tags/v1.2.3 get 1.2.3 and make sure it's not an empty string + VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') + if [ -z "$VERSION" ] + then + echo "No tagged version found, exiting" + exit 1 + fi + LINK="https://pypi.org/project/flytekit/${VERSION}" + for i in {1..60}; do + if curl -L -I -s -f ${LINK} >/dev/null; then + echo "Found pypi" + exit 0 + else + echo "Did not find - Retrying in 10 seconds..." + sleep 10 + fi + done + exit 1 + shell: bash outputs: version: ${{ steps.bump.outputs.version }} @@ -120,6 +137,91 @@ jobs: tags: ${{ steps.sqlalchemy-names.outputs.tags }} build-args: | VERSION=${{ needs.deploy.outputs.version }} - file: ./plugins/flytekit-sqlalchemy/Dockerfile.py${{ matrix.python-version }} + PYTHON_VERSION=${{ matrix.python-version }} + file: ./plugins/flytekit-sqlalchemy/Dockerfile + cache-from: type=gha + cache-to: type=gha,mode=max + + build-and-push-external-plugin-service-images: + runs-on: ubuntu-latest + needs: deploy + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + if: ${{ github.event_name == 'release' }} + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: "${{ secrets.FLYTE_BOT_USERNAME }}" + password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare External Plugin Service Image Names + id: external-plugin-service-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/external-plugin-service + tags: | + latest + ${{ github.sha }} + ${{ needs.deploy.outputs.version }} + - name: Push External Plugin Service Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "." + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.external-plugin-service-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + file: ./Dockerfile.external-plugin-service + cache-from: type=gha + cache-to: type=gha,mode=max + + build-and-push-spark-images: + runs-on: ubuntu-latest + needs: deploy + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + if: ${{ github.event_name == 'release' }} + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: "${{ secrets.FLYTE_BOT_USERNAME }}" + password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare Spark Image Names + id: spark-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flytekit + tags: | + spark-latest + spark-${{ github.sha }} + spark-${{ needs.deploy.outputs.version }} + - name: Push Spark Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "./plugins/flytekit-spark/" + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.spark-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + file: ./plugins/flytekit-spark/Dockerfile cache-from: type=gha cache-to: type=gha,mode=max diff --git a/.gitignore b/.gitignore index a4fe02503e..b2e20249a8 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ htmlcov *.ipynb *dat docs/source/_tags/ +.hypothesis diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1fd6e6b648..3007f6e64d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,7 @@ repos: rev: v0.8.0.4 hooks: - id: shellcheck +- repo: https://github.com/conorfalvey/check_pdb_hook + rev: 0.0.9 + hooks: + - id: check_pdb_hook diff --git a/.readthedocs.yml b/.readthedocs.yml index 19b1898e94..18f4292317 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -9,6 +9,8 @@ build: os: ubuntu-20.04 tools: python: "3.9" + apt_packages: + - graphviz # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/Dockerfile b/Dockerfile index 9aa462781c..257fcb5143 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,8 +16,7 @@ RUN apt-get update && apt-get install build-essential -y RUN pip install -U flytekit==$VERSION \ flytekitplugins-pod==$VERSION \ flytekitplugins-deck-standard==$VERSION \ - flytekitplugins-data-fsspec[aws]==$VERSION \ - flytekitplugins-data-fsspec[gcp]==$VERSION \ + flytekitplugins-envd==$VERSION \ scikit-learn RUN useradd -u 1000 flytekit diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 0000000000..b7c5104bbc --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,32 @@ +# This Dockerfile is here to help with end-to-end testing +# From flytekit +# $ docker build -f Dockerfile.dev --build-arg PYTHON_VERSION=3.10 -t localhost:30000/flytekittest:someversion . +# $ docker push localhost:30000/flytekittest:someversion +# From your test user code +# $ pyflyte run --image localhost:30000/flytekittest:someversion + +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-buster + +MAINTAINER Flyte Team +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit + +WORKDIR /root + +ARG VERSION + +RUN apt-get update && apt-get install build-essential vim -y + +COPY . /flytekit + +# Pod tasks should be exposed in the default image +RUN pip install -e /flytekit +RUN pip install -e /flytekit/plugins/flytekit-k8s-pod +RUN pip install -e /flytekit/plugins/flytekit-deck-standard +RUN pip install scikit-learn + +ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" + +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit diff --git a/Dockerfile.external-plugin-service b/Dockerfile.external-plugin-service new file mode 100644 index 0000000000..2194f5de23 --- /dev/null +++ b/Dockerfile.external-plugin-service @@ -0,0 +1,9 @@ +FROM python:3.9-slim-buster + +MAINTAINER Flyte Team +LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit + +ARG VERSION +RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION + +CMD pyflyte serve --port 8000 diff --git a/Makefile b/Makefile index eb9f43cdb6..09b924c691 100644 --- a/Makefile +++ b/Makefile @@ -26,16 +26,18 @@ setup: install-piptools ## Install requirements .PHONY: fmt fmt: ## Format code with black and isort + autoflake --remove-all-unused-imports --ignore-init-module-imports --ignore-pass-after-docstring --in-place -r flytekit plugins tests pre-commit run black --all-files || true pre-commit run isort --all-files || true .PHONY: lint lint: ## Run linters - mypy flytekit/core || true - mypy flytekit/types || true - mypy tests/flytekit/unit/core || true - # Exclude setup.py to fix error: Duplicate module named "setup" - mypy plugins --exclude setup.py || true + mypy flytekit/core + mypy flytekit/types + # allow-empty-bodies: Allow empty body in function. + # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". + # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. + mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core pre-commit run --all-files .PHONY: spellcheck diff --git a/dev-requirements.in b/dev-requirements.in index 8655b92bae..a912d8c9d9 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -2,6 +2,7 @@ git+https://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte coverage[toml] +hypothesis joblib mock pytest @@ -21,3 +22,7 @@ tensorflow==2.8.1; platform_machine!='arm64' or platform_system!='Darwin' # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn +types-protobuf<4 +types-croniter +types-mock +autoflake diff --git a/doc-requirements.in b/doc-requirements.in index e17b05ec5b..b495e7616f 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -1,6 +1,7 @@ . -e file:.#egg=flytekit +grpcio<=1.49.1 git+https://github.com/flyteorg/furo@main sphinx sphinx-gallery @@ -11,7 +12,7 @@ sphinx-autoapi sphinx-copybutton sphinx_fontawesome sphinx-panels -sphinxcontrib-yt +sphinxcontrib-youtube cryptography google-api-core[grpc] scikit-learn diff --git a/doc-requirements.txt b/doc-requirements.txt index 2616e043df..a635188c5f 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -10,8 +10,24 @@ absl-py==1.4.0 # via # tensorboard # tensorflow +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.1.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore aiosignal==1.3.1 - # via ray + # via + # aiohttp + # ray alabaster==0.7.13 # via sphinx alembic==1.9.2 @@ -46,11 +62,25 @@ asttokens==2.2.1 # via stack-data astunparse==1.6.3 # via tensorflow +async-timeout==4.0.2 + # via aiohttp attrs==22.2.0 # via + # aiohttp # jsonschema # ray # visions +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs babel==2.11.0 # via sphinx backcall==0.2.0 @@ -67,8 +97,10 @@ blake3==0.3.3 # via vaex-core bleach==6.0.0 # via nbconvert -botocore==1.29.61 - # via -r doc-requirements.in +botocore==1.29.76 + # via + # -r doc-requirements.in + # aiobotocore bqplot==0.12.36 # via # ipyvolume @@ -86,13 +118,16 @@ certifi==2022.12.7 cffi==1.15.1 # via # argon2-cffi-bindings + # azure-datalake-store # cryptography cfgv==3.3.1 # via pre-commit chardet==5.0.0 # via binaryornot charset-normalizer==3.0.1 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter @@ -128,7 +163,12 @@ croniter==1.3.7 cryptography==39.0.0 # via # -r doc-requirements.in + # adal + # azure-identity + # azure-storage-blob # great-expectations + # msal + # pyjwt # pyopenssl # secretstorage css-html-js-minify==2.5.5 @@ -150,8 +190,8 @@ debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via + # gcsfs # ipython - # retry defusedxml==0.7.1 # via nbconvert deprecated==1.2.13 @@ -197,11 +237,9 @@ filelock==3.9.0 # virtualenv flask==2.2.2 # via mlflow -flatbuffers==2.0.7 - # via - # tensorflow - # tf2onnx -flyteidl==1.2.9 +flatbuffers==23.1.21 + # via tensorflow +flyteidl==1.3.16 # via flytekit fonttools==4.38.0 # via matplotlib @@ -211,19 +249,26 @@ frozendict==2.3.4 # via vaex-core frozenlist==1.3.3 # via + # aiohttp # aiosignal # ray -fsspec==2023.1.0 +fsspec==2023.4.0 # via # -r doc-requirements.in + # adlfs # dask + # flytekit + # gcsfs # modin + # s3fs furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in future==0.18.3 # via vaex-core gast==0.5.3 # via tensorflow +gcsfs==2023.4.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.30 @@ -235,27 +280,38 @@ google-api-core[grpc]==2.11.0 # -r doc-requirements.in # google-cloud-bigquery # google-cloud-core + # google-cloud-storage google-auth==2.16.0 # via + # gcsfs # google-api-core # google-auth-oauthlib # google-cloud-core + # google-cloud-storage # kubernetes # tensorboard google-auth-oauthlib==0.4.6 - # via tensorboard + # via + # gcsfs + # tensorboard google-cloud==0.34.0 # via -r doc-requirements.in google-cloud-bigquery==3.5.0 # via -r doc-requirements.in google-cloud-core==2.3.2 - # via google-cloud-bigquery + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-storage==2.8.0 + # via gcsfs google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow google-resumable-media==2.4.1 - # via google-cloud-bigquery + # via + # google-cloud-bigquery + # google-cloud-storage googleapis-common-protos==1.58.0 # via # flyteidl @@ -299,6 +355,7 @@ idna==3.4 # anyio # jsonschema # requests + # yarl imagehash==4.3.1 # via visions imagesize==1.4.1 @@ -364,6 +421,8 @@ ipywidgets==8.0.4 # ipyvue # jupyter # pythreejs +isodate==0.6.1 + # via azure-storage-blob isoduration==20.11.0 # via jsonschema itsdangerous==2.1.2 @@ -526,10 +585,20 @@ modin==0.18.1 # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity msgpack==1.0.4 # via # distributed # ray +multidict==6.0.4 + # via + # aiohttp + # yarl multimethod==1.9.1 # via # visions @@ -720,7 +789,9 @@ platformdirs==2.6.2 # virtualenv plotly==5.13.0 # via -r doc-requirements.in -pre-commit==3.0.2 +portalocker==2.7.0 + # via msal-extensions +pre-commit==3.0.4 # via sphinx-tags progressbar2==4.2.0 # via vaex-core @@ -766,8 +837,6 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -py==1.11.0 - # via retry py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 @@ -802,8 +871,11 @@ pygments==2.14.0 # rich # sphinx # sphinx-prompt -pyjwt==2.6.0 - # via databricks-cli +pyjwt[crypto]==2.6.0 + # via + # adal + # databricks-cli + # msal pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 @@ -816,6 +888,7 @@ pyspark==3.3.1 # via -r doc-requirements.in python-dateutil==2.8.2 # via + # adal # arrow # botocore # croniter @@ -891,16 +964,22 @@ regex==2022.10.31 # via docker-image-py requests==2.28.2 # via + # adal + # azure-core + # azure-datalake-store # cookiecutter # databricks-cli # docker # flytekit + # gcsfs # google-api-core # google-cloud-bigquery + # google-cloud-storage # great-expectations # ipyvolume # kubernetes # mlflow + # msal # papermill # ray # requests-oauthlib @@ -916,8 +995,6 @@ requests-oauthlib==1.3.1 # kubernetes responses==0.22.0 # via flytekit -retry==0.9.2 - # via flytekit rfc3339-validator==0.1.4 # via # jsonschema @@ -927,14 +1004,18 @@ rfc3986-validator==0.1.1 # jsonschema # jupyter-events rich==13.3.1 - # via vaex-core + # via + # flytekit + # vaex-core rsa==4.9 # via google-auth ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.7 # via ruamel-yaml -scikit-learn==1.1.1 +s3fs==2023.4.0 + # via flytekit +scikit-learn==1.2.1 # via # -r doc-requirements.in # mlflow @@ -966,11 +1047,13 @@ six==1.16.0 # via # asttokens # astunparse + # azure-core + # azure-identity # bleach # databricks-cli # google-auth # google-pasta - # grpcio + # isodate # keras-preprocessing # kubernetes # patsy @@ -1047,7 +1130,7 @@ sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sphinxcontrib-yt==0.2.2 +sphinxcontrib-youtube==1.2.0 # via -r doc-requirements.in sqlalchemy==1.4.46 # via @@ -1167,7 +1250,10 @@ types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 # via + # aioitertools # astroid + # azure-core + # azure-storage-blob # flytekit # great-expectations # onnx @@ -1275,6 +1361,7 @@ widgetsnbextension==4.0.5 # via ipywidgets wrapt==1.14.1 # via + # aiobotocore # astroid # deprecated # flytekit @@ -1284,6 +1371,8 @@ xarray==2023.1.0 # via vaex-jupyter xyzservices==2022.9.0 # via ipyleaflet +yarl==1.8.2 + # via aiohttp ydata-profiling==4.0.0 # via pandas-profiling zict==2.2.0 diff --git a/docs/source/conf.py b/docs/source/conf.py index 205bcb8838..6c0663f6b5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -57,7 +57,7 @@ "sphinx-prompt", "sphinx_copybutton", "sphinx_panels", - "sphinxcontrib.yt", + "sphinxcontrib.youtube", "sphinx_tags", "sphinx_click", ] diff --git a/docs/source/design/authoring.rst b/docs/source/design/authoring.rst index 46a094399d..3118a9c71b 100644 --- a/docs/source/design/authoring.rst +++ b/docs/source/design/authoring.rst @@ -1,40 +1,42 @@ .. _design-authoring: -####################### +################### Authoring Structure -####################### +################### .. tags:: Design, Basic -One of the core features of Flytekit is to enable users to write tasks and workflows. In this section, we will understand how it works internally. - -.. note:: - - Please refer to the `design doc `__. +Flytekit's main focus is to provide users with the ability to create their own tasks and workflows. +In this section, we'll take a closer look at how it works under the hood. ********************* Types and Type Engine ********************* -Flyte has its own type system, which is codified `in the IDL `__. Python has its own type system despite being a dynamic language, which is primarily explained in `PEP 484 `_. Flytekit needs to build a medium to bridge the gap between these two type systems. -Type Engine -============= -This primariliy happens through the :py:class:`flytekit.extend.TypeEngine`. This engine works by invoking a series of :py:class:`TypeTransformers `. Each transformer is responsible for providing the functionality that the engine requires for a given native Python type. +Flyte uses its own type system, which is defined in the `IDL `__. +Despite being a dynamic language, Python also has its own type system which is primarily explained in `PEP 484 `__. +Therefore, Flytekit needs to establish a means of bridging the gap between these two type systems. +This is primariliy accomplished through the use of :py:class:`flytekit.extend.TypeEngine`. +The ``TypeEngine`` works by invoking a series of :py:class:`TypeTransformers `. +Each transformer is responsible for providing the functionality that the engine requires for a given native Python type. ***************** Callable Entities ***************** -:ref:`Tasks `, :ref:`workflows `, and :ref:`launch plans ` form the core of the Flyte user experience. Each of these concepts is backed by one or more Python classes. These classes in turn, are instantiated by decorators (in the case of tasks and workflow) or a regular Python call (in the case of launch plans). + +The Flyte user experience is built around three main concepts: :ref:`Tasks `, :ref:`workflows `, and :ref:`launch plans `. +Each of these concepts is supported by one or more Python classes, which are instantiated by decorators (in the case of tasks and workflows) or a regular Python call (in the case of launch plans). Tasks ===== -This is the current task class hierarchy: + +Here is the existing hierarchy of task classes: .. inheritance-diagram:: flytekit.core.python_function_task.PythonFunctionTask flytekit.core.python_function_task.PythonInstanceTask flytekit.extras.sqlite3.task.SQLite3Task - :parts: 1 :top-classes: flytekit.core.base_task.Task + :parts: 1 -Please see the documentation on each of the classes for details. +For more information on each of the classes, please refer to the corresponding documentation. .. autoclass:: flytekit.core.base_task.Task :noindex: @@ -48,10 +50,10 @@ Please see the documentation on each of the classes for details. .. autoclass:: flytekit.core.python_function_task.PythonFunctionTask :noindex: - Workflows ========== -There are two workflow classes, and both inherit from the :py:class:`WorkflowBase ` class. + +There exist two workflow classes, both of which derive from the ``WorkflowBase`` class. .. autoclass:: flytekit.core.workflow.PythonFunctionWorkflow :noindex: @@ -59,10 +61,10 @@ There are two workflow classes, and both inherit from the :py:class:`WorkflowBas .. autoclass:: flytekit.core.workflow.ImperativeWorkflow :noindex: +Launch Plans +============ -Launch Plan -=========== -There is only one :py:class:`LaunchPlan ` class. +There exists one :py:class:`LaunchPlan ` class. .. autoclass:: flytekit.core.launch_plan.LaunchPlan :noindex: @@ -72,12 +74,13 @@ There is only one :py:class:`LaunchPlan ` ****************** Exception Handling ****************** -Exception handling takes place along two dimensions: -* System vs. User: We try to differentiate between user exceptions and Flytekit/system-level exceptions. For instance, if Flytekit fails to upload its outputs, that's a system exception. If the user raises a ``ValueError`` because of an unexpected input in the task code, that's a user exception. -* Recoverable vs. Non-recoverable: Recoverable errors will be retried and counted against the task's retry count. Non-recoverable errors will simply fail. System exceptions are by default recoverable (since there's a good chance it was just a blip). +Exception handling occurs along two dimensions: + +* System vs. User: We distinguish between Flytekit/system-level exceptions and user exceptions. For instance, if Flytekit encounters an issue while uploading outputs, it is considered a system exception. On the other hand, if a user raises a ``ValueError`` due to an unexpected input in the task code, it is classified as a user exception. +* Recoverable vs. Non-recoverable: Recoverable errors are retried and counted towards the task's retry count, while non-recoverable errors simply fail. System exceptions are recoverable by default since they are usually temporary. -Here's the user exception tree. Feel free to raise any of these exception classes. Note that the ``FlyteRecoverableException`` is the only recoverable exception. All others, along with all the non-Flytekit defined exceptions, are non-recoverable. +The following is the user exception tree, which users can raise as needed. It is important to note that only ``FlyteRecoverableException`` is a recoverable exception. All other exceptions, including non-Flytekit defined exceptions, are non-recoverable. .. inheritance-diagram:: flytekit.exceptions.user.FlyteValidationException flytekit.exceptions.user.FlyteEntityAlreadyExistsException flytekit.exceptions.user.FlyteValueException flytekit.exceptions.user.FlyteTimeout flytekit.exceptions.user.FlyteAuthenticationException flytekit.exceptions.user.FlyteRecoverableException :parts: 1 @@ -85,36 +88,42 @@ Here's the user exception tree. Feel free to raise any of these exception classe Implementation ============== -For those who want to dig deeper, take a look at the :py:class:`flytekit.common.exceptions.scopes.FlyteScopedException` classes. -There are two decorators that are interspersed throughout the codebase. + +If you wish to delve deeper, you can explore the ``FlyteScopedException`` classes. + +There are two decorators that are used throughout the codebase. .. autofunction:: flytekit.exceptions.scopes.system_entry_point .. autofunction:: flytekit.exceptions.scopes.user_entry_point -************** +************* Call Patterns -************** -The above-mentioned entities (tasks, workflows, and launch plan) are callable. They can be invoked to yield a unit (or units) of work in Flyte. +************* -In Pythonic terms, when you add ``()`` to the end of one of the entities, it invokes the ``__call__`` method on the object. +The entities mentioned above (tasks, workflows, and launch plans) are callable and can be invoked to generate one or more units of work in Flyte. -What happens when a callable entity is called depends on the current context, specifically the current :py:class:`flytekit.FlyteContext` +In Pythonic terminology, adding ``()`` to the end of an entity invokes the ``__call__`` method on the object. -Raw Task Execution -=================== -This is what happens when a task is just run as part of a unit test. The ``@task`` decorator actually turns the decorated function into an instance of the ``PythonFunctionTask`` object, but when a user calls the ``task()`` outside of a workflow, the original function is called without any interference by Flytekit. +The behavior that occurs when a callable entity is invoked is dependent on the current context, specifically the current :py:class:`flytekit.FlyteContext`. -Task Execution Inside Workflow -=============================== -When a workflow is run locally (say as a part of a unit test), certain changes occur in the ``task``. +Raw task execution +================== -Before going further, there is a special object that's worth mentioning, the :py:class:`flytekit.extend.Promise`. +When a task is executed as part of a unit test, the ``@task`` decorator transforms the decorated function into an instance of the ``PythonFunctionTask`` object. +However, when a user invokes the ``task()`` function outside of a workflow, the original function is called without any intervention from Flytekit. + +Task execution inside a workflow +================================ + +When a workflow is executed locally (for instance, as part of a unit test), some modifications are made to the task. + +Before proceeding, it is worth noting a special object, the :py:class:`flytekit.extend.Promise`. .. autoclass:: flytekit.core.promise.Promise :noindex: -Let's assume we have a workflow like :: +Consider the following workflow: :: @task def t1(a: int) -> Tuple[int, str]: @@ -130,19 +139,23 @@ Let's assume we have a workflow like :: d = t2(a=y, b=b) return x, d -As discussed in the Promise object's documentation, when a task is called from inside a workflow, the Python native values returned by the raw underlying functions are first converted into Flyte IDL literals and then wrapped inside ``Promise`` objects. One ``Promise`` is created for every return variable. +As stated in the documentation for the Promise object, when a task is invoked within a workflow, the Python native values returned by the underlying functions are first converted into Flyte IDL literals and then encapsulated inside Promise objects. +One Promise object is created for each return variable. -When the next task is called, the logic is triggered to unwrap these Promises. +When the next task is invoked, the values are extracted from these Promises. Compilation =========== -When a workflow is compiled, instead of producing Promise objects that wrap literal values, they wrap a :py:class:`flytekit.core.promise.NodeOutput` instead. This helps track data dependency between tasks. + +During the workflow compilation process, instead of generating Promise objects that encapsulate literal values, the workflow encapsulates a :py:class:`flytekit.core.promise.NodeOutput`. +This approach aids in tracking the data dependencies between tasks. Branch Skip =========== -If a :py:func:`flytekit.conditional` is determined to be false, then Flytekit will skip calling the task. This avoids running the unintended task. +If the condition specified in a :py:func:`flytekit.conditional` evaluates to ``False``, Flytekit will avoid invoking the corresponding task. +This prevents the unintended execution of the task. .. note:: - We discussed about a task's execution pattern above. The same pattern can be applied to workflows and launch plans too! + The execution pattern that we discussed for tasks can be applied to workflows and launch plans as well! diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 20433c2ad1..4bbac570f3 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -26,6 +26,7 @@ map_task ~core.workflow.ImperativeWorkflow ~core.node_creation.create_node + ~core.promise.NodeOutput FlyteContextManager Running Locally @@ -195,6 +196,10 @@ import sys from typing import Generator +from rich import traceback + +from flytekit.lazy_import.lazy_module import lazy_module + if sys.version_info < (3, 10): from importlib_metadata import entry_points else: @@ -206,7 +211,6 @@ from flytekit.core.condition import conditional from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod @@ -223,8 +227,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch, sklearn, tensorflow -from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence +from flytekit.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase @@ -232,7 +235,7 @@ from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types import directory, file, numpy, schema +from flytekit.types import directory, file from flytekit.types.structured.structured_dataset import ( StructuredDataset, StructuredDatasetFormat, @@ -299,3 +302,6 @@ def load_implicit_plugins(): # Load all implicit plugins load_implicit_plugins() + +# Pretty-print exception messages +traceback.install(width=None, extra_lines=0) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ca7a6cf20d..a9b7c313f0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -10,7 +10,6 @@ import click as _click from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit import PythonFunctionTask from flytekit.configuration import ( SERIALIZED_CONTEXT_ENV_VAR, FastSerializationSettings, @@ -23,7 +22,7 @@ from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapTaskResolver from flytekit.core.promise import VoidPromise from flytekit.exceptions import scopes as _scoped_exceptions from flytekit.exceptions import scopes as _scopes @@ -391,12 +390,8 @@ def _execute_map_task( with setup_execution( raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir ) as ctx: - resolver_obj = load_object_from_module(resolver) - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if not isinstance(_task_def, PythonFunctionTask): - raise Exception("Map tasks cannot be run with instance tasks.") - map_task = MapPythonTask(_task_def, max_concurrency) + mtr = MapTaskResolver() + map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) task_index = _compute_array_job_index() output_prefix = os.path.join(output_prefix, str(task_index)) diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index 94afa13612..ec1fd4d3e1 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -269,9 +269,11 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: refresh_token = response_body["refresh_token"] + if "expires_in" in response_body: + expires_in = response_body["expires_in"] access_token = response_body["access_token"] - return Credentials(access_token, refresh_token, self._endpoint) + return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in) def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 183c1787cd..9582c901d8 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -1,12 +1,10 @@ -import base64 import logging import subprocess import typing from abc import abstractmethod from dataclasses import dataclass -import requests - +from . import token_client from .auth_client import AuthorizationClient from .exceptions import AccessTokenNotFoundError, AuthenticationError from .keyring import Credentials, KeyringStore @@ -22,6 +20,7 @@ class ClientConfig: authorization_endpoint: str redirect_uri: str client_id: str + device_authorization_endpoint: typing.Optional[str] = None scopes: typing.List[str] = None header_key: str = "authorization" @@ -155,67 +154,25 @@ class ClientCredentialsAuthenticator(Authenticator): This Authenticator uses ClientId and ClientSecret to authenticate """ - _utf_8 = "utf-8" - def __init__( self, endpoint: str, client_id: str, client_secret: str, cfg_store: ClientConfigStore, - header_key: str = None, + header_key: typing.Optional[str] = None, + scopes: typing.Optional[typing.List[str]] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") cfg = cfg_store.get_client_config() self._token_endpoint = cfg.token_endpoint - self._scopes = cfg.scopes + # Use scopes from `flytekit.configuration.PlatformConfig` if passed + self._scopes = scopes or cfg.scopes self._client_id = client_id self._client_secret = client_secret super().__init__(endpoint, cfg.header_key or header_key) - @staticmethod - def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]: - """ - :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration - in seconds - """ - headers = { - "Authorization": authorization_header, - "Cache-Control": "no-cache", - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded", - } - body = { - "grant_type": "client_credentials", - } - if scopes is not None: - body["scope"] = ",".join(scopes) - response = requests.post(token_endpoint, data=body, headers=headers) - if response.status_code != 200: - logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise AuthenticationError("Non-200 received from IDP") - - response = response.json() - return response["access_token"], response["expires_in"] - - @staticmethod - def get_basic_authorization_header(client_id: str, client_secret: str) -> str: - """ - This function transforms the client id and the client secret into a header that conforms with http basic auth. - It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text - - :param client_id: str - :param client_secret: str - :rtype: str - """ - concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format( - base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode( - ClientCredentialsAuthenticator._utf_8 - ) - ) - def refresh_credentials(self): """ This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler @@ -229,7 +186,56 @@ def refresh_credentials(self): # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") - authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret) - token, expires_in = self.get_token(token_endpoint, authorization_header, scopes) + authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) + token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) + + +class DeviceCodeAuthenticator(Authenticator): + """ + This Authenticator implements the Device Code authorization flow useful for headless user authentication. + + Examples described + - https://developer.okta.com/docs/guides/device-authorization-grant/main/ + - https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow + """ + + def __init__( + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + audience: typing.Optional[str] = None, + ): + self._audience = audience + cfg = cfg_store.get_client_config() + self._client_id = cfg.client_id + self._device_auth_endpoint = cfg.device_authorization_endpoint + self._scope = cfg.scopes + self._token_endpoint = cfg.token_endpoint + if self._device_auth_endpoint is None: + raise AuthenticationError( + "Device Authentication is not available on the Flyte backend / authentication server" + ) + super().__init__( + endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint) + ) + + def refresh_credentials(self): + resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope) + print( + f""" +To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code} +OR copy paste the following URL: {resp.verification_uri_complete} + """ + ) + try: + # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that + # access tokens can be refreshed for once authenticated machines + token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id) + self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) + KeyringStore.store(self._creds) + except Exception: + KeyringStore.delete(self._endpoint) + raise diff --git a/flytekit/clients/auth/exceptions.py b/flytekit/clients/auth/exceptions.py index 6e790e47a4..5086c5b6e1 100644 --- a/flytekit/clients/auth/exceptions.py +++ b/flytekit/clients/auth/exceptions.py @@ -12,3 +12,11 @@ class AuthenticationError(RuntimeError): """ pass + + +class AuthenticationPending(RuntimeError): + """ + This is raised if the token endpoint returns authentication pending + """ + + pass diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py index c2b19c46b6..79f5e86c68 100644 --- a/flytekit/clients/auth/keyring.py +++ b/flytekit/clients/auth/keyring.py @@ -15,6 +15,7 @@ class Credentials(object): access_token: str refresh_token: str = "na" for_endpoint: str = "flyte-default" + expires_in: typing.Optional[int] = None class KeyringStore: @@ -39,7 +40,7 @@ def store(credentials: Credentials) -> Credentials: credentials.access_token, ) except NoKeyringError as e: - logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return credentials @staticmethod @@ -48,7 +49,7 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]: refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key) access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key) except NoKeyringError as e: - logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return None if not access_token: @@ -61,4 +62,4 @@ def delete(for_endpoint: str): _keyring.delete_password(for_endpoint, KeyringStore._access_token_key) _keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key) except NoKeyringError as e: - logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py new file mode 100644 index 0000000000..7cbb42a13e --- /dev/null +++ b/flytekit/clients/auth/token_client.py @@ -0,0 +1,154 @@ +import base64 +import enum +import logging +import time +import typing +import urllib.parse +from dataclasses import dataclass +from datetime import datetime, timedelta + +import requests + +from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending + +utf_8 = "utf-8" + +# Errors that Token endpoint will return +error_slow_down = "slow_down" +error_auth_pending = "authorization_pending" + + +# Grant Types +class GrantType(str, enum.Enum): + CLIENT_CREDS = "client_credentials" + DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" + + +@dataclass +class DeviceCodeResponse: + """ + Response from device auth flow endpoint + {'device_code': 'code', + 'user_code': 'BNDJJFXL', + 'verification_uri': 'url', + 'verification_uri_complete': 'url', + 'expires_in': 600, + 'interval': 5} + """ + + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + @classmethod + def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse": + return cls( + device_code=j["device_code"], + user_code=j["user_code"], + verification_uri=j["verification_uri"], + verification_uri_complete=j["verification_uri_complete"], + expires_in=j["expires_in"], + interval=j["interval"], + ) + + +def get_basic_authorization_header(client_id: str, client_secret: str) -> str: + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. Secrets are + first URL encoded to escape illegal characters. + + :param client_id: str + :param client_secret: str + :rtype: str + """ + encoded = urllib.parse.quote_plus(client_secret) + concatenated = "{}:{}".format(client_id, encoded) + return "Basic {}".format(base64.b64encode(concatenated.encode(utf_8)).decode(utf_8)) + + +def get_token( + token_endpoint: str, + scopes: typing.Optional[typing.List[str]] = None, + authorization_header: typing.Optional[str] = None, + client_id: typing.Optional[str] = None, + device_code: typing.Optional[str] = None, + grant_type: GrantType = GrantType.CLIENT_CREDS, +) -> typing.Tuple[str, int]: + """ + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + if authorization_header: + headers["Authorization"] = authorization_header + body = { + "grant_type": grant_type.value, + } + if client_id: + body["client_id"] = client_id + if device_code: + body["device_code"] = device_code + if scopes is not None: + body["scope"] = ",".join(scopes) + + response = requests.post(token_endpoint, data=body, headers=headers) + if not response.ok: + j = response.json() + if "error" in j: + err = j["error"] + if err == error_auth_pending or err == error_slow_down: + raise AuthenticationPending(f"Token not yet available, try again in some time {err}") + logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + + j = response.json() + return j["access_token"], j["expires_in"] + + +def get_device_code( + device_auth_endpoint: str, + client_id: str, + audience: typing.Optional[str] = None, + scope: typing.Optional[typing.List[str]] = None, +) -> DeviceCodeResponse: + """ + Retrieves the device Authentication code that can be done to authenticate the request using a browser on a + separate device + """ + payload = {"client_id": client_id, "scope": scope, "audience": audience} + resp = requests.post(device_auth_endpoint, payload) + if not resp.ok: + raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") + return DeviceCodeResponse.from_json_response(resp.json()) + + +def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]: + tick = datetime.now() + interval = timedelta(seconds=resp.interval) + end_time = tick + timedelta(seconds=resp.expires_in) + while tick < end_time: + try: + access_token, expires_in = get_token( + token_endpoint, + grant_type=GrantType.DEVICE_CODE, + client_id=client_id, + device_code=resp.device_code, + ) + print("Authentication successful!") + return access_token, expires_in + except AuthenticationPending: + ... + except Exception: + raise + print("Authentication Pending...") + time.sleep(interval.total_seconds()) + tick = tick + interval + raise AuthenticationError("Authentication failed!") diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 41fc5c025f..93bd883324 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -12,6 +12,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @@ -41,6 +42,7 @@ def get_client_config(self) -> ClientConfig: client_id=public_client_config.client_id, scopes=public_client_config.scopes, header_key=public_client_config.authorization_metadata_key or None, + device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, ) @@ -69,6 +71,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_id=cfg.client_id, client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, + scopes=cfg.scopes, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None @@ -78,6 +81,8 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth command=cfg.command, header_key=client_cfg.header_key if client_cfg else None, ) + elif cfg_auth == AuthType.DEVICEFLOW: + return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience) else: raise ValueError( f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index d542af5f7e..2b15dfbd50 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1007,7 +1007,7 @@ def get_upload_signed_url( def get_download_signed_url( self, native_url: str, expires_in: datetime.timedelta = None - ) -> _data_proxy_pb2.CreateUploadLocationResponse: + ) -> _data_proxy_pb2.CreateDownloadLocationRequest: expires_in_pb = None if expires_in: expires_in_pb = Duration() diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 21aec1c4ad..47d5c8c641 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -1167,7 +1167,6 @@ def terminate_execution(host, insecure, cause, urn=None): raise _click.UsageError('Missing option "-u" / "--urn" or missing pipe inputs.') except KeyboardInterrupt: _sys.stdout.flush() - pass else: _terminate_one_execution(client, urn, cause) diff --git a/flytekit/clis/sdk_in_container/backfill.py b/flytekit/clis/sdk_in_container/backfill.py index 234b03499f..49c2667d5b 100644 --- a/flytekit/clis/sdk_in_container/backfill.py +++ b/flytekit/clis/sdk_in_container/backfill.py @@ -1,7 +1,7 @@ import typing from datetime import datetime, timedelta -import click +import rich_click as click from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context from flytekit.clis.sdk_in_container.run import DateTimeType, DurationParamType @@ -10,8 +10,8 @@ The backfill command generates and registers a new workflow based on the input launchplan to run an automated backfill. The workflow can be managed using the Flyte UI and can be canceled, relaunched, and recovered. -- ``launchplan`` refers to the name of the Launchplan -- ``launchplan_version`` is optional and should be a valid version for a Launchplan version. + - ``launchplan`` refers to the name of the Launchplan + - ``launchplan_version`` is optional and should be a valid version for a Launchplan version. """ @@ -168,11 +168,12 @@ def backfill( execute=execute, parallel=parallel, ) - if entity: - console_url = remote.generate_console_url(entity) - if execute: - click.secho(f"\n Execution launched {console_url} to see execution in the console.", fg="green") - return - click.secho(f"\n Workflow registered at {console_url}", fg="green") + if dry_run: + return + console_url = remote.generate_console_url(entity) + if execute: + click.secho(f"\n Execution launched {console_url} to see execution in the console.", fg="green") + return + click.secho(f"\n Workflow registered at {console_url}", fg="green") except StopIteration as e: click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/build.py b/flytekit/clis/sdk_in_container/build.py new file mode 100644 index 0000000000..3e18535268 --- /dev/null +++ b/flytekit/clis/sdk_in_container/build.py @@ -0,0 +1,127 @@ +import os +import pathlib +import typing + +import rich_click as click +from typing_extensions import OrderedDict + +from flytekit.clis.sdk_in_container.constants import CTX_MODULE, CTX_PROJECT_ROOT +from flytekit.clis.sdk_in_container.run import RUN_LEVEL_PARAMS_KEY, get_entities_in_file, load_naive_entity +from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.workflow import PythonFunctionWorkflow +from flytekit.tools.script_mode import _find_project_root +from flytekit.tools.translator import get_serializable + + +def get_workflow_command_base_params() -> typing.List[click.Option]: + """ + Return the set of base parameters added to every pyflyte build workflow subcommand. + """ + return [ + click.Option( + param_decls=["--fast"], + required=False, + is_flag=True, + default=False, + help="Use fast serialization. The image won't contain the source code. The value is false by default.", + ), + ] + + +def build_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): + """ + Returns a function that is used to implement WorkflowCommand and build an image for flyte workflows. + """ + + def _build(*args, **kwargs): + m = OrderedDict() + options = None + run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] + + project, domain = run_level_params.get("project"), run_level_params.get("domain") + serialization_settings = SerializationSettings( + project=project, + domain=domain, + image_config=ImageConfig.auto_default_image(), + ) + if not run_level_params.get("fast"): + serialization_settings.source_root = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT) + + _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) + + return _build + + +class WorkflowCommand(click.MultiCommand): + """ + click multicommand at the python file layer, subcommands should be all the workflows in the file. + """ + + def __init__(self, filename: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self._filename = pathlib.Path(filename).resolve() + + def list_commands(self, ctx): + entities = get_entities_in_file(self._filename.__str__()) + return entities.all() + + def get_command(self, ctx, exe_entity): + """ + This command uses the filename with which this command was created, and the string name of the entity passed + after the Python filename on the command line, to load the Python object, and then return the Command that + click should run. + :param ctx: The click Context object. + :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task + function. + :return: + """ + rel_path = os.path.relpath(self._filename) + if rel_path.startswith(".."): + raise ValueError( + f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}" + ) + + project_root = _find_project_root(self._filename) + rel_path = self._filename.relative_to(project_root) + module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") + + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module + + entity = load_naive_entity(module, exe_entity, project_root) + + cmd = click.Command( + name=exe_entity, + callback=build_command(ctx, entity), + help=f"Build an image for {module}.{exe_entity}.", + ) + return cmd + + +class BuildCommand(click.MultiCommand): + """ + A click command group for building a image for flyte workflows & tasks in a file. + """ + + def __init__(self, *args, **kwargs): + params = get_workflow_command_base_params() + super().__init__(*args, params=params, **kwargs) + + def list_commands(self, ctx): + return [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] + + def get_command(self, ctx, filename): + if ctx.obj: + ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params + return WorkflowCommand(filename, name=filename, help="Build an image for [workflow|task]") + + +_build_help = """ +This command can build an image for a workflow or a task from the command line, for fully self-contained scripts. +""" + +build = BuildCommand( + name="build", + help=_build_help, +) diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index d228babf43..67391abb4d 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -1,4 +1,4 @@ -import click as _click +import rich_click as _click CTX_PROJECT = "project" CTX_DOMAIN = "domain" @@ -10,6 +10,7 @@ CTX_PROJECT_ROOT = "project_root" CTX_MODULE = "module" CTX_VERBOSE = "verbose" +CTX_COPY_ALL = "copy_all" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 6ac451be92..4c66a2046c 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -1,7 +1,7 @@ from dataclasses import replace from typing import Optional -import click +import rich_click as click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE from flytekit.configuration import Config, ImageConfig, get_config_file diff --git a/flytekit/clis/sdk_in_container/init.py b/flytekit/clis/sdk_in_container/init.py index 1ec2f57c32..627b393578 100644 --- a/flytekit/clis/sdk_in_container/init.py +++ b/flytekit/clis/sdk_in_container/init.py @@ -1,4 +1,4 @@ -import click +import rich_click as click from cookiecutter.main import cookiecutter diff --git a/flytekit/clis/sdk_in_container/launchplan.py b/flytekit/clis/sdk_in_container/launchplan.py new file mode 100644 index 0000000000..2d33e2e3d7 --- /dev/null +++ b/flytekit/clis/sdk_in_container/launchplan.py @@ -0,0 +1,74 @@ +import rich_click as click + +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.models.launch_plan import LaunchPlanState + +_launchplan_help = """ +The launchplan command activates or deactivates a specified or the latest version of the launchplan. +If ``--activate`` is chosen then the previous version of the launchplan will be deactivated. + +- ``launchplan`` refers to the name of the Launchplan +- ``launchplan_version`` is optional and should be a valid version for a Launchplan version. If not specified the latest will be used. +""" + + +@click.command("launchplan", help=_launchplan_help) +@click.option( + "-p", + "--project", + required=False, + type=str, + default="flytesnacks", + help="Fecth launchplan from this project", +) +@click.option( + "-d", + "--domain", + required=False, + type=str, + default="development", + help="Fetch launchplan from this domain", +) +@click.option( + "--activate/--deactivate", + required=True, + type=bool, + is_flag=True, + help="Activate or Deactivate the launchplan", +) +@click.argument( + "launchplan", + required=True, + type=str, +) +@click.argument( + "launchplan-version", + required=False, + type=str, + default=None, +) +@click.pass_context +def launchplan( + ctx: click.Context, + project: str, + domain: str, + activate: bool, + launchplan: str, + launchplan_version: str, +): + remote = get_and_save_remote_with_click_context(ctx, project, domain) + try: + launchplan = remote.fetch_launch_plan( + project=project, + domain=domain, + name=launchplan, + version=launchplan_version, + ) + state = LaunchPlanState.ACTIVE if activate else LaunchPlanState.INACTIVE + remote.client.update_launch_plan(id=launchplan.id, state=state) + click.secho( + f"\n Launchplan was set to {LaunchPlanState.enum_to_string(state)}: {launchplan.name}:{launchplan.id.version}", + fg="green", + ) + except StopIteration as e: + click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/local_cache.py b/flytekit/clis/sdk_in_container/local_cache.py index 0dbdc9c621..b0923b842a 100644 --- a/flytekit/clis/sdk_in_container/local_cache.py +++ b/flytekit/clis/sdk_in_container/local_cache.py @@ -1,4 +1,4 @@ -import click +import rich_click as click from flytekit.core.local_cache import LocalTaskCache diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index e457b3d649..f1c6f526e2 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -1,6 +1,6 @@ import os -import click +import rich_click as click from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 5e1136d14c..d9a8fb0c2a 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,18 +1,21 @@ import typing -import click import grpc +import rich_click as click from google.protobuf.json_format import MessageToJson from flytekit import configuration from flytekit.clis.sdk_in_container.backfill import backfill +from flytekit.clis.sdk_in_container.build import build from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES, CTX_VERBOSE from flytekit.clis.sdk_in_container.init import init +from flytekit.clis.sdk_in_container.launchplan import launchplan from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize +from flytekit.clis.sdk_in_container.serve import serve from flytekit.configuration.internal import LocalSDK from flytekit.exceptions.base import FlyteException from flytekit.exceptions.user import FlyteInvalidInputException @@ -36,8 +39,8 @@ def validate_package(ctx, param, values): def pretty_print_grpc_error(e: grpc.RpcError): if isinstance(e, grpc._channel._InactiveRpcError): # noqa - click.secho(f"RPC Failed, with Status: {e.code()}", fg="red") - click.secho(f"\tdetails: {e.details()}", fg="magenta") + click.secho(f"RPC Failed, with Status: {e.code()}", fg="red", bold=True) + click.secho(f"\tdetails: {e.details()}", fg="magenta", bold=True) click.secho(f"\tDebug string {e.debug_error_string()}", dim=True) return @@ -51,12 +54,11 @@ def pretty_print_exception(e: Exception): raise e if isinstance(e, FlyteException): + click.secho(f"Failed with Exception Code: {e._ERROR_CODE}", fg="red") # noqa if isinstance(e, FlyteInvalidInputException): click.secho("Request rejected by the API, due to Invalid input.", fg="red") - click.secho(f"\tReason: {str(e)}", dim=True) click.secho(f"\tInput Request: {MessageToJson(e.request)}", dim=True) - return - click.secho(f"Failed with Exception: Reason: {e._ERROR_CODE}", fg="red") # noqa + cause = e.__cause__ if cause: if isinstance(cause, grpc.RpcError): @@ -72,7 +74,7 @@ def pretty_print_exception(e: Exception): click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa -class ErrorHandlingCommand(click.Group): +class ErrorHandlingCommand(click.RichGroup): def invoke(self, ctx: click.Context) -> typing.Any: try: return super().invoke(ctx) @@ -132,6 +134,9 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: bool): main.add_command(run) main.add_command(register) main.add_command(backfill) +main.add_command(serve) +main.add_command(build) +main.add_command(launchplan) main.epilog if __name__ == "__main__": diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 30c955e351..afc7aeb99e 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -1,7 +1,7 @@ import os import typing -import click +import rich_click as click from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 136831c0bc..9c7228ec46 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -9,7 +9,8 @@ from dataclasses import dataclass from typing import cast -import click +import rich_click as click +import yaml from dataclasses_json import DataClassJsonMixin from pytimeparse import parse from typing_extensions import get_args @@ -17,6 +18,7 @@ from flytekit import BlobType, Literal, Scalar from flytekit.clis.sdk_in_container.constants import ( CTX_CONFIG_FILE, + CTX_COPY_ALL, CTX_DOMAIN, CTX_MODULE, CTX_PROJECT, @@ -37,7 +39,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.models import literals from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union +from flytekit.models.literals import Blob, BlobMetadata, LiteralCollection, LiteralMap, Primitive, Union from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader, script_mode @@ -55,10 +57,6 @@ def remove_prefix(text, prefix): return text -class JsonParamType(click.ParamType): - name = "json object" - - @dataclass class Directory(object): dir_path: str @@ -81,7 +79,7 @@ def convert( raise ValueError( f"Currently only directories containing one file are supported, found [{len(files)}] files found in {p.resolve()}" ) - return Directory(dir_path=value, local_file=files[0].resolve()) + return Directory(dir_path=str(p), local_file=files[0].resolve()) raise click.BadParameter(f"parameter should be a valid directory path, {value}") @@ -134,6 +132,33 @@ def convert( return datetime.timedelta(seconds=parse(value)) +class JsonParamType(click.ParamType): + name = "json object OR json/yaml file path" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value is None: + raise click.BadParameter("None value cannot be converted to a Json type.") + if type(value) == dict or type(value) == list: + return value + try: + return json.loads(value) + except Exception: # noqa + try: + # We failed to load the json, so we'll try to load it as a file + if os.path.exists(value): + # if the value is a yaml file, we'll try to load it as yaml + if value.endswith(".yaml") or value.endswith(".yml"): + with open(value, "r") as f: + return yaml.safe_load(f) + with open(value, "r") as f: + return json.load(f) + raise + except json.JSONDecodeError as e: + raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}") + + @dataclass class DefaultConverter(object): click_type: click.ParamType @@ -215,16 +240,18 @@ def is_bool(self) -> bool: return self._literal_type.simple == SimpleType.BOOLEAN return False - def get_uri_for_dir(self, value: Directory, remote_filename: typing.Optional[str] = None): + def get_uri_for_dir( + self, ctx: typing.Optional[click.Context], value: Directory, remote_filename: typing.Optional[str] = None + ): uri = value.dir_path if self._remote and value.local: md5, _ = script_mode.hash_file(value.local_file) if not remote_filename: remote_filename = value.local_file.name - df_remote_location = self._create_upload_fn(filename=remote_filename, content_md5=md5) - self._flyte_ctx.file_access.put_data(value.local_file, df_remote_location.signed_url) - uri = df_remote_location.native_url[: -len(remote_filename)] + remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] + _, native_url = remote.upload_file(value.local_file) + uri = native_url[: -len(remote_filename)] return uri @@ -232,7 +259,7 @@ def convert_to_structured_dataset( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory ) -> Literal: - uri = self.get_uri_for_dir(value, "00000.parquet") + uri = self.get_uri_for_dir(ctx, value, "00000.parquet") lit = Literal( scalar=Scalar( @@ -254,15 +281,13 @@ def convert_to_blob( value: typing.Union[Directory, FileParam], ) -> Literal: if isinstance(value, Directory): - uri = self.get_uri_for_dir(value) + uri = self.get_uri_for_dir(ctx, value) else: uri = value.filepath if self._remote and value.local: fp = pathlib.Path(value.filepath) - md5, _ = script_mode.hash_file(value.filepath) - df_remote_location = self._create_upload_fn(filename=fp.name, content_md5=md5) - self._flyte_ctx.file_access.put_data(fp, df_remote_location.signed_url) - uri = df_remote_location.native_url + remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] + _, uri = remote.upload_file(fp) lit = Literal( scalar=Scalar( @@ -299,6 +324,68 @@ def convert_to_union( logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") + def convert_to_list( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: list + ) -> Literal: + """ + Convert a python list into a Flyte Literal + """ + if not value: + raise click.BadParameter("Expected non-empty list") + if not isinstance(value, list): + raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(value)}") + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + self._literal_type.collection_type, + type(value[0]), + self._create_upload_fn, + ) + lt = Literal(collection=LiteralCollection([])) + for v in value: + click_val = converter._click_type.convert(v, param, ctx) + lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val)) + return lt + + def convert_to_map( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: dict + ) -> Literal: + """ + Convert a python dict into a Flyte Literal. + It is assumed that the click parameter type is a JsonParamType. The map is also assumed to be univariate. + """ + if not value: + raise click.BadParameter("Expected non-empty dict") + if not isinstance(value, dict): + raise click.BadParameter(f"Expected json dict '{{...}}', parsed value is {type(value)}") + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + self._literal_type.map_value_type, + type(value[list(value.keys())[0]]), + self._create_upload_fn, + ) + lt = Literal(map=LiteralMap({})) + for k, v in value.items(): + click_val = converter._click_type.convert(v, param, ctx) + lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val) + return lt + + def convert_to_struct( + self, + ctx: typing.Optional[click.Context], + param: typing.Optional[click.Parameter], + value: typing.Union[dict, typing.Any], + ) -> Literal: + """ + Convert the loaded json object to a Flyte Literal struct type. + """ + if type(value) != self._python_type: + o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) + else: + o = value + return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) + def convert_to_literal( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any ) -> Literal: @@ -308,30 +395,18 @@ def convert_to_literal( if self._literal_type.blob: return self.convert_to_blob(ctx, param, value) - if self._literal_type.collection_type or self._literal_type.map_value_type: - # TODO Does not support nested flytefile, flyteschema types - v = json.loads(value) if isinstance(value, str) else value - if self._literal_type.collection_type and not isinstance(v, list): - raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}") - if self._literal_type.map_value_type and not isinstance(v, dict): - raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(v)) - return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type) + if self._literal_type.collection_type: + return self.convert_to_list(ctx, param, value) + + if self._literal_type.map_value_type: + return self.convert_to_map(ctx, param, value) if self._literal_type.union_type: return self.convert_to_union(ctx, param, value) if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - if self._python_type == dict: - if type(value) != str: - # The type of default value is dict, so we have to convert it to json string - value = json.dumps(value) - o = json.loads(value) - elif type(value) != self._python_type: - o = cast(DataClassJsonMixin, self._python_type).from_json(value) - else: - o = value - return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) + return self.convert_to_struct(ctx, param, value) return Literal(scalar=self._converter.convert(value, self._python_type)) if self._literal_type.schema: @@ -342,10 +417,15 @@ def convert_to_literal( ) def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]: - lit = self.convert_to_literal(ctx, param, value) - if not self._remote: - return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) - return lit + try: + lit = self.convert_to_literal(ctx, param, value) + if not self._remote: + return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) + return lit + except click.BadParameter: + raise + except Exception as e: + raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e def to_click_option( @@ -368,6 +448,13 @@ def to_click_option( if literal_converter.is_bool() and not default_val: default_val = False + if literal_var.type.simple == SimpleType.STRUCT: + if default_val: + if type(default_val) == dict or type(default_val) == list: + default_val = json.dumps(default_val) + else: + default_val = cast(DataClassJsonMixin, default_val).to_json() + return click.Option( param_decls=[f"--{input_name}"], type=literal_converter.click_type, @@ -426,6 +513,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: default="/root", help="Directory inside the image where the tar file containing the code will be copied to", ), + click.Option( + param_decls=["--copy-all", "copy_all"], + required=False, + is_flag=True, + default=False, + help="Copy all files in the source root directory to the destination directory", + ), click.Option( param_decls=["-i", "--image", "image_config"], required=False, @@ -557,6 +651,7 @@ def _run(*args, **kwargs): destination_dir=run_level_params.get("destination_dir"), source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), + copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL), ) options = None @@ -586,7 +681,7 @@ def _run(*args, **kwargs): return _run -class WorkflowCommand(click.MultiCommand): +class WorkflowCommand(click.RichGroup): """ click multicommand at the python file layer, subcommands should be all the workflows in the file. """ @@ -654,7 +749,7 @@ def get_command(self, ctx, exe_entity): return cmd -class RunCommand(click.MultiCommand): +class RunCommand(click.RichGroup): """ A click command group for registering and executing flyte workflows & tasks in a file. """ diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index eef055ad3c..0c328248e5 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -3,7 +3,7 @@ import typing from enum import Enum as _Enum -import click +import rich_click as click from flytekit.clis.sdk_in_container import constants from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES @@ -69,7 +69,7 @@ def serialize_all( serialize_to_folder(pkgs, serialization_settings, local_source_root, folder) -@click.group("serialize") +@click.group("serialize", cls=click.RichGroup) @click.option( "--image", required=False, @@ -124,7 +124,7 @@ def serialize(ctx, image, local_source_root, in_container_config_path, in_contai ctx.obj[CTX_PYTHON_INTERPRETER] = sys.executable -@click.command("workflows") +@click.command("workflows", cls=click.RichCommand) # For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the # directory for you so it shouldn't be a problem. @click.option("-f", "--folder", type=click.Path(exists=True)) @@ -148,13 +148,13 @@ def workflows(ctx, folder=None): ) -@click.group("fast") +@click.group("fast", cls=click.RichGroup) @click.pass_context def fast(ctx): pass -@click.command("workflows") +@click.command("workflows", cls=click.RichCommand) @click.option( "--deref-symlinks", default=False, diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py new file mode 100644 index 0000000000..71b539d36c --- /dev/null +++ b/flytekit/clis/sdk_in_container/serve.py @@ -0,0 +1,46 @@ +from concurrent import futures + +import click +import grpc +from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server + +from flytekit.extend.backend.external_plugin_service import BackendPluginServer + +_serve_help = """Start a grpc server for the external plugin service.""" + + +@click.command("serve", help=_serve_help) +@click.option( + "--port", + default="8000", + is_flag=False, + type=int, + help="Grpc port for the external plugin service", +) +@click.option( + "--worker", + default="10", + is_flag=False, + type=int, + help="Number of workers for the grpc server", +) +@click.option( + "--timeout", + default=None, + is_flag=False, + type=int, + help="It will wait for the specified number of seconds before shutting down grpc server. It should only be used " + "for testing.", +) +@click.pass_context +def serve(_: click.Context, port, worker, timeout): + """ + Start a grpc server for the external plugin service. + """ + click.secho("Starting the external plugin service...", fg="blue") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker)) + add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server) + + server.add_insecure_port(f"[::]:{port}") + server.start() + server.wait_for_termination(timeout=timeout) diff --git a/tests/flytekit/unit/extras/persistence/__init__.py b/flytekit/clis/sdk_in_container/utils.py similarity index 100% rename from tests/flytekit/unit/extras/persistence/__init__.py rename to flytekit/clis/sdk_in_container/utils.py diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index afed857a26..cd41708bff 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -143,12 +143,14 @@ from io import BytesIO from typing import Dict, List, Optional +import yaml from dataclasses_json import dataclass_json -from docker_image import reference from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.image_spec import ImageSpec +from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.loggers import logger PROJECT_PLACEHOLDER = "{{ registration.project }}" @@ -205,6 +207,15 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image :param Text tag: e.g. somedocker.com/myimage:someversion123 :rtype: Text """ + from docker_image import reference + + if pathlib.Path(tag).is_file(): + with open(tag, "r") as f: + image_spec_dict = yaml.safe_load(f) + image_spec = ImageSpec(**image_spec_dict) + ImageBuildEngine.build(image_spec) + tag = image_spec.image_name() + ref = reference.Reference.parse(tag) if not optional_tag and ref["tag"] is None: raise AssertionError(f"Incorrectly formatted image {tag}, missing tag value") @@ -344,6 +355,7 @@ class AuthType(enum.Enum): CLIENTSECRET = "ClientSecret" PKCE = "Pkce" EXTERNALCOMMAND = "ExternalCommand" + DEVICEFLOW = "DeviceFlow" @dataclass(init=True, repr=True, eq=True, frozen=True) @@ -352,7 +364,7 @@ class PlatformConfig(object): This object contains the settings to talk to a Flyte backend (the DNS location of your Admin server basically). :param endpoint: DNS for Flyte backend - :param insecure: Whether to use SSL + :param insecure: Whether or not to use SSL :param insecure_skip_verify: Whether to skip SSL certificate verification :param console_endpoint: endpoint for console if different from Flyte backend :param command: This command is executed to return a token using an external process @@ -376,6 +388,7 @@ class PlatformConfig(object): client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD + audience: typing.Optional[str] = None rpc_retries: int = 3 @classmethod @@ -697,6 +710,7 @@ class SerializationSettings(object): fast_serialization_settings (Optional[FastSerializationSettings]): If the code is being serialized so that it can be fast registered (and thus omit building a Docker image) this object contains additional parameters for serialization. + source_root (Optional[str]): The root directory of the source code. """ image_config: ImageConfig @@ -708,6 +722,7 @@ class SerializationSettings(object): python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None + source_root: Optional[str] = None def __post_init__(self): if self.flytekit_virtualenv_root is None: @@ -781,6 +796,7 @@ def new_builder(self) -> Builder: flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, + source_root=self.source_root, ) def should_fast_serialize(self) -> bool: @@ -831,6 +847,7 @@ class Builder(object): flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None + source_root: Optional[str] = None def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: self.fast_serialization_settings = fss @@ -847,4 +864,5 @@ def build(self) -> SerializationSettings: flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, + source_root=self.source_root, ) diff --git a/flytekit/configuration/default_images.py b/flytekit/configuration/default_images.py index 8c01041eed..625e69d9ae 100644 --- a/flytekit/configuration/default_images.py +++ b/flytekit/configuration/default_images.py @@ -30,14 +30,19 @@ def default_image(cls) -> str: def find_image_for( cls, python_version: typing.Optional[PythonVersion] = None, flytekit_version: typing.Optional[str] = None ) -> str: + if python_version is None: + python_version = PythonVersion((sys.version_info.major, sys.version_info.minor)) + + return cls._DEFAULT_IMAGE_PREFIXES[python_version] + ( + flytekit_version.replace("v", "") if flytekit_version else cls.get_version_suffix() + ) + + @classmethod + def get_version_suffix(cls) -> str: from flytekit import __version__ if not __version__ or __version__ == "0.0.0+develop": version_suffix = "latest" else: version_suffix = __version__ - if python_version is None: - python_version = PythonVersion((sys.version_info.major, sys.version_info.minor)) - return cls._DEFAULT_IMAGE_PREFIXES[python_version] + ( - flytekit_version.replace("v", "") if flytekit_version else version_suffix - ) + return version_suffix diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 7fcdc15a50..30b73223a9 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Optional, Tuple, Type, TypeVar from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface @@ -22,11 +22,11 @@ def __init__( self, name: str, query_template: str, + task_config: Optional[T] = None, task_type="sql_task", - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - task_config: Optional[T] = None, - outputs: Dict[str, Type] = None, + outputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 2cf8032a6f..1f9e27f735 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -21,10 +21,16 @@ import datetime from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities +from flytekit.core.context_manager import ( + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, + FlyteEntities, +) from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache from flytekit.core.promise import ( @@ -37,8 +43,8 @@ translate_inputs_to_literals, ) from flytekit.core.tracker import TrackedInstance -from flytekit.core.type_engine import TypeEngine -from flytekit.deck.deck import Deck +from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError +from flytekit.core.utils import timeit from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import interface as _interface_models @@ -156,7 +162,7 @@ def __init__( self, task_type: str, name: str, - interface: Optional[_interface_models.TypedInterface] = None, + interface: _interface_models.TypedInterface, metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, @@ -174,7 +180,7 @@ def __init__( FlyteEntities.entities.append(self) @property - def interface(self) -> Optional[_interface_models.TypedInterface]: + def interface(self) -> _interface_models.TypedInterface: return self._interface @property @@ -239,12 +245,17 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # Promises as essentially inputs from previous task executions # native constants are just bound to this specific task (default values for a task input) # Also along with promises and constants, there could be dictionary or list of promises or constants - kwargs = translate_inputs_to_literals( - ctx, - incoming_values=kwargs, - flyte_interface_types=self.interface.inputs, # type: ignore - native_types=self.get_input_types(), - ) + try: + kwargs = translate_inputs_to_literals( + ctx, + incoming_values=kwargs, + flyte_interface_types=self.interface.inputs, + native_types=self.get_input_types(), # type: ignore + ) + except TypeTransformerFailedError as exc: + msg = f"Failed to convert inputs of task '{self.name}':\n {exc}" + logger.error(msg) + raise TypeError(msg) from exc input_literal_map = _literal_models.LiteralMap(literals=kwargs) # if metadata.cache is set, check memoized version @@ -289,8 +300,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs): - return flyte_entity_call_handler(self, *args, **kwargs) + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): raise Exception("not implemented") @@ -334,8 +345,8 @@ def sandbox_execute( """ Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. """ - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() + es = cast(ExecutionState, ctx.execution_state) + b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox() ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() return self.dispatch_execute(ctx, input_literal_map) @@ -384,7 +395,7 @@ def __init__( self, task_type: str, name: str, - task_config: T, + task_config: Optional[T], interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, disable_deck: bool = True, @@ -421,9 +432,13 @@ def __init__( ) else: if self._python_interface.docstring.short_description: - self._docs.short_description = self._python_interface.docstring.short_description + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description if self._python_interface.docstring.long_description: - self._docs.long_description = Description(value=self._python_interface.docstring.long_description) + cast(Documentation, self._docs).long_description = Description( + value=self._python_interface.docstring.long_description + ) # TODO lets call this interface and the other as flyte_interface? @property @@ -434,25 +449,25 @@ def python_interface(self) -> Interface: return self._python_interface @property - def task_config(self) -> T: + def task_config(self) -> Optional[T]: """ Returns the user-specified task config which is used for plugin-specific handling of the task. """ return self._task_config - def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for an input variable by name. """ return self._python_interface.inputs[k] - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for the specified output variable by name. """ return self._python_interface.outputs[k] - def get_input_types(self) -> Optional[Dict[str, type]]: + def get_input_types(self) -> Dict[str, type]: """ Returns the names and python types as a dictionary for the inputs of this task. """ @@ -498,18 +513,28 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) # type: ignore ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now # Translate the input literals to Python native - native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, self.python_interface.inputs) + try: + native_inputs = TypeEngine.literal_map_to_kwargs( + exec_ctx, input_literal_map, self.python_interface.inputs + ) + except Exception as exc: + msg = f"Failed to convert inputs of task '{self.name}':\n {exc}" + logger.error(msg) + raise type(exc)(msg) from exc # TODO: Logger should auto inject the current context information to indicate if the task is running within # a workflow or a subworkflow etc logger.info(f"Invoking {self.name} with inputs: {native_inputs}") try: - native_outputs = self.execute(**native_inputs) + with timeit("Execute user level code"): + native_outputs = self.execute(**native_inputs) except Exception as e: logger.exception(f"Exception when executing {e}") raise e @@ -546,22 +571,26 @@ def dispatch_execute( # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption # built into the IDL that all the values of a literal map are of the same type. - literals = {} - for k, v in native_outputs_as_map.items(): - literal_type = self._outputs_interface[k].type - py_type = self.get_type_for_output_var(k, v) - - if isinstance(v, tuple): - raise TypeError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}") - try: - literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) - except Exception as e: - logger.error(f"Failed to convert return value for var {k} with error {type(e)}: {e}") - raise TypeError( - f"Failed to convert return value for var {k} for function {self.name} with error {type(e)}: {e}" - ) from e + with timeit("Translate the output to literals"): + literals = {} + for i, (k, v) in enumerate(native_outputs_as_map.items()): + literal_type = self._outputs_interface[k].type + py_type = self.get_type_for_output_var(k, v) + + if isinstance(v, tuple): + raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") + try: + literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) + except Exception as e: + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = k if k != f"o{i}" else i + msg = f"Failed to convert outputs of task '{self.name}' at position {key}:\n {e}" + logger.error(msg) + raise TypeError(msg) from e if self._disable_deck is False: + from flytekit.deck.deck import Deck + INPUT = "input" OUTPUT = "output" @@ -579,7 +608,7 @@ def dispatch_execute( # After the execute has been successfully completed return outputs_literal_map - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore """ This is the method that will be invoked directly before executing the task method and before all the inputs are converted. One particular case where this is useful is if the context is to be modified for the user process @@ -597,7 +626,7 @@ def execute(self, **kwargs) -> Any: """ pass - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any: """ Post execute is called after the execution has completed, with the user_params and can be used to clean-up, or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index c1eb933ec6..4b4cfd16f3 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -126,7 +126,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): fa.upload_directory(str(cp), self._checkpoint_dest) else: fname = cp.stem + cp.suffix - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, fname) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), fname]) fa.upload(str(cp), rpath) return @@ -138,7 +138,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): with dest_cp.open("wb") as f: f.write(cp.read()) - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, self.TMP_DST_PATH) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), self.TMP_DST_PATH]) fa.upload(str(dest_cp), rpath) def read(self) -> typing.Optional[bytes]: diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py index d47820f811..49970d5623 100644 --- a/flytekit/core/class_based_resolver.py +++ b/flytekit/core/class_based_resolver.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): def name(self) -> str: return "ClassStorageTaskResolver" - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore return self.mapping def add(self, t: PythonAutoContainerTask): @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: idx = int(loader_args[0]) return self.mapping[idx] - def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore """ This is responsible for turning an instance of a task into args that the load_task function can reconstitute. """ diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index b5cae86923..76553db702 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP return self._compute_outputs(n) return self._condition - def if_(self, expr: bool) -> Case: + def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case: return self._condition._if(expr) def compute_output_vars(self) -> typing.Optional[typing.List[str]]: @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str: return f"{node_id}.{var}" -def merge_promises(*args: Promise) -> typing.List[Promise]: +def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]: node_vars: typing.Set[typing.Tuple[str, str]] = set() merged_promises: typing.List[Promise] = [] for p in args: @@ -414,7 +414,7 @@ def transform_to_boolexpr( def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]: - expr, promises = transform_to_boolexpr(c.expr) + expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr)) n = c.output_promise.ref.node # type: ignore return _core_wf.IfBlock(condition=expr, then_node=n), promises diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index d470fb54fe..fd604004d6 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,16 +1,18 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + -# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -36,17 +38,19 @@ def __init__( name: str, image: str, command: List[str], - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - arguments: List[str] = None, - outputs: Dict[str, Type] = None, + arguments: Optional[List[str]] = None, + outputs: Optional[Dict[str, Type]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, - input_data_dir: str = None, - output_data_dir: str = None, + input_data_dir: Optional[str] = None, + output_data_dir: Optional[str] = None, metadata_format: MetadataFormat = MetadataFormat.JSON, - io_strategy: IOStrategy = None, + io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional["PodTemplate"] = None, + pod_template_name: Optional[str] = None, **kwargs, ): sec_ctx = None @@ -55,6 +59,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metadata.pod_template_name + metadata = metadata or TaskMetadata() + metadata.pod_template_name = pod_template_name + super().__init__( task_type="raw-container", name=name, @@ -74,6 +83,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) + self.pod_template = pod_template @property def resources(self) -> ResourceSpec: @@ -91,19 +101,29 @@ def execute(self, **kwargs) -> Any: return None def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + + return self._get_container(settings) + + def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: + return _task_model.DataLoadingConfig( + input_path=self._input_data_dir, + output_path=self._output_data_dir, + format=self._md_format.value, + enabled=True, + io_strategy=self._io_strategy.value if self._io_strategy else None, + ) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env return _get_container_definition( image=self._image, command=self._cmd, args=self._args, - data_loading_config=_task_model.DataLoadingConfig( - input_path=self._input_data_dir, - output_path=self._output_data_dir, - format=self._md_format.value, - enabled=True, - io_strategy=self._io_strategy.value if self._io_strategy else None, - ), + data_loading_config=self._get_data_loading_config(), environment=env, storage_request=self.resources.requests.storage, ephemeral_storage_request=self.resources.requests.ephemeral_storage, @@ -116,3 +136,20 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + data_config=self._get_data_loading_config(), + ) + + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 7e4600b3bb..e2923bfc7f 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -27,7 +27,6 @@ from enum import Enum from typing import Generator, List, Optional, Union -from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint @@ -39,7 +38,8 @@ from flytekit.models.core import identifier as _identifier if typing.TYPE_CHECKING: - from flytekit.deck.deck import Deck + from flytekit import Deck + from flytekit.clients import friendly as friendly_client # noqa # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -48,7 +48,7 @@ flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[]) if typing.TYPE_CHECKING: - from flytekit.core.base_task import TaskResolverMixin + from flytekit.core.base_task import Task, TaskResolverMixin # Identifier fields use placeholders for registration-time substitution. @@ -84,7 +84,7 @@ class Builder(object): decks: List[Deck] raw_output_prefix: Optional[str] = None execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] = None - working_dir: typing.Optional[utils.AutoDeletingTempDir] = None + working_dir: typing.Optional[str] = None checkpoint: typing.Optional[Checkpoint] = None execution_date: typing.Optional[datetime] = None logging: Optional[_logging.Logger] = None @@ -108,7 +108,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: def build(self) -> ExecutionParameters: if not isinstance(self.working_dir, utils.AutoDeletingTempDir): - pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True) return ExecutionParameters( execution_date=self.execution_date, stats=self.stats, @@ -123,14 +123,14 @@ def build(self) -> ExecutionParameters: ) @staticmethod - def new_builder(current: ExecutionParameters = None) -> Builder: + def new_builder(current: Optional[ExecutionParameters] = None) -> Builder: return ExecutionParameters.Builder(current=current) def with_task_sandbox(self) -> Builder: prefix = self.working_directory if isinstance(self.working_directory, utils.AutoDeletingTempDir): prefix = self.working_directory.name - task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) + task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore p = pathlib.Path(task_sandbox_dir) cp_dir = p.joinpath("__cp") cp_dir.mkdir(exist_ok=True) @@ -202,12 +202,10 @@ def raw_output_prefix(self) -> str: return self._raw_output_prefix @property - def working_directory(self) -> utils.AutoDeletingTempDir: + def working_directory(self) -> str: """ A handle to a special working directory for easily producing temporary files. - TODO: Usage examples - TODO: This does not always return a AutoDeletingTempDir """ return self._working_directory @@ -264,10 +262,24 @@ def decks(self) -> typing.List: @property def default_deck(self) -> Deck: - from flytekit.deck.deck import Deck + from flytekit import Deck return Deck("default") + @property + def timeline_deck(self) -> "TimeLineDeck": # type: ignore + from flytekit.deck.deck import TimeLineDeck + + time_line_deck = None + for deck in self.decks: + if isinstance(deck, TimeLineDeck): + time_line_deck = deck + break + if time_line_deck is None: + time_line_deck = TimeLineDeck("Timeline") + + return time_line_deck + def __getattr__(self, attr_name: str) -> typing.Any: """ This houses certain task specific context. For example in Spark, it houses the SparkSession, etc @@ -287,7 +299,7 @@ def get(self, key: str) -> typing.Any: """ Returns task specific context if present else raise an error. The returned context will match the key """ - return self.__getattr__(attr_name=key) + return self.__getattr__(attr_name=key) # type: ignore class SecretsManager(object): @@ -331,13 +343,13 @@ def __getattr__(self, item: str) -> _GroupSecrets: """ return self._GroupSecrets(item, self) - def get(self, group: str, key: str) -> str: + def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError """ - self.check_group_key(group, key) - env_var = self.get_secrets_env_var(group, key) - fpath = self.get_secrets_file(group, key) + self.check_group_key(group) + env_var = self.get_secrets_env_var(group, key, group_version) + fpath = self.get_secrets_file(group, key, group_version) v = os.environ.get(env_var) if v is not None: return v @@ -348,26 +360,27 @@ def get(self, group: str, key: str) -> str: f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" ) - def get_secrets_env_var(self, group: str, key: str) -> str: + def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Returns a string that matches the ENV Variable to look for the secrets """ - self.check_group_key(group, key) - return f"{self._env_prefix}{group.upper()}_{key.upper()}" + self.check_group_key(group) + l = [k.upper() for k in filter(None, (group, group_version, key))] + return f"{self._env_prefix}{'_'.join(l)}" - def get_secrets_file(self, group: str, key: str) -> str: + def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Returns a path that matches the file to look for the secrets """ - self.check_group_key(group, key) - return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}") + self.check_group_key(group) + l = [k.lower() for k in filter(None, (group, group_version, key))] + l[-1] = f"{self._file_prefix}{l[-1]}" + return os.path.join(self._base_dir, *l) @staticmethod - def check_group_key(group: str, key: str): + def check_group_key(group: str): if group is None or group == "": raise ValueError("secrets group is a mandatory field.") - if key is None or key == "": - raise ValueError("secrets key is a mandatory field.") @dataclass(frozen=True) @@ -467,14 +480,14 @@ class Mode(Enum): LOCAL_TASK_EXECUTION = 3 mode: Optional[ExecutionState.Mode] - working_dir: os.PathLike + working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] branch_eval_mode: Optional[BranchEvalMode] user_space_params: Optional[ExecutionParameters] def __init__( self, - working_dir: os.PathLike, + working_dir: Union[os.PathLike, str], mode: Optional[ExecutionState.Mode] = None, engine_dir: Optional[Union[os.PathLike, str]] = None, branch_eval_mode: Optional[BranchEvalMode] = None, @@ -538,7 +551,7 @@ class FlyteContext(object): file_access: FileAccessProvider level: int = 0 - flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None + flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None serialization_settings: Optional[SerializationSettings] = None @@ -607,7 +620,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params) @staticmethod - def current_context() -> Optional[FlyteContext]: + def current_context() -> FlyteContext: """ This method exists only to maintain backwards compatibility. Please use ``FlyteContextManager.current_context()`` instead. @@ -639,7 +652,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig """ from flytekit.deck.deck import _get_deck - return _get_deck(self.execution_state.user_space_params) + return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params) @dataclass class Builder(object): @@ -647,7 +660,7 @@ class Builder(object): level: int = 0 compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None - flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None + flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False @@ -726,7 +739,7 @@ class FlyteContextManager(object): FlyteContextManager manages the execution context within Flytekit. It holds global state of either compilation or Execution. It is not thread-safe and can only be run as a single threaded application currently. Context's within Flytekit is useful to manage compilation state and execution state. Refer to ``CompilationState`` - and ``ExecutionState`` for for information. FlyteContextManager provides a singleton stack to manage these contexts. + and ``ExecutionState`` for more information. FlyteContextManager provides a singleton stack to manage these contexts. Typical usage is @@ -852,7 +865,7 @@ class FlyteEntities(object): registration process """ - entities = [] + entities: List[Union["LaunchPlan", Task, "WorkflowBase"]] = [] # type: ignore FlyteContextManager.initialize() diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index d407b3528b..2bd86ce896 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -14,303 +14,51 @@ :template: custom.rst :nosignatures: - DataPersistence - DataPersistencePlugins - DiskPersistence FileAccessProvider - UnsupportedPersistenceOp """ - import os import pathlib -import re -import shutil -import sys import tempfile import typing -from abc import abstractmethod -from shutil import copyfile -from typing import Dict, Union +from typing import Any, Dict, Union, cast from uuid import UUID +import fsspec +from fsspec.utils import get_protocol + +from flytekit import configuration from flytekit.configuration import DataConfig -from flytekit.core.utils import PerformanceTimer -from flytekit.exceptions.user import FlyteAssertion, FlyteValueException +from flytekit.core.utils import timeit +from flytekit.exceptions.user import FlyteAssertion from flytekit.interfaces.random import random from flytekit.loggers import logger -CURRENT_PYTHON = sys.version_info[:2] -THREE_SEVEN = (3, 7) - - -class UnsupportedPersistenceOp(Exception): - """ - This exception is raised for all methods when a method is not supported by the data persistence layer - """ - - def __init__(self, message: str): - super(UnsupportedPersistenceOp, self).__init__(message) - - -class DataPersistence(object): - """ - Base abstract type for all DataPersistence operations. This can be extended using the flytekitplugins architecture - """ - - def __init__(self, name: str, default_prefix: typing.Optional[str] = None, **kwargs): - self._name = name - self._default_prefix = default_prefix - - @property - def name(self) -> str: - return self._name - - @property - def default_prefix(self) -> typing.Optional[str]: - return self._default_prefix - - def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: - """ - Returns true if the given path exists, else false - """ - raise UnsupportedPersistenceOp(f"Listing a directory is not supported by the persistence plugin {self.name}") - - @abstractmethod - def exists(self, path: str) -> bool: - """ - Returns true if the given path exists, else false - """ - pass - - @abstractmethod - def get(self, from_path: str, to_path: str, recursive: bool = False): - """ - Retrieves data from from_path and writes to the given to_path (to_path is locally accessible) - """ - pass - - @abstractmethod - def put(self, from_path: str, to_path: str, recursive: bool = False): - """ - Stores data from from_path and writes to the given to_path (from_path is locally accessible) - """ - pass - - @abstractmethod - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: - """ - if add_protocol is true then is prefixed else - Constructs a path in the format *args - delim is dependent on the storage medium. - each of the args is joined with the delim - """ - pass - - -class DataPersistencePlugins(object): - """ - DataPersistencePlugins is the core plugin registry that stores all DataPersistence plugins. To add a new plugin use - - .. code-block:: python - - DataPersistencePlugins.register_plugin("s3:/", DataPersistence(), force=True|False) - - These plugins should always be registered. Follow the plugin registration guidelines to auto-discover your plugins. - """ - - _PLUGINS: Dict[str, typing.Type[DataPersistence]] = {} - - @classmethod - def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], force: bool = False): - """ - Registers the supplied plugin for the specified protocol if one does not already exist. - If one exists and force is default or False, then a TypeError is raised. - If one does not exist then it is registered - If one exists, but force == True then the existing plugin is overridden - """ - if protocol in cls._PLUGINS: - p = cls._PLUGINS[protocol] - if p == plugin: - return - if not force: - raise TypeError( - f"Cannot register plugin {plugin.name} for protocol {protocol} as plugin {p.name} is already" - f" registered for the same protocol. You can force register the new plugin by passing force=True" - ) - - cls._PLUGINS[protocol] = plugin - - @staticmethod - def get_protocol(url: str): - # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391 - parts = re.split(r"(\:\:|\://)", url, 1) - if len(parts) > 1: - return parts[0] - logger.info("Setting protocol to file") - return "file" - - @classmethod - def find_plugin(cls, path: str) -> typing.Type[DataPersistence]: - """ - Returns a plugin for the given protocol, else raise a TypeError - """ - for k, p in cls._PLUGINS.items(): - if cls.get_protocol(path) == k.replace("://", "") or path.startswith(k): - return p - raise TypeError(f"No plugin found for matching protocol of path {path}") - - @classmethod - def print_all_plugins(cls): - """ - Prints all the plugins and their associated protocoles - """ - for k, p in cls._PLUGINS.items(): - print(f"Plugin {p.name} registered for protocol {k}") - - @classmethod - def is_supported_protocol(cls, protocol: str) -> bool: - """ - Returns true if the given protocol is has a registered plugin for it - """ - return protocol in cls._PLUGINS - - @classmethod - def supported_protocols(cls) -> typing.List[str]: - return [k for k in cls._PLUGINS.keys()] - - -class DiskPersistence(DataPersistence): - """ - The simplest form of persistence that is available with default flytekit - Disk-based persistence. - This will store all data locally and retrieve the data from local. This is helpful for local execution and simulating - runs. - """ - - PROTOCOL = "file://" +# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 +# for key and secret +_FSSPEC_S3_KEY_ID = "key" +_FSSPEC_S3_SECRET = "secret" +_ANON = "anon" - def __init__(self, default_prefix: typing.Optional[str] = None, **kwargs): - super().__init__(name="local", default_prefix=default_prefix, **kwargs) - @staticmethod - def _make_local_path(path): - if not os.path.exists(path): - try: - pathlib.Path(path).mkdir(parents=True, exist_ok=True) - except OSError: # Guard against race condition - if not os.path.isdir(path): - raise +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): + kwargs: Dict[str, Any] = { + "cache_regions": True, + } + if s3_cfg.access_key_id: + kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - @staticmethod - def strip_file_header(path: str) -> str: - """ - Drops file:// if it exists from the file - """ - if path.startswith("file://"): - return path.replace("file://", "", 1) - return path + if s3_cfg.secret_access_key: + kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key - def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: - if not recursive: - files = os.listdir(self.strip_file_header(path)) - for f in files: - yield f - return - - for root, subdirs, files in os.walk(self.strip_file_header(path)): - for f in files: - yield os.path.join(root, f) - return - - def exists(self, path: str): - return os.path.exists(self.strip_file_header(path)) - - def copy_tree(self, from_path: str, to_path: str): - # TODO: Remove this code after support for 3.7 is dropped and inline this function back - # 3.7 doesn't have dirs_exist_ok - if CURRENT_PYTHON == THREE_SEVEN: - tp = pathlib.Path(self.strip_file_header(to_path)) - if tp.exists(): - if not tp.is_dir(): - raise FlyteValueException(tp, f"Target {tp} exists but is not a dir") - files = os.listdir(tp) - if len(files) != 0: - logger.debug(f"Deleting existing target dir {tp} with files {files}") - shutil.rmtree(tp) - shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path)) - else: - # copytree will overwrite existing files in the to_path - shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True) + # S3fs takes this as a special arg + if s3_cfg.endpoint is not None: + kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} - def get(self, from_path: str, to_path: str, recursive: bool = False): - if from_path != to_path: - if recursive: - self.copy_tree(from_path, to_path) - else: - copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) + if anonymous: + kwargs[_ANON] = True - def put(self, from_path: str, to_path: str, recursive: bool = False): - if from_path != to_path: - if recursive: - self.copy_tree(from_path, to_path) - else: - # Emulate s3's flat storage by automatically creating directory path - self._make_local_path(os.path.dirname(self.strip_file_header(to_path))) - # Write the object to a local file in the temp local folder - copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def construct_path(self, _: bool, add_prefix: bool, *args: str) -> str: - # Ignore add_protocol for now. Only complicates things - if add_prefix: - prefix = self.default_prefix if self.default_prefix else "" - return os.path.join(prefix, *args) - return os.path.join(*args) - - -def stringify_path(filepath): - """ - Copied from `filesystem_spec `__ - - Attempt to convert a path-like object to a string. - Parameters - ---------- - filepath: object to be converted - Returns - ------- - filepath_str: maybe a string version of the object - Notes - ----- - Objects supporting the fspath protocol (Python 3.6+) are coerced - according to its __fspath__ method. - For backwards compatibility with older Python version, pathlib.Path - objects are specially coerced. - Any other object is passed through unchanged, which includes bytes, - strings, buffers, or anything else that's not even path-like. - """ - if isinstance(filepath, str): - return filepath - elif hasattr(filepath, "__fspath__"): - return filepath.__fspath__() - elif isinstance(filepath, pathlib.Path): - return str(filepath) - elif hasattr(filepath, "path"): - return filepath.path - else: - return filepath - - -def split_protocol(urlpath): - """ - Copied from `filesystem_spec `__ - Return protocol, path pair - """ - urlpath = stringify_path(urlpath) - if "://" in urlpath: - protocol, path = urlpath.split("://", 1) - if len(protocol) > 1: - # excludes Windows paths - return protocol, path - return None, urlpath + return kwargs class FileAccessProvider(object): @@ -335,13 +83,18 @@ def __init__( local_sandbox_dir_appended = os.path.join(local_sandbox_dir, "local_flytekit") self._local_sandbox_dir = pathlib.Path(local_sandbox_dir_appended) self._local_sandbox_dir.mkdir(parents=True, exist_ok=True) - self._local = DiskPersistence(default_prefix=local_sandbox_dir_appended) + self._local = fsspec.filesystem(None) - self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix)( - default_prefix=raw_output_prefix, data_config=data_config - ) - self._raw_output_prefix = raw_output_prefix self._data_config = data_config if data_config else DataConfig.auto() + self._default_protocol = get_protocol(raw_output_prefix) + self._default_remote = cast(fsspec.AbstractFileSystem, self.get_filesystem(self._default_protocol)) + if os.name == "nt" and raw_output_prefix.startswith("file://"): + raise FlyteAssertion("Cannot use the file:// prefix on Windows.") + self._raw_output_prefix = ( + raw_output_prefix + if raw_output_prefix.endswith(self.sep(self._default_remote)) + else raw_output_prefix + self.sep(self._default_remote) + ) @property def raw_output_prefix(self) -> str: @@ -351,38 +104,120 @@ def raw_output_prefix(self) -> str: def data_config(self) -> DataConfig: return self._data_config + def get_filesystem( + self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs + ) -> typing.Optional[fsspec.AbstractFileSystem]: + if not protocol: + return self._default_remote + if protocol == "file": + kwargs["auto_mkdir"] = True + elif protocol == "s3": + s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + s3kwargs.update(kwargs) + return fsspec.filesystem(protocol, **s3kwargs) # type: ignore + elif protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. + if anonymous: + return None + + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: + protocol = get_protocol(path) + return self.get_filesystem(protocol, anonymous=anonymous, **kwargs) + @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: """ - Deprecated. Lets find a replacement + Deprecated. Let's find a replacement """ - protocol, _ = split_protocol(path) + protocol = get_protocol(path) if protocol is None: return False return protocol != "file" @property def local_sandbox_dir(self) -> os.PathLike: + """ + This is a context based temp dir. + """ return self._local_sandbox_dir @property - def local_access(self) -> DiskPersistence: + def local_access(self) -> fsspec.AbstractFileSystem: return self._local - def construct_random_path( - self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None - ) -> str: + @staticmethod + def strip_file_header(path: str, trim_trailing_sep: bool = False) -> str: """ - Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name + Drops file:// if it exists from the file """ - key = UUID(int=random.getrandbits(128)).hex - if file_path_or_file_name: - _, tail = os.path.split(file_path_or_file_name) - if tail: - return persist.construct_path(False, True, key, tail) - else: - logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") - return persist.construct_path(False, True, key) + if path.startswith("file://"): + return path.replace("file://", "", 1) + return path + + @staticmethod + def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: + f = os.path.join(f, "") + t = os.path.join(t, "") + return f, t + + def sep(self, file_system: typing.Optional[fsspec.AbstractFileSystem]) -> str: + if file_system is None or file_system.protocol == "file": + return os.sep + return file_system.sep + + def exists(self, path: str) -> bool: + try: + file_system = self.get_filesystem_for_path(path) + return file_system.exists(path) + except OSError as oe: + logger.debug(f"Error in exists checking {path} {oe}") + anon_fs = self.get_filesystem(get_protocol(path), anonymous=True) + if anon_fs is not None: + logger.debug(f"Attempting anonymous exists with {anon_fs}") + return anon_fs.exists(path) + raise oe + + def get(self, from_path: str, to_path: str, recursive: bool = False): + file_system = self.get_filesystem_for_path(from_path) + if recursive: + from_path, to_path = self.recursive_paths(from_path, to_path) + try: + if os.name == "nt" and file_system.protocol == "file" and recursive: + import shutil + + return shutil.copytree( + self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True + ) + return file_system.get(from_path, to_path, recursive=recursive) + except OSError as oe: + logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") + file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) + if file_system is not None: + logger.debug(f"Attempting anonymous get with {file_system}") + return file_system.get(from_path, to_path, recursive=recursive) + raise oe + + def put(self, from_path: str, to_path: str, recursive: bool = False): + file_system = self.get_filesystem_for_path(to_path) + from_path = self.strip_file_header(from_path) + if recursive: + # Only check this for the local filesystem + if file_system.protocol == "file" and not file_system.isdir(from_path): + raise FlyteAssertion(f"Source path {from_path} is not a directory") + if os.name == "nt" and file_system.protocol == "file": + import shutil + + return shutil.copytree( + self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True + ) + from_path, to_path = self.recursive_paths(from_path, to_path) + return file_system.put(from_path, to_path, recursive=recursive) def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ @@ -391,7 +226,20 @@ def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._default_remote, file_path_or_file_name) + default_protocol = self._default_remote.protocol + if type(default_protocol) == list: + default_protocol = default_protocol[0] + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + sep = self.sep(self._default_remote) + tail = sep + tail if tail else tail + if default_protocol == "file": + # Special case the local case, users will not expect to see a file:// prefix + return self.strip_file_header(self.raw_output_prefix) + key + tail + + return self._default_remote.unstrip_protocol(self.raw_output_prefix + key + tail) def get_random_remote_directory(self): return self.get_random_remote_path(None) @@ -400,19 +248,19 @@ def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = N """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._local, file_path_or_file_name) + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + if tail: + return os.path.join(self._local_sandbox_dir, key, tail) + return os.path.join(self._local_sandbox_dir, key) def get_random_local_directory(self) -> str: _dir = self.get_random_local_path(None) pathlib.Path(_dir).mkdir(parents=True, exist_ok=True) return _dir - def exists(self, path: str) -> bool: - """ - checks if the given path exists - """ - return DataPersistencePlugins.find_plugin(path)().exists(path) - def download_directory(self, remote_path: str, local_path: str): """ Downloads directory from given remote to local path @@ -439,39 +287,36 @@ def upload_directory(self, local_path: str, remote_path: str): """ return self.put_data(local_path, remote_path, is_multipart=True) - def get_data(self, remote_path: str, local_path: str, is_multipart=False): + @timeit("Download data to local from remote") + def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False): """ - :param Text remote_path: - :param Text local_path: - :param bool is_multipart: + :param remote_path: + :param local_path: + :param is_multipart: """ try: - with PerformanceTimer(f"Copying ({remote_path} -> {local_path})"): - pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - data_persistence_plugin = DataPersistencePlugins.find_plugin(remote_path) - data_persistence_plugin(data_config=self.data_config).get( - remote_path, local_path, recursive=is_multipart - ) + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + self.get(remote_path, to_path=local_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" f"Original exception: {str(ex)}" ) - def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False): + @timeit("Upload data to remote") + def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False): """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure we don't use the true local proxy if the remote path is a file:// - :param Text local_path: - :param Text remote_path: - :param bool is_multipart: + :param local_path: + :param remote_path: + :param is_multipart: """ try: - with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"): - DataPersistencePlugins.find_plugin(remote_path)(data_config=self.data_config).put( - local_path, remote_path, recursive=is_multipart - ) + local_path = str(local_path) + + self.put(cast(str, local_path), remote_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" @@ -479,9 +324,6 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul ) from ex -DataPersistencePlugins.register_plugin("file://", DiskPersistence) -DataPersistencePlugins.register_plugin("/", DiskPersistence) - flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( local_sandbox_dir=os.path.join(flyte_tmp_dir, "sandbox"), diff --git a/flytekit/core/docstring.py b/flytekit/core/docstring.py index 420f26f8f5..fa9d9caec2 100644 --- a/flytekit/core/docstring.py +++ b/flytekit/core/docstring.py @@ -4,7 +4,7 @@ class Docstring(object): - def __init__(self, docstring: str = None, callable_: Callable = None): + def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None): if docstring is not None: self._parsed_docstring = parse(docstring) else: diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index b6cb7ca2b6..bc3ab1d3fd 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -53,7 +53,7 @@ def __init__( ) else: # We don't know how to find the python interface here, approve() sets it below, See the code. - self._python_interface = None + self._python_interface = None # type: ignore @property def name(self) -> str: @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return p # Assume this is an approval operation since that's the only remaining option. - msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?" + msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?" proceed = click.confirm(msg, default=True) if proceed: # We need to return a promise here, and a promise is what should've been passed in by the call in approve() @@ -167,6 +167,7 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st raise ValueError("You can't use approval on a task that doesn't return anything.") ctx = FlyteContextManager.current_context() + upstream_item = typing.cast(Promise, upstream_item) if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: if not upstream_item.ref.node.flyte_entity.python_interface: raise ValueError( diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 954c1ae409..eae7a8e0cf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -5,7 +5,7 @@ import inspect import typing from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast from typing_extensions import Annotated, get_args, get_origin, get_type_hints @@ -21,6 +21,28 @@ T = typing.TypeVar("T") +def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str: + if isinstance(v, tuple): + if v[1]: + return f"{k}: {v[0]}={v[1]}" + return f"{k}: {v[0]}" + return f"{k}: {v}" + + +def repr_type_signature(io: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]]) -> str: + """ + Converts an inputs and outputs to a type signature + """ + s = "(" + i = 0 + for k, v in io.items(): + if i > 0: + s += ", " + s += repr_kv(k, v) + i = i + 1 + return s + ")" + + class Interface(object): """ A Python native interface object, like inspect.signature but simpler. @@ -28,8 +50,8 @@ class Interface(object): def __init__( self, - inputs: typing.Optional[typing.Dict[str, Union[Type, Tuple[Type, Any]], None]] = None, - outputs: typing.Optional[typing.Dict[str, Type]] = None, + inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None, + outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -43,21 +65,23 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs = {} + self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore if inputs: for k, v in inputs.items(): - if isinstance(v, Tuple) and len(v) > 1: - self._inputs[k] = v + if type(v) is tuple and len(cast(Tuple, v)) > 1: + self._inputs[k] = v # type: ignore else: - self._inputs[k] = (v, None) - self._outputs = outputs if outputs else {} + self._inputs[k] = (v, None) # type: ignore + self._outputs = outputs if outputs else {} # type: ignore self._output_tuple_name = output_tuple_name if outputs: variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): + class Output( # type: ignore + collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore + ): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -90,7 +114,7 @@ def __rshift__(self, *args, **kwargs): self._docstring = docstring @property - def output_tuple(self) -> Optional[Type[collections.namedtuple]]: + def output_tuple(self) -> Type[collections.namedtuple]: # type: ignore return self._output_tuple_class @property @@ -98,7 +122,7 @@ def output_tuple_name(self) -> Optional[str]: return self._output_tuple_name @property - def inputs(self) -> typing.Dict[str, Type]: + def inputs(self) -> Dict[str, type]: r = {} for k, v in self._inputs.items(): r[k] = v[0] @@ -111,8 +135,8 @@ def output_names(self) -> Optional[List[str]]: return None @property - def inputs_with_defaults(self) -> typing.Dict[str, Tuple[Type, Any]]: - return self._inputs + def inputs_with_defaults(self) -> Dict[str, Tuple[Type, Any]]: + return cast(Dict[str, Tuple[Type, Any]], self._inputs) @property def default_inputs_as_kwargs(self) -> Dict[str, Any]: @@ -120,13 +144,13 @@ def default_inputs_as_kwargs(self) -> Dict[str, Any]: @property def outputs(self) -> typing.Dict[str, type]: - return self._outputs + return self._outputs # type: ignore @property def docstring(self) -> Optional[Docstring]: return self._docstring - def remove_inputs(self, vars: List[str]) -> Interface: + def remove_inputs(self, vars: Optional[List[str]]) -> Interface: """ This method is useful in removing some variables from the Flyte backend inputs specification, as these are implicit local only inputs or will be supplied by the library at runtime. For example, spark-session etc @@ -151,7 +175,7 @@ def with_inputs(self, extra_inputs: Dict[str, Type]) -> Interface: for k, v in extra_inputs.items(): if k in new_inputs: raise ValueError(f"Input {k} cannot be added as it already exists in the interface") - new_inputs[k] = v + cast(Dict[str, Type], new_inputs)[k] = v return Interface(new_inputs, self._outputs, docstring=self.docstring) def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: @@ -167,6 +191,12 @@ def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: new_outputs[k] = v return Interface(self._inputs, new_outputs) + def __str__(self): + return f"{repr_type_signature(self._inputs)} -> {repr_type_signature(self._outputs)}" + + def __repr__(self): + return str(self) + def transform_inputs_to_parameters( ctx: context_manager.FlyteContext, interface: Interface @@ -220,7 +250,7 @@ def transform_interface_to_typed_interface( return _interface_models.TypedInterface(inputs_map, outputs_map) -def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: +def transform_types_to_list_of_type(m: Dict[str, type], bound_inputs: typing.Set[str]) -> Dict[str, type]: """ Converts a given variables to be collections of their type. This is useful for array jobs / map style code. It will create a collection of types even if any one these types is not a collection type @@ -230,6 +260,10 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: all_types_are_collection = True for k, v in m.items(): + if k in bound_inputs: + # Skip the inputs that are bound. If they are bound, it does not matter if they are collection or + # singletons + continue v_type = type(v) if v_type != typing.List and v_type != list: all_types_are_collection = False @@ -240,33 +274,40 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] + if k in bound_inputs: + om[k] = v + else: + om[k] = typing.List[v] # type: ignore return om # type: ignore -def transform_interface_to_list_interface(interface: Interface) -> Interface: +def transform_interface_to_list_interface(interface: Interface, bound_inputs: typing.Set[str]) -> Interface: """ Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map like functions + :param interface: Interface to be upgraded toa list interface + :param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars. """ - map_inputs = transform_types_to_list_of_type(interface.inputs) - map_outputs = transform_types_to_list_of_type(interface.outputs) + map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs) + map_outputs = transform_types_to_list_of_type(interface.outputs, set()) return Interface(inputs=map_inputs, outputs=map_outputs) -def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T], Annotated]: +def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T]]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): - if get_origin(t) is list: - return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] - elif get_origin(t) is dict and t.__args__[0] == str: - return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] - elif get_origin(t) is typing.Union: - return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] - elif get_origin(t) is Annotated: + ot = get_origin(t) + args = getattr(t, "__args__") + if ot is list: + return typing.List[_change_unrecognized_type_to_pickle(args[0])] # type: ignore + elif ot is dict and args[0] == str: + return typing.Dict[str, _change_unrecognized_type_to_pickle(args[1])] # type: ignore + elif ot is typing.Union: + return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] # type: ignore + elif ot is Annotated: base_type, *config = get_args(t) - return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] + return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] # type: ignore TypeEngine.get_transformer(t) except ValueError: logger.warning( @@ -286,7 +327,6 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) @@ -294,12 +334,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): outputs[k] = _change_unrecognized_type_to_pickle(v) # type: ignore - inputs = OrderedDict() + inputs: Dict[str, Tuple[Type, Any]] = OrderedDict() for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future - inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) + inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) # type: ignore # This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We # would like to preserve that name in our custom collections.namedtuple. @@ -325,23 +365,24 @@ def transform_variable_map( if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) - sub_type: Type[T] = v + sub_type: type = v if hasattr(v, "__origin__") and hasattr(v, "__args__"): - if v.__origin__ is list: - sub_type = v.__args__[0] - elif v.__origin__ is dict: - sub_type = v.__args__[1] - if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: - if hasattr(sub_type.python_type(), "__name__"): - res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} - elif hasattr(sub_type.python_type(), "_name"): + if getattr(v, "__origin__") is list: + sub_type = getattr(v, "__args__")[0] + elif getattr(v, "__origin__") is dict: + sub_type = getattr(v, "__args__")[1] + if hasattr(sub_type, "__origin__") and getattr(sub_type, "__origin__") is FlytePickle: + original_type = cast(FlytePickle, sub_type).python_type() + if hasattr(original_type, "__name__"): + res[k].type.metadata = {"python_class_name": original_type.__name__} + elif hasattr(original_type, "_name"): # If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead. - res[k].type.metadata = {"python_class_name": sub_type.python_type()._name} + res[k].type.metadata = {"python_class_name": original_type._name} return res -def transform_type(x: type, description: str = None) -> _interface_models.Variable: +def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable: return _interface_models.Variable(type=TypeEngine.to_literal_type(x), description=description) @@ -393,13 +434,13 @@ def t(a: int, b: str) -> Dict[str, int]: ... # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python - if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): + if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore # isinstance / issubclass does not work for Namedtuple. # Options 1 and 2 bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(get_type_hints(return_annotation, include_extras=True)) + return dict(get_type_hints(cast(Type, return_annotation), include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 @@ -419,7 +460,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... else: # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") - return {default_output_name(): return_annotation} + return {default_output_name(): cast(Type, return_annotation)} def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 550dc1919e..86011f1253 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -74,7 +74,7 @@ def wf(a: int, c: str) -> str: # The reason we cache is simply because users may get the default launch plan twice for a single Workflow. We # don't want to create two defaults, could be confusing. - CACHE = {} + CACHE: typing.Dict[str, LaunchPlan] = {} @staticmethod def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan: @@ -107,16 +107,16 @@ def create( cls, name: str, workflow: _annotated_workflow.WorkflowBase, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() default_inputs = default_inputs or {} @@ -130,7 +130,7 @@ def create( temp_inputs = {} for k, v in default_inputs.items(): temp_inputs[k] = (workflow.python_interface.inputs[k], v) - temp_interface = Interface(inputs=temp_inputs, outputs={}) + temp_interface = Interface(inputs=temp_inputs, outputs={}) # type: ignore temp_signature = transform_inputs_to_parameters(ctx, temp_interface) wf_signature_parameters._parameters.update(temp_signature.parameters) @@ -185,16 +185,16 @@ def get_or_create( cls, workflow: _annotated_workflow.WorkflowBase, name: Optional[str] = None, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: """ This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not @@ -298,13 +298,13 @@ def __init__( workflow: _annotated_workflow.WorkflowBase, parameters: _interface_models.ParameterMap, fixed_inputs: _literal_models.LiteralMap, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: typing.Optional[int] = None, - security_context: typing.Optional[security.SecurityContext] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ): self._name = name self._workflow = workflow @@ -313,7 +313,7 @@ def __init__( self._parameters = _interface_models.ParameterMap(parameters=parameters) self._fixed_inputs = fixed_inputs # See create() for additional information - self._saved_inputs = {} + self._saved_inputs: Dict[str, Any] = {} self._schedule = schedule self._notifications = notifications or [] @@ -328,16 +328,15 @@ def __init__( def clone_with( self, name: str, - parameters: _interface_models.ParameterMap = None, - fixed_inputs: _literal_models.LiteralMap = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, + parameters: Optional[_interface_models.ParameterMap] = None, + fixed_inputs: Optional[_literal_models.LiteralMap] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ) -> LaunchPlan: return LaunchPlan( name=name, @@ -349,7 +348,6 @@ def clone_with( labels=labels or self.labels, annotations=annotations or self.annotations, raw_output_data_config=raw_output_data_config or self.raw_output_data_config, - auth_role=auth_role or self._auth_role, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, ) @@ -407,11 +405,11 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig] return self._raw_output_data_config @property - def max_parallelism(self) -> typing.Optional[int]: + def max_parallelism(self) -> Optional[int]: return self._max_parallelism @property - def security_context(self) -> typing.Optional[security.SecurityContext]: + def security_context(self) -> Optional[security.SecurityContext]: return self._security_context def construct_node_metadata(self) -> _workflow_model.NodeMetadata: diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 11cb3b926c..e0b205ca5b 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,10 +1,12 @@ from typing import Optional -import joblib from diskcache import Cache +from flytekit import lazy_module from flytekit.models.literals import Literal, LiteralCollection, LiteralMap +joblib = lazy_module("joblib") + # Location on the filesystem where serialized objects will be stored # TODO: read from config CACHE_LOCATION = "~/.flyte/local-cache" diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 3b5c0a09ca..b40b5029bb 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -2,71 +2,92 @@ Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with a reference task as well as run-time parameters that limit execution concurrency and failure tolerations. """ - +import functools +import hashlib +import logging import os import typing from contextlib import contextmanager -from itertools import count -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Set from flytekit.configuration import SerializationSettings from flytekit.core import tracker -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.tracker import TrackedInstance +from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes from flytekit.models.array_job import ArrayJob from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql +from flytekit.tools.module_loader import load_object_from_module class MapPythonTask(PythonTask): """ A MapPythonTask defines a :py:class:`flytekit.PythonTask` which specifies how to run an inner :py:class:`flytekit.PythonFunctionTask` across a range of inputs in parallel. - TODO: support lambda functions """ - # To support multiple map tasks declared around identical python function tasks, we keep a global count of - # MapPythonTask instances to uniquely differentiate map task names for each declared instance. - _ids = count(0) - def __init__( self, - python_function_task: PythonFunctionTask, - concurrency: int = None, - min_success_ratio: float = None, + python_function_task: typing.Union[PythonFunctionTask, functools.partial], + concurrency: Optional[int] = None, + min_success_ratio: Optional[float] = None, + bound_inputs: Optional[Set[str]] = None, **kwargs, ): """ + Wrapper that creates a MapPythonTask + :param python_function_task: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given - batch size + batch size :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete - successfully before terminating this task and marking it successful. + successfully before terminating this task and marking it successful + :param bound_inputs: List[str] specifies a list of variable names within the interface of python_function_task, + that are already bound and should not be considered as list inputs, but scalar values. This is mostly + useful at runtime and is passed in by MapTaskResolver. This field is not required when a `partial` method + is specified. The bound_vars will be auto-deduced from the `partial.keywords`. """ - if len(python_function_task.python_interface.inputs.keys()) > 1: - raise ValueError("Map tasks only accept python function tasks with 0 or 1 inputs") - - if len(python_function_task.python_interface.outputs.keys()) > 1: + self._partial = None + if isinstance(python_function_task, functools.partial): + # TODO: We should be able to support partial tasks with lists as inputs + for arg in python_function_task.keywords.values(): + if isinstance(arg, list): + raise ValueError("Map tasks do not support partial tasks with lists as inputs. ") + self._partial = python_function_task + actual_task = self._partial.func + else: + actual_task = python_function_task + + if not isinstance(actual_task, PythonFunctionTask): + raise ValueError("Map tasks can only compose of Python Functon Tasks currently") + + if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") - collection_interface = transform_interface_to_list_interface(python_function_task.python_interface) - instance = next(self._ids) - _, mod, f, _ = tracker.extract_task_module(python_function_task.task_function) - name = f"{mod}.mapper_{f}_{instance}" - - self._cmd_prefix = None - self._run_task = python_function_task - self._max_concurrency = concurrency - self._min_success_ratio = min_success_ratio - self._array_task_interface = python_function_task.python_interface - if "metadata" not in kwargs and python_function_task.metadata: - kwargs["metadata"] = python_function_task.metadata - if "security_ctx" not in kwargs and python_function_task.security_context: - kwargs["security_ctx"] = python_function_task.security_context + self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set() + if self._partial: + self._bound_inputs = set(self._partial.keywords.keys()) + + collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) + self._run_task: PythonFunctionTask = actual_task + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() + name = f"{mod}.map_{f}_{h}" + + self._cmd_prefix: typing.Optional[typing.List[str]] = None + self._max_concurrency: typing.Optional[int] = concurrency + self._min_success_ratio: typing.Optional[float] = min_success_ratio + self._array_task_interface = actual_task.python_interface + if "metadata" not in kwargs and actual_task.metadata: + kwargs["metadata"] = actual_task.metadata + if "security_ctx" not in kwargs and actual_task.security_context: + kwargs["security_ctx"] = actual_task.security_context super().__init__( name=name, interface=collection_interface, @@ -76,7 +97,15 @@ def __init__( **kwargs, ) + @property + def bound_inputs(self) -> Set[str]: + return self._bound_inputs + def get_command(self, settings: SerializationSettings) -> List[str]: + """ + TODO ADD bound variables to the resolver. Maybe we need a different resolver? + """ + mt = MapTaskResolver() container_args = [ "pyflyte-map-execute", "--inputs", @@ -90,9 +119,9 @@ def get_command(self, settings: SerializationSettings) -> List[str]: "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - self._run_task.task_resolver.location, + mt.name(), "--", - *self._run_task.task_resolver.loader_args(settings, self._run_task), + *mt.loader_args(settings, self), ] if self._cmd_prefix: @@ -100,7 +129,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): - self._cmd_prefix = cmd # type: ignore + self._cmd_prefix = cmd @contextmanager def prepare_target(self): @@ -135,6 +164,18 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] def run_task(self) -> PythonFunctionTask: return self._run_task + def __call__(self, *args, **kwargs): + """ + This call method modifies the kwargs and adds kwargs from partial. + This is mostly done in the local_execute and compilation only. + At runtime, the map_task is created with all the inputs filled in. to support this, we have modified + the map_task interface in the constructor. + """ + if self._partial: + """If partial exists, then mix-in all partial values""" + kwargs = {**self._partial.keywords, **kwargs} + return super().__call__(*args, **kwargs) + def execute(self, **kwargs) -> Any: ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: @@ -149,8 +190,8 @@ def _compute_array_job_index() -> int: environment variable and the offset (if one's set). The offset will be set and used when the user request that the job runs in a number of slots less than the size of the input. """ - return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", 0)) + int( - os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")) + return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", "0")) + int( + os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "0"), "0") ) @property @@ -168,7 +209,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]: return self.interface.outputs return self._run_task.interface.outputs - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> type: """ We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this interface to construct outputs. Each instance of an container_array task will however produce outputs @@ -181,7 +222,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: return self._python_interface.outputs[k] return self._run_task._python_interface.outputs[k] - def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: + def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: """ This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but @@ -191,7 +232,11 @@ def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: task_index = self._compute_array_job_index() map_task_inputs = {} for k in self.interface.inputs.keys(): - map_task_inputs[k] = kwargs[k][task_index] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + map_task_inputs[k] = v[task_index] + else: + map_task_inputs[k] = v return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs) def _raw_execute(self, **kwargs) -> Any: @@ -213,7 +258,11 @@ def _raw_execute(self, **kwargs) -> Any: for i in range(len(kwargs[any_input_key])): single_instance_inputs = {} for k in self.interface.inputs.keys(): - single_instance_inputs[k] = kwargs[k][i] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + single_instance_inputs[k] = kwargs[k][i] + else: + single_instance_inputs[k] = kwargs[k] o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs) if outputs_expected: outputs.append(o) @@ -221,7 +270,12 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs): +def map_task( + task_function: typing.Union[PythonFunctionTask, functools.partial], + concurrency: int = 0, + min_success_ratio: float = 1.0, + **kwargs, +): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. @@ -267,8 +321,64 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_succes successfully before terminating this task and marking it successful. """ - if not isinstance(task_function, PythonFunctionTask): - raise ValueError( - f"Only Flyte python task types are supported in map tasks currently, received {type(task_function)}" - ) return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs) + + +class MapTaskResolver(TrackedInstance, TaskResolverMixin): + """ + Special resolver that is used for MapTasks. + This exists because it is possible that MapTasks are created using nested "partial" subtasks. + When a maptask is created its interface is interpolated from the interface of the subtask - the interpolation, + simply converts every input into a list/collection input. + + For example: + interface -> (i: int, j: str) -> str => map_task interface -> (i: List[int], j: List[str]) -> List[str] + + But in cases in which `j` is bound to a fixed value by using `functools.partial` we need a way to ensure that + the interface is not simply interpolated, but only the unbound inputs are interpolated. + + .. code-block:: python + + def foo((i: int, j: str) -> str: + ... + + mt = map_task(functools.partial(foo, j=10)) + + print(mt.interface) + + output: + + (i: List[int], j: str) -> List[str] + + But, at runtime this information is lost. To reconstruct this, we use MapTaskResolver that records the "bound vars" + and then at runtime reconstructs the interface with this knowledge + """ + + def name(self) -> str: + return "MapTaskResolver" + + @timeit("Load map task") + def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPythonTask: + """ + Loader args should be of the form + vars "var1,var2,.." resolver "resolver" [resolver_args] + """ + _, bound_vars, _, resolver, *resolver_args = loader_args + logging.info(f"MapTask found task resolver {resolver} and arguments {resolver_args}") + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + bound_inputs = set(bound_vars.split(",")) + return MapPythonTask(python_function_task=_task_def, max_concurrency=max_concurrency, bound_inputs=bound_inputs) + + def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore + return [ + "vars", + f'{",".join(t.bound_inputs)}', + "resolver", + t.run_task.task_resolver.location, + *t.run_task.task_resolver.loader_args(settings, t.run_task), + ] + + def get_all_tasks(self) -> List[Task]: + raise NotImplementedError("MapTask resolver cannot return every instance of the map task") diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 617790746f..c9a547efdb 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -84,7 +84,9 @@ def metadata(self) -> _workflow_model.NodeMetadata: def with_overrides(self, *args, **kwargs): if "node_name" in kwargs: - self._id = kwargs["node_name"] + # Convert the node name into a DNS-compliant. + # https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names + self._id = _dnsify(kwargs["node_name"]) if "aliases" in kwargs: alias_dict = kwargs["aliases"] if not isinstance(alias_dict, dict): @@ -126,7 +128,7 @@ def with_overrides(self, *args, **kwargs): def _convert_resource_overrides( resources: typing.Optional[Resources], resource_name: str -) -> [_resources_model.ResourceEntry]: +) -> typing.List[_resources_model.ResourceEntry]: if resources is None: return [] if not isinstance(resources, Resources): diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index de33393c13..62065f6869 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,7 +1,6 @@ from __future__ import annotations -import collections -from typing import TYPE_CHECKING, Type, Union +from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext @@ -21,7 +20,7 @@ def create_node( entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs -) -> Union[Node, VoidPromise, Type[collections.namedtuple]]: +) -> Union[Node, VoidPromise]: """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or don't produce outputs. For example, if you have t1() and t2(), both of which do not take in nor produce any @@ -173,9 +172,9 @@ def sub_wf(): if len(output_names) == 1: # See explanation above for why we still tupletize a single element. - return entity.python_interface.output_tuple(results) + return entity.python_interface.output_tuple(results) # type: ignore - return entity.python_interface.output_tuple(*results) + return entity.python_interface.output_tuple(*results) # type: ignore else: raise Exception(f"Cannot use explicit run to call Flyte entities {entity.name}") diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py index 5e9c746911..98ba92af36 100644 --- a/flytekit/core/pod_template.py +++ b/flytekit/core/pod_template.py @@ -1,22 +1,27 @@ from dataclasses import dataclass -from typing import Dict, Optional - -from kubernetes.client.models import V1PodSpec +from typing import TYPE_CHECKING, Dict, Optional from flytekit.exceptions import user as _user_exceptions +if TYPE_CHECKING: + from kubernetes.client import V1PodSpec + PRIMARY_CONTAINER_DEFAULT_NAME = "primary" -@dataclass +@dataclass(init=True, repr=True, eq=True, frozen=False) class PodTemplate(object): """Custom PodTemplate specification for a Task.""" - pod_spec: V1PodSpec = V1PodSpec(containers=[]) + pod_spec: Optional["V1PodSpec"] = None primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME labels: Optional[Dict[str, str]] = None annotations: Optional[Dict[str, str]] = None def __post_init__(self): + if self.pod_spec is None: + from kubernetes.client import V1PodSpec + + self.pod_spec = V1PodSpec(containers=[]) if not self.primary_container_name: raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 3a851a50ea..2a3687c06c 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -10,11 +10,18 @@ from flytekit.core import context_manager as _flyte_context from flytekit.core import interface as flyte_interface from flytekit.core import type_engine -from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ( + BranchEvalMode, + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, +) from flytekit.core.interface import Interface from flytekit.core.node import Node -from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine +from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models import literals as _literals_models @@ -81,11 +88,17 @@ def extract_value( if lt.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") try: - sub_type = ListTransformer.get_sub_type(python_type) + sub_type: type = ListTransformer.get_sub_type(python_type) except ValueError: if len(input_val) == 0: raise sub_type = type(input_val[0]) + # To maintain consistency between translate_inputs_to_literals and ListTransformer.to_literal for batchable types, + # directly call ListTransformer.to_literal to batch process the list items. This is necessary because processing + # each list item separately could lead to errors since ListTransformer.to_python_value may treat the literal + # as it is batched for batchable types. + if ListTransformer.is_batchable(python_type): + return TypeEngine.to_literal(ctx, input_val, python_type, lt) literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) elif isinstance(input_val, dict): @@ -135,7 +148,10 @@ def extract_value( raise ValueError(f"Received unexpected keyword argument {k}") var = flyte_interface_types[k] t = native_types[k] - result[k] = extract_value(ctx, v, t, var.type) + try: + result[k] = extract_value(ctx, v, t, var.type) + except TypeTransformerFailedError as exc: + raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc return result @@ -348,7 +364,7 @@ def __hash__(self): return hash(id(self)) def __rshift__(self, other: Union[Promise, VoidPromise]): - if not self.is_ready: + if not self.is_ready and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -408,10 +424,10 @@ def is_false(self) -> ComparisonExpression: def is_true(self): return self.is_(True) - def __eq__(self, other) -> ComparisonExpression: + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) - def __ne__(self, other) -> ComparisonExpression: + def __ne__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.NE, other) def __gt__(self, other) -> ComparisonExpression: @@ -455,7 +471,7 @@ def __str__(self): def create_native_named_tuple( ctx: FlyteContext, - promises: Optional[Union[Promise, List[Promise]]], + promises: Union[Tuple[Promise], Promise, VoidPromise, None], entity_interface: Interface, ) -> Optional[Tuple]: """ @@ -471,12 +487,16 @@ def create_native_named_tuple( if isinstance(promises, Promise): k, v = [(k, v) for k, v in entity_interface.outputs.items()][0] # get output native type + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = k if k != "o0" else 0 try: return TypeEngine.to_python_value(ctx, promises.val, v) except Exception as e: - raise AssertionError(f"Failed to convert value of output {k}, expected type {v}.") from e + raise TypeError( + f"Failed to convert output in position {key} of value {promises.val}, expected type {v}." + ) from e - if len(promises) == 0: + if len(cast(Tuple[Promise], promises)) == 0: return None named_tuple_name = "DefaultNamedTupleOutput" @@ -484,7 +504,7 @@ def create_native_named_tuple( named_tuple_name = entity_interface.output_tuple_name outputs = {} - for p in promises: + for i, p in enumerate(cast(Tuple[Promise], promises)): if not isinstance(p, Promise): raise AssertionError( "Workflow outputs can only be promises that are returned by tasks. Found a value of" @@ -494,11 +514,13 @@ def create_native_named_tuple( try: outputs[p.var] = TypeEngine.to_python_value(ctx, p.val, t) except Exception as e: - raise AssertionError(f"Failed to convert value of output {p.var}, expected type {t}.") from e + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = p.var if p.var != f"o{i}" else i + raise TypeError(f"Failed to convert output in position {key} of value {p.val}, expected type {t}.") from e # Should this class be part of the Interface? - t = collections.namedtuple(named_tuple_name, list(outputs.keys())) - return t(**outputs) + nt = collections.namedtuple(named_tuple_name, list(outputs.keys())) # type: ignore + return nt(**outputs) # To create a class that is a named tuple, we might have to create namedtuplemeta and manipulate the tuple @@ -542,7 +564,7 @@ def create_task_output( named_tuple_name = entity_interface.output_tuple_name # Should this class be part of the Interface? - class Output(collections.namedtuple(named_tuple_name, variables)): + class Output(collections.namedtuple(named_tuple_name, variables)): # type: ignore def with_overrides(self, *args, **kwargs): val = self.__getattribute__(self._fields[0]) val.with_overrides(*args, **kwargs) @@ -597,11 +619,22 @@ def binding_data_from_python_std( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) - elif isinstance(t_value, list): - if expected_literal_type.collection_type is None: - raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") + elif expected_literal_type.union_type is not None: + for i in range(len(expected_literal_type.union_type.variants)): + try: + lt_type = expected_literal_type.union_type.variants[i] + python_type = get_args(t_value_type)[i] if t_value_type else None + return binding_data_from_python_std(ctx, lt_type, t_value, python_type) + except Exception: + logger.debug( + f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}." + ) + raise AssertionError( + f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}." + ) - sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None + elif isinstance(t_value, list): + sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value @@ -683,7 +716,7 @@ def ref(self) -> Optional[NodeOutput]: return self._ref def __rshift__(self, other: Union[Promise, VoidPromise]): - if self.ref: + if self.ref and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -1019,11 +1052,13 @@ def create_and_link_node( class LocallyExecutable(Protocol): - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: ... -def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): +def flyte_entity_call_handler( + entity: SupportsNodeCreation, *args, **kwargs +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1049,7 +1084,7 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): for k, v in kwargs.items(): if k not in cast(SupportsNodeCreation, entity).python_interface.inputs: raise ValueError( - f"Received unexpected keyword argument {k} in function {cast(SupportsNodeCreation, entity).name}" + f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'" ) ctx = FlyteContextManager.current_context() @@ -1075,7 +1110,7 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ) as child_ctx: - cast(FlyteContext, child_ctx).user_space_params._decks = [] + cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) @@ -1085,7 +1120,9 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): else: raise Exception(f"Received an output when workflow local execution expected None. Received: {result}") - if (1 < expected_outputs == len(result)) or (result is not None and expected_outputs == 1): + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( + result is not None and expected_outputs == 1 + ): return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) raise ValueError( diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 2d05df3c3d..47da6a9729 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,12 +3,7 @@ import importlib import re from abc import ABC -from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union - -from flyteidl.core import tasks_pb2 as _core_task -from kubernetes.client import ApiClient -from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements +from typing import Callable, Dict, List, Optional, TypeVar, Union from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin @@ -17,7 +12,8 @@ from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -26,10 +22,6 @@ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" -def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: - return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") - - class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the @@ -44,7 +36,7 @@ def __init__( name: str, task_config: T, task_type="python-task", - container_image: Optional[str] = None, + container_image: Optional[Union[str, ImageSpec]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, environment: Optional[Dict[str, str]] = None, @@ -86,7 +78,7 @@ def __init__( raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) - # pod_template_name overwrites the metedata.pod_template_name + # pod_template_name overwrites the metadata.pod_template_name kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() kwargs["metadata"].pod_template_name = pod_template_name @@ -120,11 +112,11 @@ def __init__( self.pod_template = pod_template @property - def task_resolver(self) -> Optional[TaskResolverMixin]: + def task_resolver(self) -> TaskResolverMixin: return self._task_resolver @property - def container_image(self) -> Optional[str]: + def container_image(self) -> Optional[Union[str, ImageSpec]]: return self._container_image @property @@ -189,6 +181,9 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain for elem in (settings.env, self.environment): if elem: env.update(elem) + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self.container_image, ImageSpec): + self.container_image.source_root = settings.source_root return _get_container_definition( image=get_registerable_container_image(self.container_image, settings.image_config), command=[], @@ -207,52 +202,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain memory_limit=self.resources.limits.mem, ) - def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = self.pod_template.pod_spec.containers - primary_exists = False - - for container in containers: - if container.name == self.pod_template.primary_container_name: - primary_exists = True - break - - if not primary_exists: - # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=self.pod_template.primary_container_name)) - final_containers = [] - for container in containers: - # In the case of the primary container, we overwrite specific container attributes - # with the default values used in the regular Python task. - # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == self.pod_template.primary_container_name: - sdk_default_container = self._get_container(settings) - container.image = sdk_default_container.image - # clear existing commands - container.command = sdk_default_container.command - # also clear existing args - container.args = sdk_default_container.args - limits, requests = {}, {} - for resource in sdk_default_container.resources.limits: - limits[_sanitize_resource_name(resource)] = resource.value - for resource in sdk_default_container.resources.requests: - requests[_sanitize_resource_name(resource)] = resource.value - resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if len(limits) > 0 or len(requests) > 0: - # Important! Only copy over resource requirements if they are non-empty. - container.resources = resource_requirements - container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( - container.env or [] - ) - final_containers.append(container) - self.pod_template.pod_spec.containers = final_containers - - return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec) - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None return _task_model.K8sPod( - pod_spec=self._serialize_pod_spec(settings), + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, @@ -274,14 +228,15 @@ class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): def name(self) -> str: return "DefaultTaskResolver" - def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask: + @timeit("Load task") + def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, task_module, _, task_name, *_ = loader_args - task_module = importlib.import_module(task_module) + task_module = importlib.import_module(name=task_module) # type: ignore task_def = getattr(task_module, task_name) return task_def - def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore from flytekit.core.python_function_task import PythonFunctionTask if isinstance(task, PythonFunctionTask): @@ -291,19 +246,23 @@ def loader_args(self, settings: SerializationSettings, task: PythonAutoContainer _, m, t, _ = extract_task_module(task) return ["task-module", m, "task-name", t] - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore raise Exception("should not be needed") default_task_resolver = DefaultTaskResolver() -def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> str: +def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg: ImageConfig) -> str: """ - :param img: Configured image + :param img: Configured image or image spec :param cfg: Registration configuration :return: """ + if isinstance(img, ImageSpec): + ImageBuildEngine.build(img) + return img.image_name() + if img is not None and img != "": matches = _IMAGE_REPLACE_REGEX.findall(img) if matches is None or len(matches) == 0: diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index eee0dce9b8..07493886a2 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -21,7 +21,7 @@ TC = TypeVar("TC") -class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): +class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): # type: ignore """ Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class should be subclassed and a custom Executor provided as a default to this parent class constructor @@ -229,7 +229,7 @@ def name(self) -> str: # The return type of this function is different, it should be a Task, but it's not because it doesn't make # sense for ExecutableTemplateShimTask to inherit from Task. - def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: + def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: # type: ignore logger.info(f"Task template loader args: {loader_args}") ctx = FlyteContext.current_context() task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") # type: ignore @@ -240,7 +240,7 @@ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: executor_class = load_object_from_module(loader_args[1]) return ExecutableTemplateShimTask(task_template_model, executor_class) - def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: # type: ignore return ["{{.taskTemplatePath}}", f"{t.executor_type.__module__}.{t.executor_type.__name__}"] def get_all_tasks(self) -> List[Task]: diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 81f6739a39..90b10cbc36 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -17,7 +17,7 @@ from abc import ABC from collections import OrderedDict from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar, Union, cast from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager @@ -43,7 +43,7 @@ T = TypeVar("T") -class PythonInstanceTask(PythonAutoContainerTask[T], ABC): +class PythonInstanceTask(PythonAutoContainerTask[T], ABC): # type: ignore """ This class should be used as the base class for all Tasks that do not have a user defined function body, but have a platform defined execute method. (Execute needs to be overridden). This base class ensures that the module loader @@ -72,7 +72,7 @@ def __init__( super().__init__(name=name, task_config=task_config, task_type=task_type, task_resolver=task_resolver, **kwargs) -class PythonFunctionTask(PythonAutoContainerTask[T]): +class PythonFunctionTask(PythonAutoContainerTask[T]): # type: ignore """ A Python Function task should be used as the base for all extensions that have a python function. It will automatically detect interface of the python function and when serialized on the hosted Flyte platform handles the @@ -193,10 +193,10 @@ def compile_into_workflow( from flytekit.tools.translator import get_serializable self._create_and_cache_dynamic_workflow() - self._wf.compile(**kwargs) + cast(PythonFunctionWorkflow, self._wf).compile(**kwargs) wf = self._wf - model_entities = OrderedDict() + model_entities: OrderedDict = OrderedDict() # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. @@ -263,12 +263,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() - function_outputs = self._wf.execute(**kwargs) + function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name) - if len(self._wf.python_interface.outputs) == 0: + if len(cast(PythonFunctionWorkflow, self._wf).python_interface.outputs) == 0: raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") # TODO: This will need to be cleaned up when we revisit top-level tuple support. diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 7247457d86..de386fa159 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -21,7 +21,7 @@ from flytekit.models.core import workflow as _workflow_model -@dataclass +@dataclass # type: ignore class Reference(ABC): project: str domain: str @@ -72,7 +72,7 @@ class ReferenceEntity(object): def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], - inputs: Optional[Dict[str, Union[Type[Any], Tuple[Type[Any], Any]]]], + inputs: Dict[str, Type], outputs: Dict[str, Type], ): if ( diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 7addc89197..93116d0720 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -6,6 +6,7 @@ import datetime import re as _re +from typing import Optional import croniter as _croniter @@ -52,7 +53,11 @@ class CronSchedule(_schedule_models.Schedule): _OFFSET_PATTERN = _re.compile("([-+]?)P([-+0-9YMWD]+)?(T([-+0-9HMS.,]+)?)?") def __init__( - self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None + self, + cron_expression: Optional[str] = None, + schedule: Optional[str] = None, + offset: Optional[str] = None, + kickoff_time_input_arg: Optional[str] = None, ): """ :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler. @@ -161,7 +166,7 @@ class FixedRate(_schedule_models.Schedule): See the :std:ref:`fixed rate intervals` chapter in the cookbook for additional usage examples. """ - def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: str = None): + def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: Optional[str] = None): """ :param datetime.timedelta duration: :param str kickoff_time_input_arg: diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py index d8d18293c5..f96db3e49c 100644 --- a/flytekit/core/shim_task.py +++ b/flytekit/core/shim_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Generic, Type, TypeVar, Union +from typing import Any, Generic, Optional, Type, TypeVar, Union, cast -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger @@ -47,7 +47,7 @@ def name(self) -> str: if self._task_template is not None: return self._task_template.id.name # if not access the subclass's name - return self._name + return self._name # type: ignore @property def task_template(self) -> _task_model.TaskTemplate: @@ -67,13 +67,13 @@ def execute(self, **kwargs) -> Any: """ return self.executor.execute_from_model(self.task_template, **kwargs) - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ return user_params - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, _: Optional[ExecutionParameters], rval: Any) -> Any: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ @@ -92,7 +92,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) ) as exec_ctx: # Added: Have to reverse the Python interface from the task template Flyte interface # See docstring for more details. diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 28c5b5def7..562099c641 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,6 +1,6 @@ import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface @@ -8,6 +8,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.image_spec.image_spec import ImageSpec from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -74,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction return PythonFunctionTask +T = TypeVar("T") + + +@overload +def task( + _task_function: None = ..., + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]: + ... + + +@overload +def task( + _task_function: Callable[..., Any], + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> PythonFunctionTask[T]: + ... + + def task( - _task_function: Optional[Callable] = None, - task_config: Optional[Any] = None, + _task_function: Optional[Callable[..., Any]] = None, + task_config: Optional[T] = None, cache: bool = False, cache_serialize: bool = False, cache_version: str = "", @@ -84,18 +140,18 @@ def task( interruptible: Optional[bool] = None, deprecated: str = "", timeout: Union[_datetime.timedelta, int] = 0, - container_image: Optional[str] = None, + container_image: Optional[Union[str, ImageSpec]] = None, environment: Optional[Dict[str, str]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, secret_requests: Optional[List[Secret]] = None, - execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, disable_deck: bool = True, - pod_template: Optional[PodTemplate] = None, + pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable, PythonFunctionTask]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]: """ This is the core decorator to use for any task type in flytekit. @@ -189,7 +245,7 @@ def foo2(): :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ - def wrapper(fn) -> PythonFunctionTask: + def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize, @@ -225,7 +281,7 @@ def wrapper(fn) -> PythonFunctionTask: return wrapper -class ReferenceTask(ReferenceEntity, PythonFunctionTask): +class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore """ This is a reference task, the body of the function passed in through the constructor will never be used, only the signature of the function will be. The signature should also match the signature of the task you're referencing, @@ -233,7 +289,7 @@ class ReferenceTask(ReferenceEntity, PythonFunctionTask): """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] + self, project: str, domain: str, name: str, version: str, inputs: Dict[str, type], outputs: Dict[str, Type] ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/testing.py b/flytekit/core/testing.py index 772a4b6df6..f1a0fec7de 100644 --- a/flytekit/core/testing.py +++ b/flytekit/core/testing.py @@ -1,3 +1,4 @@ +import typing from contextlib import contextmanager from typing import Union from unittest.mock import MagicMock @@ -9,7 +10,7 @@ @contextmanager -def task_mock(t: PythonTask) -> MagicMock: +def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]: """ Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native interface associated with it. @@ -41,9 +42,9 @@ def _log(*args, **kwargs): return m(*args, **kwargs) _captured_fn = t.execute - t.execute = _log + t.execute = _log # type: ignore yield m - t.execute = _captured_fn + t.execute = _captured_fn # type: ignore def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]): diff --git a/flytekit/core/tracked_abc.py b/flytekit/core/tracked_abc.py index bad4f8c555..3c39d3725c 100644 --- a/flytekit/core/tracked_abc.py +++ b/flytekit/core/tracked_abc.py @@ -3,7 +3,7 @@ from flytekit.core.tracker import TrackedInstance -class FlyteTrackedABC(type(TrackedInstance), type(ABC)): +class FlyteTrackedABC(type(TrackedInstance), type(ABC)): # type: ignore """ This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 2a203d4861..1123c57d25 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -91,7 +91,6 @@ def find_lhs(self) -> str: # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and # continue looping through m. logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err)) - pass logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") @@ -179,7 +178,7 @@ class _ModuleSanitizer(object): def __init__(self): self._module_cache = {} - def _resolve_abs_module_name(self, path: str, package_root: str) -> str: + def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str] = None) -> str: """ Recursively finds the root python package under-which basename exists """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 7bfc85d1ef..5d4e0731b8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,14 +22,15 @@ from google.protobuf.json_format import ParseDict as _ParseDict from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions -from marshmallow_jsonschema import JSONSchema from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag +from flytekit.core.utils import timeit from flytekit.exceptions import user as user_exceptions +from flytekit.lazy_import.lazy_module import is_imported from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import types as _type_models @@ -88,7 +89,7 @@ def type_assertions_enabled(self) -> bool: def assert_type(self, t: Type[T], v: T): if not hasattr(t, "__origin__") and not isinstance(v, t): - raise TypeTransformerFailedError(f"Type of Val '{v}' is not an instance of {t}") + raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @abstractmethod def get_literal_type(self, t: Type[T]) -> LiteralType: @@ -118,7 +119,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise NotImplementedError(f"Conversion to Literal for python type {python_type} not implemented") @abstractmethod - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: """ Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised :param ctx: FlyteContext @@ -129,7 +130,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) - @abstractmethod def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -162,12 +162,14 @@ def __init__( self._to_literal_transformer = to_literal_transformer self._from_literal_transformer = from_literal_transformer - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: return LiteralType.from_flyte_idl(self._lt.to_flyte_idl()) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != self._type: - raise TypeTransformerFailedError(f"Expected value of type {self._type} but got type {type(python_val)}") + raise TypeTransformerFailedError( + f"Expected value of type {self._type} but got '{python_val}' of type {type(python_val)}" + ) return self._to_literal_transformer(python_val) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: @@ -186,7 +188,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return res except AttributeError: # Assume that this is because a property on `lv` was None - raise TypeTransformerFailedError(f"Cannot convert literal {lv}") + raise TypeTransformerFailedError(f"Cannot convert literal {lv} to {self._type}") def guess_python_type(self, literal_type: LiteralType) -> Type[T]: if literal_type.simple is not None and literal_type.simple == self._lt.simple: @@ -207,7 +209,7 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC): def __init__(self, name: str, t: Type[T]): super().__init__(name, t) - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: @@ -327,6 +329,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 if isinstance(v, EnumField): v.load_by = LoadDumpOptions.name + from marshmallow_jsonschema import JSONSchema + schema = JSONSchema().dump(s) except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 @@ -374,7 +378,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, # so here we convert it back to the Structured Dataset. - from flytekit import StructuredDataset + from flytekit.types.structured import StructuredDataset if python_type == StructuredDataset and type(python_val) == dict: return StructuredDataset(**python_val) @@ -408,11 +412,13 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A return None return self._serialize_flyte_type(python_val, get_args(python_type)[0]) - if hasattr(python_type, "__origin__") and python_type.__origin__ is list: - return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] + if hasattr(python_type, "__origin__") and get_origin(python_type) is list: + return [self._serialize_flyte_type(v, get_args(python_type)[0]) for v in cast(list, python_val)] - if hasattr(python_type, "__origin__") and python_type.__origin__ is dict: - return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()} + if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: + return { + k: self._serialize_flyte_type(v, get_args(python_type)[1]) for k, v in cast(dict, python_val).items() + } if not dataclasses.is_dataclass(python_type): return python_val @@ -472,7 +478,13 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> t = FlyteSchemaTransformer() return t.to_python_value( FlyteContext.current_context(), - Literal(scalar=Scalar(schema=Schema(python_val.remote_path, t._get_schema_type(expected_python_type)))), + Literal( + scalar=Scalar( + schema=Schema( + cast(FlyteSchema, python_val).remote_path, t._get_schema_type(expected_python_type) + ) + ) + ), expected_python_type, ) elif issubclass(expected_python_type, FlyteFile): @@ -486,7 +498,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ), - uri=python_val.path, + uri=cast(FlyteFile, python_val).path, ) ) ), @@ -503,7 +515,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART ) ), - uri=python_val.path, + uri=cast(FlyteDirectory, python_val).path, ) ) ), @@ -516,9 +528,11 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> scalar=Scalar( structured_dataset=StructuredDataset( metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=python_val.file_format) + structured_dataset_type=StructuredDatasetType( + format=cast(StructuredDataset, python_val).file_format + ) ), - uri=python_val.uri, + uri=cast(StructuredDataset, python_val).uri, ) ) ), @@ -557,7 +571,9 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if isinstance(val, dict): ktype, vtype = DictTransformer.get_dict_types(t) # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) - return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + return { + self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items() + } if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore @@ -598,7 +614,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # calls to guess_python_type would result in a logically equivalent (but new) dataclass, which # TypeEngine.assert_type would not be happy about. @lru_cache(typed=True) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore if literal_type.simple == SimpleType.STRUCT: if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: schema_name = literal_type.metadata["$ref"].split("/")[-1] @@ -623,7 +639,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(python_val)) + struct.update(_MessageToDict(cast(Message, python_val))) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) @@ -634,7 +650,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: pb_obj = expected_python_type() dictionary = _MessageToDict(lv.scalar.generic) - pb_obj = _ParseDict(dictionary, pb_obj) + pb_obj = _ParseDict(dictionary, pb_obj) # type: ignore return pb_obj def guess_python_type(self, literal_type: LiteralType) -> Type[T]: @@ -657,7 +673,8 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] - _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() + _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore + has_lazy_import = False @classmethod def register( @@ -682,10 +699,10 @@ def register( def register_restricted_type( cls, name: str, - type: Type, + type: Type[T], ): cls._RESTRICTED_TYPES.append(type) - cls.register(RestrictedTypeTransformer(name, type)) + cls.register(RestrictedTypeTransformer(name, type)) # type: ignore @classmethod def register_additional_type(cls, transformer: TypeTransformer, additional_type: Type, override=False): @@ -701,24 +718,32 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: d = dictionary of registered transformers, where is a python `type` v = lookup type Step 1: - find a transformer that matches v exactly + If the type is annotated with a TypeTransformer instance, use that. Step 2: - find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc + find a transformer that matches v exactly Step 3: + find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc + + Step 4: Walk the inheritance hierarchy of v and find a transformer that matches the first base class. This is potentially non-deterministic - will depend on the registration pattern. TODO lets make this deterministic by using an ordered dict - Step 4: + Step 5: if v is of type data class, use the dataclass transformer """ - + cls.lazy_import_transformers() # Step 1 if get_origin(python_type) is Annotated: - python_type = get_args(python_type)[0] + args = get_args(python_type) + for annotation in args: + if isinstance(annotation, TypeTransformer): + return annotation + + python_type = args[0] if python_type in cls._REGISTRY: return cls._REGISTRY[python_type] @@ -757,6 +782,39 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer") + @classmethod + def lazy_import_transformers(cls): + """ + Only load the transformers if needed. + """ + if cls.has_lazy_import: + return + cls.has_lazy_import = True + from flytekit.types.structured import ( + register_arrow_handlers, + register_bigquery_handlers, + register_pandas_handlers, + ) + + if is_imported("tensorflow"): + from flytekit.extras import tensorflow # noqa: F401 + if is_imported("torch"): + from flytekit.extras import pytorch # noqa: F401 + if is_imported("sklearn"): + from flytekit.extras import sklearn # noqa: F401 + if is_imported("pandas"): + try: + from flytekit.types import schema # noqa: F401 + except ValueError: + logger.debug("Transformer for pandas is already registered.") + register_pandas_handlers() + if is_imported("pyarrow"): + register_arrow_handlers() + if is_imported("google.cloud.bigquery"): + register_bigquery_handlers() + if is_imported("numpy"): + from flytekit.types import numpy # noqa: F401 + @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: """ @@ -820,7 +878,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T return transformer.to_python_value(ctx, lv, expected_python_type) @classmethod - def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[T]) -> str: + def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str: transformer = cls.get_transformer(expected_python_type) if get_origin(expected_python_type) is Annotated: expected_python_type, *annotate_args = get_args(expected_python_type) @@ -843,6 +901,7 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. return _interface_models.VariableMap(variables=variables) @classmethod + @timeit("Translate literal to python value") def literal_map_to_kwargs( cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type] ) -> typing.Dict[str, typing.Any]: @@ -853,7 +912,13 @@ def literal_map_to_kwargs( raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} + kwargs = {} + for i, k in enumerate(lm.literals): + try: + kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) + except TypeTransformerFailedError as exc: + raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc + return kwargs @classmethod def dict_to_literal_map( @@ -942,8 +1007,8 @@ def get_sub_type(t: Type[T]) -> Type[T]: if get_origin(t) is Annotated: return ListTransformer.get_sub_type(get_args(t)[0]) - if t.__origin__ is list and hasattr(t, "__args__"): - return t.__args__[0] + if getattr(t, "__origin__") is list and hasattr(t, "__args__"): + return getattr(t, "__args__")[0] raise ValueError("Only generic univariate typing.List[T] type is supported.") @@ -957,27 +1022,67 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") + @staticmethod + def is_batchable(t: Type): + """ + This function evaluates whether the provided type is batchable or not. + It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle. + """ + from flytekit.types.pickle import FlytePickle + + if get_origin(t) is Annotated: + return ListTransformer.is_batchable(get_args(t)[0]) + if get_origin(t) is list: + subtype = get_args(t)[0] + if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle): + return True + return False + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") - t = self.get_sub_type(python_type) - lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore + if ListTransformer.is_batchable(python_type): + from flytekit.types.pickle.pickle import BatchSize, FlytePickle + + batch_size = len(python_val) # default batch size + # parse annotated to get the number of items saved in a pickle file. + if get_origin(python_type) is Annotated: + for annotation in get_args(python_type)[1:]: + if isinstance(annotation, BatchSize): + batch_size = annotation.val + break + if batch_size > 0: + lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore + else: + lit_list = [] + else: + t = self.get_sub_type(python_type) + lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[T]: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore try: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() + if self.is_batchable(expected_python_type): + from flytekit.types.pickle import FlytePickle + + batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] + if len(batch_list) > 0 and type(batch_list[0]) is list: + # Make it have backward compatibility. The upstream task may use old version of Flytekit that + # won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. + return [item for batch in batch_list for item in batch] + return batch_list + else: + st = self.get_sub_type(expected_python_type) + return [TypeEngine.to_python_value(ctx, x, st) for x in lits] - st = self.get_sub_type(expected_python_type) - return [TypeEngine.to_python_value(ctx, x, st) for x in lits] - - def guess_python_type(self, literal_type: LiteralType) -> Type[list]: + def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: - ct = TypeEngine.guess_python_type(literal_type.collection_type) - return typing.List[ct] + ct: Type = TypeEngine.guess_python_type(literal_type.collection_type) + return typing.List[ct] # type: ignore raise ValueError(f"List transformer cannot reverse {literal_type}") @@ -1033,7 +1138,7 @@ def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool: if len(ucols) != len(dcols): return False - for (u, d) in zip(ucols, dcols): + for u, d in zip(ucols, dcols): if u.name != d.name: return False @@ -1090,7 +1195,9 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_args(t)[0] try: - trans = [(TypeEngine.get_transformer(x), x) for x in get_args(t)] + trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [ + (TypeEngine.get_transformer(x), x) for x in get_args(t) + ] # must go through TypeEngine.to_literal_type instead of trans.get_literal_type # to handle Annotated variants = [_add_tag_to_type(TypeEngine.to_literal_type(x), t.name) for (t, x) in trans] @@ -1107,7 +1214,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp res_type = None for t in get_args(python_type): try: - trans = TypeEngine.get_transformer(t) + trans: TypeTransformer[T] = TypeEngine.get_transformer(t) res = trans.to_literal(ctx, python_val, t, expected) res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) @@ -1140,7 +1247,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: res_tag = None for v in get_args(expected_python_type): try: - trans = TypeEngine.get_transformer(v) + trans: TypeTransformer[T] = TypeEngine.get_transformer(v) if union_tag is not None: if trans.name != union_tag: continue @@ -1179,7 +1286,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] # type: ignore raise ValueError(f"Union transformer cannot reverse {literal_type}") @@ -1226,7 +1333,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: if tp: if tp[0] == str: try: - sub_type = TypeEngine.to_literal_type(tp[1]) + sub_type = TypeEngine.to_literal_type(cast(type, tp[1])) return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") @@ -1247,7 +1354,7 @@ def to_literal( raise ValueError("Flyte MapType expects all keys to be strings") # TODO: log a warning for Annotated objects that contain HashMethod k_type, v_type = self.get_dict_types(python_type) - lit_map[k] = TypeEngine.to_literal(ctx, v, v_type, expected.map_value_type) + lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: @@ -1263,7 +1370,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key") py_map = {} for k, v in lv.map.literals.items(): - py_map[k] = TypeEngine.to_python_value(ctx, v, tp[1]) + py_map[k] = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) return py_map # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict @@ -1301,10 +1408,8 @@ def _blob_type(self) -> _core_types.BlobType: dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) - def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType( - blob=self._blob_type(), - ) + def get_literal_type(self, t: typing.TextIO) -> LiteralType: # type: ignore + return _type_models.LiteralType(blob=self._blob_type()) def to_literal( self, ctx: FlyteContext, python_val: typing.TextIO, python_type: Type[typing.TextIO], expected: LiteralType @@ -1375,7 +1480,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: raise TypeTransformerFailedError("Only EnumTypes with value of string are supported") return LiteralType(enum_type=_core_types.EnumType(values=values)) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + def to_literal( + self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType + ) -> Literal: if type(python_val).__class__ != enum.EnumMeta: raise TypeTransformerFailedError("Expected an enum") if type(python_val.value) != str: @@ -1384,11 +1491,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - return expected_python_type(lv.scalar.primitive.string_value) + return expected_python_type(lv.scalar.primitive.string_value) # type: ignore -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: - """Generate a model class based on the provided JSON Schema +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ @@ -1397,7 +1505,7 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac property_type = property_val["type"] # Handle list if property_val["type"] == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore # Handle dataclass and dict elif property_type == "object": if property_val.get("$ref"): @@ -1405,13 +1513,13 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) elif property_val.get("additionalProperties"): attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: - attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) + attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) # type: ignore # Handle int, float, bool or str else: - attribute_list.append([property_key, _get_element_type(property_val)]) + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) @@ -1585,8 +1693,8 @@ def __init__( raise ValueError("Cannot instantiate LiteralsResolver without a map of Literals.") self._literals = literals self._variable_map = variable_map - self._native_values = {} - self._type_hints = {} + self._native_values: Dict[str, type] = {} + self._type_hints: Dict[str, type] = {} self._ctx = ctx def __str__(self) -> str: @@ -1639,7 +1747,7 @@ def __getitem__(self, key: str): return self.get(key) - def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: + def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: # type: ignore """ This will get the ``attr`` value from the Literal map, and invoke the TypeEngine to convert it into a Python native value. A Python type can optionally be supplied. If successful, the native value will be cached and @@ -1666,7 +1774,9 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: raise e else: ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver") - val = TypeEngine.to_python_value(self._ctx or FlyteContext.current_context(), self._literals[attr], as_type) + val = TypeEngine.to_python_value( + self._ctx or FlyteContext.current_context(), self._literals[attr], cast(Type, as_type) + ) self._native_values[attr] = val return val diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index d23aae3fbb..95ec4d33d9 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -1,13 +1,20 @@ +import datetime import os as _os import shutil as _shutil import tempfile as _tempfile import time as _time +from functools import wraps from hashlib import sha224 as _sha224 from pathlib import Path -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast +from flyteidl.core import tasks_pb2 as _core_task + +from flytekit.core.pod_template import PodTemplate from flytekit.loggers import logger -from flytekit.models import task as _task_models + +if TYPE_CHECKING: + from flytekit.models import task as task_models def _dnsify(value: str) -> str: @@ -51,8 +58,8 @@ def _dnsify(value: str) -> str: def _get_container_definition( image: str, command: List[str], - args: List[str], - data_loading_config: Optional[_task_models.DataLoadingConfig] = None, + args: Optional[List[str]] = None, + data_loading_config: Optional["task_models.DataLoadingConfig"] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, cpu_request: Optional[str] = None, @@ -64,7 +71,7 @@ def _get_container_definition( gpu_limit: Optional[str] = None, memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, -) -> _task_models.Container: +) -> "task_models.Container": storage_limit = storage_limit storage_request = storage_request ephemeral_storage_limit = ephemeral_storage_limit @@ -76,6 +83,9 @@ def _get_container_definition( memory_limit = memory_limit memory_request = memory_request + from flytekit.models import task as task_models + + # TODO: Use convert_resources_to_resource_model instead of manually fixing the resources. requests = [] if storage_request: requests.append( @@ -126,6 +136,56 @@ def _get_container_definition( ) +def _sanitize_resource_name(resource: "task_models.Resources.ResourceEntry") -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") + + +def _serialize_pod_spec(pod_template: "PodTemplate", primary_container: "task_models.Container") -> Dict[str, Any]: + from kubernetes.client import ApiClient, V1PodSpec + from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements + + if pod_template.pod_spec is None: + return {} + containers = cast(V1PodSpec, pod_template.pod_spec).containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the values given to ContainerTask. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, pod_template).primary_container_name: + container.image = primary_container.image + container.command = primary_container.command + container.args = primary_container.args + + limits, requests = {}, {} + for resource in primary_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in primary_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + if primary_container.env is not None: + container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + cast(V1PodSpec, pod_template.pod_spec).containers = final_containers + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec) + + def load_proto_from_file(pb2_type, path): with open(path, "rb") as reader: out = pb2_type() @@ -209,26 +269,66 @@ def __str__(self): return self.__repr__() -class PerformanceTimer(object): - def __init__(self, context_statement): +class timeit: + """ + A context manager and a decorator that measures the execution time of the wrapped code block or functions. + It will append a timing information to TimeLineDeck. For instance: + + @timeit("Function description") + def function() + + with timeit("Wrapped code block description"): + # your code + """ + + def __init__(self, name: str = ""): """ - :param Text context_statement: the statement to log + :param name: A string that describes the wrapped code block or function being executed. """ - self._context_statement = context_statement + self._name = name + self.start_time = None self._start_wall_time = None self._start_process_time = None + def __call__(self, func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + def __enter__(self): - logger.info("Entering timed context: {}".format(self._context_statement)) + self.start_time = datetime.datetime.utcnow() self._start_wall_time = _time.perf_counter() self._start_process_time = _time.process_time() + return self def __exit__(self, exc_type, exc_val, exc_tb): + """ + The exception, if any, will propagate outside the context manager, as the purpose of this context manager + is solely to measure the execution time of the wrapped code block. + """ + from flytekit.core.context_manager import FlyteContextManager + + end_time = datetime.datetime.utcnow() end_wall_time = _time.perf_counter() end_process_time = _time.process_time() + + timeline_deck = FlyteContextManager.current_context().user_space_params.timeline_deck + timeline_deck.append_time_info( + dict( + Name=self._name, + Start=self.start_time, + Finish=end_time, + WallTime=end_wall_time - self._start_wall_time, + ProcessTime=end_process_time - self._start_process_time, + ) + ) + logger.info( - "Exiting timed context: {} [Wall Time: {}s, Process Time: {}s]".format( - self._context_statement, + "{}. [Wall Time: {}s, Process Time: {}s]".format( + self._name, end_wall_time - self._start_wall_time, end_process_time - self._start_process_time, ) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 8ba307b767..ab24a642e4 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -1,9 +1,12 @@ from __future__ import annotations +import typing from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload + +from typing_extensions import get_args from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -32,14 +35,16 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference from flytekit.core.tracker import extract_task_module -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError, UnionTransformer from flytekit.exceptions import scopes as exception_scopes from flytekit.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models +from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.documentation import Description, Documentation +from flytekit.models.types import TypeStructure GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -49,6 +54,8 @@ flyte_entity=None, ) +T = typing.TypeVar("T") + class WorkflowFailurePolicy(Enum): """ @@ -177,9 +184,9 @@ def __init__( self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface self._interface = transform_interface_to_typed_interface(python_interface) - self._inputs = {} - self._unbound_inputs = set() - self._nodes = [] + self._inputs: Dict[str, Promise] = {} + self._unbound_inputs: set = set() + self._nodes: List[Node] = [] self._output_bindings: List[_literal_models.Binding] = [] self._docs = docs @@ -191,7 +198,9 @@ def __init__( ) else: if self._python_interface.docstring.short_description: - self._docs.short_description = self._python_interface.docstring.short_description + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description if self._python_interface.docstring.long_description: self._docs = Description(value=self._python_interface.docstring.long_description) @@ -211,11 +220,11 @@ def short_name(self) -> str: return extract_obj_name(self._name) @property - def workflow_metadata(self) -> Optional[WorkflowMetadata]: + def workflow_metadata(self) -> WorkflowMetadata: return self._workflow_metadata @property - def workflow_metadata_defaults(self): + def workflow_metadata_defaults(self) -> WorkflowMetadataDefaults: return self._workflow_metadata_defaults @property @@ -250,7 +259,7 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.workflow_metadata_defaults.interruptible, ) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ Workflow needs to fill in default arguments before invoking the call handler. """ @@ -258,7 +267,11 @@ def __call__(self, *args, **kwargs): input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) self.compile() - return flyte_entity_call_handler(self, *args, **input_kwargs) + try: + return flyte_entity_call_handler(self, *args, **input_kwargs) + except Exception as exc: + exc.args = (f"Encountered error while executing workflow '{self.name}':\n {exc}", *exc.args[1:]) + raise exc def execute(self, **kwargs): raise Exception("Should not be called") @@ -266,19 +279,63 @@ def execute(self, **kwargs): def compile(self, **kwargs): pass + def ensure_literal( + self, ctx, py_type: Type[T], input_type: type_models.LiteralType, python_value: Any + ) -> _literal_models.Literal: + """ + This function will attempt to convert a python value to a literal. If the python value is a promise, it will + return the promise's value. + """ + if input_type.union_type is not None: + if python_value is None and UnionTransformer.is_optional_type(py_type): + return _literal_models.Literal(scalar=_literal_models.Scalar(none_type=_literal_models.Void())) + for i in range(len(input_type.union_type.variants)): + lt_type = input_type.union_type.variants[i] + python_type = get_args(py_type)[i] + try: + final_lt = self.ensure_literal(ctx, python_type, lt_type, python_value) + lt_type._structure = TypeStructure(tag=TypeEngine.get_transformer(python_type).name) + return _literal_models.Literal( + scalar=_literal_models.Scalar(union=_literal_models.Union(value=final_lt, stored_type=lt_type)) + ) + except Exception as e: + logger.debug(f"Failed to convert {python_value} to {lt_type} with error {e}") + raise TypeError(f"Failed to convert {python_value} to {input_type}") + if isinstance(python_value, list) and input_type.collection_type: + collection_lit_type = input_type.collection_type + collection_py_type = get_args(py_type)[0] + xx = [self.ensure_literal(ctx, collection_py_type, collection_lit_type, pv) for pv in python_value] + return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=xx)) + elif isinstance(python_value, dict) and input_type.map_value_type: + mapped_lit_type = input_type.map_value_type + mapped_py_type = get_args(py_type)[1] + xx = {k: self.ensure_literal(ctx, mapped_py_type, mapped_lit_type, v) for k, v in python_value.items()} # type: ignore + return _literal_models.Literal(map=_literal_models.LiteralMap(literals=xx)) + # It is a scalar, convert to Promise if necessary. + else: + if isinstance(python_value, Promise): + return python_value.val + if not isinstance(python_value, Promise): + try: + res = TypeEngine.to_literal(ctx, python_value, py_type, input_type) + return res + except TypeTransformerFailedError as exc: + raise TypeError( + f"Failed to convert input '{python_value}' of workflow '{self.name}':\n {exc}" + ) from exc + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. for k, v in kwargs.items(): - if not isinstance(v, Promise): - t = self.python_interface.inputs[k] - kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) + py_type = self.python_interface.inputs[k] + lit_type = self.interface.inputs[k].type + kwargs[k] = Promise(var=k, val=self.ensure_literal(ctx, py_type, lit_type, v)) - # The output of this will always be a combination of Python native values and Promises containing Flyte - # Literals. + # The output of this will always be a combination of Python native values and Promises containing Flyte + # Literals. self.compile() function_outputs = self.execute(**kwargs) - # First handle the empty return case. # A workflow function may return a task that doesn't return anything # def wf(): @@ -419,7 +476,7 @@ def execute(self, **kwargs): raise FlyteValidationException(f"Workflow not ready, wf is currently {self}") # Create a map that holds the outputs of each node. - intermediate_node_outputs = {GLOBAL_START_NODE: {}} # type: Dict[Node, Dict[str, Promise]] + intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} # Start things off with the outputs of the global input node, i.e. the inputs to the workflow. # local_execute should've already ensured that all the values in kwargs are Promise objects @@ -516,7 +573,7 @@ def get_input_values(input_value): self._unbound_inputs.remove(input_value) return n # type: ignore - def add_workflow_input(self, input_name: str, python_type: Type) -> Interface: + def add_workflow_input(self, input_name: str, python_type: Type) -> Promise: """ Adds an input to the workflow. """ @@ -543,7 +600,8 @@ def add_workflow_output( f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}" f" starting with the container type (e.g. List[int]" ) - python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var] + promise = cast(Promise, p) + python_type = promise.ref.node.flyte_entity.python_interface.outputs[promise.var] logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}") flyte_type = TypeEngine.to_literal_type(python_type=python_type) @@ -595,9 +653,9 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): def __init__( self, - workflow_function: Callable, - metadata: Optional[WorkflowMetadata], - default_metadata: Optional[WorkflowMetadataDefaults], + workflow_function: Callable[..., Any], + metadata: WorkflowMetadata, + default_metadata: WorkflowMetadataDefaults, docstring: Optional[Docstring] = None, docs: Optional[Documentation] = None, ): @@ -622,7 +680,7 @@ def __init__( def function(self): return self._workflow_function - def task_name(self, t: PythonAutoContainerTask) -> str: + def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore return f"{self.name}.{t.__module__}.{t.name}" def compile(self, **kwargs): @@ -719,12 +777,32 @@ def execute(self, **kwargs): return exception_scopes.user_entry_point(self._workflow_function)(**kwargs) +@overload +def workflow( + _workflow_function: None = ..., + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]: + ... + + +@overload +def workflow( + _workflow_function: Callable[..., Any], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> PythonFunctionWorkflow: + ... + + def workflow( - _workflow_function=None, + _workflow_function: Optional[Callable[..., Any]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, docs: Optional[Documentation] = None, -) -> WorkflowBase: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -755,7 +833,7 @@ def workflow( :param docs: Description entity for the workflow """ - def wrapper(fn): + def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow: workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -770,13 +848,13 @@ def wrapper(fn): update_wrapper(workflow_instance, fn) return workflow_instance - if _workflow_function: + if _workflow_function is not None: return wrapper(_workflow_function) else: - return wrapper + return wrapper # type: ignore -class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): +class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore """ A reference workflow is a pointer to a workflow that already exists on your Flyte installation. This object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface. diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index cec59e7318..0d53ec18d6 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -2,8 +2,6 @@ import typing from typing import Optional -from jinja2 import Environment, FileSystemLoader, select_autoescape - from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.loggers import logger @@ -74,6 +72,58 @@ def html(self) -> str: return self._html +class TimeLineDeck(Deck): + """ + The TimeLineDeck class is designed to render the execution time of each part of a task. + Unlike deck class, the conversion of data to HTML is delayed until the html property is accessed. + This approach is taken because rendering a timeline graph with partial data would not provide meaningful insights. + Instead, the complete data set is used to create a comprehensive visualization of the execution time of each part of the task. + """ + + def __init__(self, name: str, html: Optional[str] = ""): + super().__init__(name, html) + self.time_info = [] + + def append_time_info(self, info: dict): + assert isinstance(info, dict) + self.time_info.append(info) + + @property + def html(self) -> str: + try: + from flytekitplugins.deck.renderer import GanttChartRenderer, TableRenderer + except ImportError: + warning_info = "Plugin 'flytekit-deck-standard' is not installed. To display time line, install the plugin in the image." + logger.warning(warning_info) + return warning_info + + if len(self.time_info) == 0: + return "" + + import pandas + + df = pandas.DataFrame(self.time_info) + note = """ +

Note:

+
    +
  1. if the time duration is too small(< 1ms), it may be difficult to see on the time line graph.
  2. +
  3. For accurate execution time measurements, users should refer to wall time and process time.
  4. +
+ """ + # set the accuracy to microsecond + df["ProcessTime"] = df["ProcessTime"].apply(lambda time: "{:.6f}".format(time)) + df["WallTime"] = df["WallTime"].apply(lambda time: "{:.6f}".format(time)) + + width = 1400 + gantt_chart_html = GanttChartRenderer().to_html(df, chart_width=width) + time_table_html = TableRenderer().to_html( + df[["Name", "WallTime", "ProcessTime"]], + header_labels=["Name", "Wall Time(s)", "Process Time(s)"], + table_width=width, + ) + return gantt_chart_html + time_table_html + note + + def _ipython_check() -> bool: """ Check if interface is launching from iPython (not colab) @@ -98,10 +148,12 @@ def _get_deck( If ignore_jupyter is set to True, then it will return a str even in a jupyter environment. """ deck_map = {deck.name: deck.html for deck in new_user_params.decks} - raw_html = template.render(metadata=deck_map) + raw_html = get_deck_template().render(metadata=deck_map) if not ignore_jupyter and _ipython_check(): - from IPython.core.display import HTML - + try: + from IPython.core.display import HTML + except ImportError: + ... return HTML(raw_html) return raw_html @@ -118,15 +170,18 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters): logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}") -root = os.path.dirname(os.path.abspath(__file__)) -templates_dir = os.path.join(root, "html") -env = Environment( - loader=FileSystemLoader(templates_dir), - # 🔥 include autoescaping for security purposes - # sources: - # - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping - # - https://stackoverflow.com/a/38642558/8474894 (see in comments) - # - https://stackoverflow.com/a/68826578/8474894 - autoescape=select_autoescape(enabled_extensions=("html",)), -) -template = env.get_template("template.html") +def get_deck_template() -> "Template": + from jinja2 import Environment, FileSystemLoader, select_autoescape + + root = os.path.dirname(os.path.abspath(__file__)) + templates_dir = os.path.join(root, "html") + env = Environment( + loader=FileSystemLoader(templates_dir), + # 🔥 include autoescaping for security purposes + # sources: + # - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping + # - https://stackoverflow.com/a/38642558/8474894 (see in comments) + # - https://stackoverflow.com/a/68826578/8474894 + autoescape=select_autoescape(enabled_extensions=("html",)), + ) + return env.get_template("template.html") diff --git a/flytekit/deck/html/template.html b/flytekit/deck/html/template.html index 6bec37effe..19e0256880 100644 --- a/flytekit/deck/html/template.html +++ b/flytekit/deck/html/template.html @@ -53,17 +53,19 @@ } #flyte-frame-container { - width: 100%; + width: auto; } #flyte-frame-container > div { - display: none; + display: None; } #flyte-frame-container > div.active { - display: block; + display: Block; padding: 2rem 4rem; + width: 100%; } + diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index dddb88e420..cfea92ec4e 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,14 +1,22 @@ -from typing import Any +from typing import TYPE_CHECKING, Any -import pandas -import pyarrow from typing_extensions import Protocol, runtime_checkable +from flytekit import lazy_module + +if TYPE_CHECKING: + # Always import these modules in type-checking mode or when running pytest + import pandas + import pyarrow +else: + pandas = lazy_module("pandas") + pyarrow = lazy_module("pyarrow") + @runtime_checkable class Renderable(Protocol): def to_html(self, python_value: Any) -> str: - """Convert a object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string. + """Convert an object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string. Returns: An HTML document as a string. """ raise NotImplementedError @@ -27,16 +35,16 @@ def __init__(self, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX self._max_rows = max_rows self._max_cols = max_cols - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: "pandas.DataFrame") -> str: assert isinstance(df, pandas.DataFrame) return df.to_html(max_rows=self._max_rows, max_cols=self._max_cols) class ArrowRenderer: """ - Render a Arrow dataframe as an HTML table. + Render an Arrow dataframe as an HTML table. """ - def to_html(self, df: pyarrow.Table) -> str: + def to_html(self, df: "pyarrow.Table") -> str: assert isinstance(df, pyarrow.Table) return df.to_string() diff --git a/flytekit/exceptions/scopes.py b/flytekit/exceptions/scopes.py index 60a4afa97e..bdfb2ba182 100644 --- a/flytekit/exceptions/scopes.py +++ b/flytekit/exceptions/scopes.py @@ -194,10 +194,13 @@ def user_entry_point(wrapped, instance, args, kwargs): _CONTEXT_STACK.append(_USER_CONTEXT) if _is_base_context(): # See comment at this location for system_entry_point + fn_name = wrapped.__name__ try: return wrapped(*args, **kwargs) - except FlyteScopedException as ex: - raise ex.value + except FlyteScopedException as exc: + raise exc.type(f"Error encountered while executing '{fn_name}':\n {exc.value}") from exc + except Exception as exc: + raise type(exc)(f"Error encountered while executing '{fn_name}':\n {exc}") from exc else: try: return wrapped(*args, **kwargs) diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index f6635a4a57..7223d13523 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -29,8 +29,6 @@ PythonCustomizedContainerTask ExecutableTemplateShimTask ShimTaskExecutor - DataPersistence - DataPersistencePlugins """ from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -39,7 +37,7 @@ from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.context_manager import ExecutionState, SecretsManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.interface import Interface from flytekit.core.promise import Promise from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask diff --git a/flytekit/extend/backend/__init__.py b/flytekit/extend/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/extend/backend/base_plugin.py b/flytekit/extend/backend/base_plugin.py new file mode 100644 index 0000000000..9fc1bc206b --- /dev/null +++ b/flytekit/extend/backend/base_plugin.py @@ -0,0 +1,107 @@ +import typing +from abc import ABC, abstractmethod + +import grpc +from flyteidl.core.tasks_pb2 import TaskTemplate +from flyteidl.service.external_plugin_service_pb2 import ( + RETRYABLE_FAILURE, + RUNNING, + SUCCEEDED, + State, + TaskCreateResponse, + TaskDeleteResponse, + TaskGetResponse, +) + +from flytekit import logger +from flytekit.models.literals import LiteralMap + + +class BackendPluginBase(ABC): + """ + This is the base class for all backend plugins. It defines the interface that all plugins must implement. + The external plugins service will be run either locally or in a pod, and will be responsible for + invoking backend plugins. The propeller will communicate with the external plugins service + to create tasks, get the status of tasks, and delete tasks. + + All the backend plugins should be registered in the BackendPluginRegistry. External plugins service + will look up the plugin based on the task type. Every task type can only have one plugin. + """ + + def __init__(self, task_type: str): + self._task_type = task_type + + @property + def task_type(self) -> str: + """ + task_type is the name of the task type that this plugin supports. + """ + return self._task_type + + @abstractmethod + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> TaskCreateResponse: + """ + Return a Unique ID for the task that was created. It should return error code if the task creation failed. + """ + pass + + @abstractmethod + def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + """ + Return the status of the task, and return the outputs in some cases. For example, bigquery job + can't write the structured dataset to the output location, so it returns the output literals to the propeller, + and the propeller will write the structured dataset to the blob store. + """ + pass + + @abstractmethod + def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse: + """ + Delete the task. This call should be idempotent. + """ + pass + + +class BackendPluginRegistry(object): + """ + This is the registry for all backend plugins. The external plugins service will look up the plugin + based on the task type. + """ + + _REGISTRY: typing.Dict[str, BackendPluginBase] = {} + + @staticmethod + def register(plugin: BackendPluginBase): + if plugin.task_type in BackendPluginRegistry._REGISTRY: + raise ValueError(f"Duplicate plugin for task type {plugin.task_type}") + BackendPluginRegistry._REGISTRY[plugin.task_type] = plugin + logger.info(f"Registering backend plugin for task type {plugin.task_type}") + + @staticmethod + def get_plugin(context: grpc.ServicerContext, task_type: str) -> typing.Optional[BackendPluginBase]: + if task_type not in BackendPluginRegistry._REGISTRY: + logger.error(f"Cannot find backend plugin for task type [{task_type}]") + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Cannot find backend plugin for task type [{task_type}]") + return None + return BackendPluginRegistry._REGISTRY[task_type] + + +def convert_to_flyte_state(state: str) -> State: + """ + Convert the state from the backend plugin to the state in flyte. + """ + state = state.lower() + if state in ["failed"]: + return RETRYABLE_FAILURE + elif state in ["done", "succeeded"]: + return SUCCEEDED + elif state in ["running"]: + return RUNNING + raise ValueError(f"Unrecognized state: {state}") diff --git a/flytekit/extend/backend/external_plugin_service.py b/flytekit/extend/backend/external_plugin_service.py new file mode 100644 index 0000000000..e820a320b1 --- /dev/null +++ b/flytekit/extend/backend/external_plugin_service.py @@ -0,0 +1,53 @@ +import grpc +from flyteidl.service.external_plugin_service_pb2 import ( + PERMANENT_FAILURE, + TaskCreateRequest, + TaskCreateResponse, + TaskDeleteRequest, + TaskDeleteResponse, + TaskGetRequest, + TaskGetResponse, +) +from flyteidl.service.external_plugin_service_pb2_grpc import ExternalPluginServiceServicer + +from flytekit import logger +from flytekit.extend.backend.base_plugin import BackendPluginRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class BackendPluginServer(ExternalPluginServiceServicer): + def CreateTask(self, request: TaskCreateRequest, context: grpc.ServicerContext) -> TaskCreateResponse: + try: + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + plugin = BackendPluginRegistry.get_plugin(context, tmp.type) + if plugin is None: + return TaskCreateResponse() + return plugin.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp) + except Exception as e: + logger.error(f"failed to create task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to create task with error {e}") + + def GetTask(self, request: TaskGetRequest, context: grpc.ServicerContext) -> TaskGetResponse: + try: + plugin = BackendPluginRegistry.get_plugin(context, request.task_type) + if plugin is None: + return TaskGetResponse(state=PERMANENT_FAILURE) + return plugin.get(context=context, job_id=request.job_id) + except Exception as e: + logger.error(f"failed to get task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to get task with error {e}") + + def DeleteTask(self, request: TaskDeleteRequest, context: grpc.ServicerContext) -> TaskDeleteResponse: + try: + plugin = BackendPluginRegistry.get_plugin(context, request.task_type) + if plugin is None: + return TaskDeleteResponse() + return plugin.delete(context=context, job_id=request.job_id) + except Exception as e: + logger.error(f"failed to delete task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to delete task with error {e}") diff --git a/flytekit/extras/persistence/__init__.py b/flytekit/extras/persistence/__init__.py deleted file mode 100644 index a677632fd8..0000000000 --- a/flytekit/extras/persistence/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -======================= -DataPersistence Extras -======================= - -.. currentmodule:: flytekit.extras.persistence - -This module provides some default implementations of :py:class:`flytekit.DataPersistence`. These implementations -use command-line clients to download and upload data. The actual binaries need to be installed for these extras to work. -The binaries are not bundled with flytekit to keep it lightweight. - -Persistence Extras -=================== - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - GCSPersistence - HttpPersistence - S3Persistence -""" - -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.http import HttpPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py deleted file mode 100644 index 0ddb600024..0000000000 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import posixpath -import typing -from shutil import which as shell_which - -from flytekit.configuration import DataConfig, GCSConfig -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions.user import FlyteUserException -from flytekit.tools import subprocess - - -def _update_cmd_config_and_execute(cmd): - env = os.environ.copy() - return subprocess.check_call(cmd, env=env) - - -def _amend_path(path): - return posixpath.join(path, "*") if not path.endswith("*") else path - - -class GCSPersistence(DataPersistence): - """ - This DataPersistence plugin uses a preinstalled GSUtil binary in the container to download and upload data. - - The binary can be installed in multiple ways including simply, - - .. prompt:: - - pip install gsutil - - """ - - _GS_UTIL_CLI = "gsutil" - PROTOCOL = "gs://" - - def __init__(self, default_prefix: typing.Optional[str] = None, data_config: typing.Optional[DataConfig] = None): - super(GCSPersistence, self).__init__(name="gcs-gsutil", default_prefix=default_prefix) - self.gcs_cfg = data_config.gcs if data_config else GCSConfig.auto() - - @staticmethod - def _check_binary(): - """ - Make sure that the `gsutil` cli is present - """ - if not shell_which(GCSPersistence._GS_UTIL_CLI): - raise FlyteUserException("gsutil (gcloud cli) not found! Please install using `pip install gsutil`.") - - def _maybe_with_gsutil_parallelism(self, *gsutil_args): - """ - Check if we should run `gsutil` with the `-m` flag that enables - parallelism via multiple threads/processes. Additional tweaking of - this behavior can be achieved via the .boto configuration file. See: - https://cloud.google.com/storage/docs/boto-gsutil - """ - cmd = [GCSPersistence._GS_UTIL_CLI] - if self.gcs_cfg.gsutil_parallelism: - cmd.append("-m") - cmd.extend(gsutil_args) - - return cmd - - def exists(self, remote_path): - """ - :param Text remote_path: remote gs:// path - :rtype bool: whether the gs file exists or not - """ - GCSPersistence._check_binary() - - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - cmd = [GCSPersistence._GS_UTIL_CLI, "-q", "stat", remote_path] - try: - _update_cmd_config_and_execute(cmd) - return True - except Exception: - return False - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if not from_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - GCSPersistence._check_binary() - if recursive: - cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(from_path), to_path) - else: - cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - - return _update_cmd_config_and_execute(cmd) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - GCSPersistence._check_binary() - - if recursive: - cmd = self._maybe_with_gsutil_parallelism( - "cp", - "-r", - _amend_path(from_path), - to_path if to_path.endswith("/") else to_path + "/", - ) - else: - cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - return _update_cmd_config_and_execute(cmd) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - paths = list(paths) # make type check happy - if add_prefix: - paths.insert(0, self.default_prefix) - path = "/".join(paths) - if add_protocol: - return f"{self.PROTOCOL}{path}" - return path - - -DataPersistencePlugins.register_plugin(GCSPersistence.PROTOCOL, GCSPersistence) diff --git a/flytekit/extras/persistence/http.py b/flytekit/extras/persistence/http.py deleted file mode 100644 index ce6079300d..0000000000 --- a/flytekit/extras/persistence/http.py +++ /dev/null @@ -1,84 +0,0 @@ -import base64 -import os -import pathlib - -import requests - -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions import user -from flytekit.loggers import logger -from flytekit.tools import script_mode - - -class HttpPersistence(DataPersistence): - """ - DataPersistence implementation for the HTTP protocol. only supports downloading from an http source. Uploads are - not supported currently. - """ - - PROTOCOL_HTTP = "http" - PROTOCOL_HTTPS = "https" - _HTTP_OK = 200 - _HTTP_FORBIDDEN = 403 - _HTTP_NOT_FOUND = 404 - ALLOWED_CODES = { - _HTTP_OK, - _HTTP_NOT_FOUND, - _HTTP_FORBIDDEN, - } - - def __init__(self, *args, **kwargs): - super(HttpPersistence, self).__init__(name="http/https", *args, **kwargs) - - def exists(self, path: str): - rsp = requests.head(path) - if rsp.status_code not in self.ALLOWED_CODES: - raise user.FlyteValueException( - rsp.status_code, - f"Data at {path} could not be checked for existence. Expected one of: {self.ALLOWED_CODES}", - ) - return rsp.status_code == self._HTTP_OK - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if recursive: - raise user.FlyteAssertion("Reading data recursively from HTTP endpoint is not currently supported.") - rsp = requests.get(from_path) - if rsp.status_code != self._HTTP_OK: - raise user.FlyteValueException( - rsp.status_code, - "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK), - ) - head, _ = os.path.split(to_path) - if head and head.startswith("/"): - logger.debug(f"HttpPersistence creating {head} so that parent dirs exist") - pathlib.Path(head).mkdir(parents=True, exist_ok=True) - with open(to_path, "wb") as writer: - writer.write(rsp.content) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - if recursive: - raise user.FlyteAssertion("Recursive writing data to HTTP endpoint is not currently supported.") - - md5, _ = script_mode.hash_file(from_path) - encoded_md5 = base64.b64encode(md5) - with open(from_path, "+rb") as local_file: - content = local_file.read() - content_length = len(content) - rsp = requests.put( - to_path, data=content, headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5} - ) - - if rsp.status_code != self._HTTP_OK: - raise user.FlyteValueException( - rsp.status_code, - f"Request to send data {to_path} failed.", - ) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - raise user.FlyteAssertion( - "There are multiple ways of creating http links / paths, this is not supported by the persistence layer" - ) - - -DataPersistencePlugins.register_plugin("http://", HttpPersistence) -DataPersistencePlugins.register_plugin("https://", HttpPersistence) diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py deleted file mode 100644 index 0b00227ca0..0000000000 --- a/flytekit/extras/persistence/s3_awscli.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -import os as _os -import re as _re -import string as _string -import time -import typing -from shutil import which as shell_which -from typing import Dict, List, Optional - -from flytekit.configuration import DataConfig, S3Config -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions.user import FlyteUserException -from flytekit.loggers import logger -from flytekit.tools import subprocess - -S3_ANONYMOUS_FLAG = "--no-sign-request" -S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" -S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" - - -def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]): - env = _os.environ.copy() - - if s3_cfg.enable_debug: - cmd.insert(1, "--debug") - - if s3_cfg.endpoint is not None: - cmd.insert(1, s3_cfg.endpoint) - cmd.insert(1, "--endpoint-url") - - if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: - if s3_cfg.access_key_id: - env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id - - if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ: - if s3_cfg.secret_access_key: - env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key - - retry = 0 - while True: - try: - try: - return subprocess.check_call(cmd, env=env) - except Exception as e: - if retry > 0: - logger.info(f"AWS command failed with error {e}, command: {cmd}, retry {retry}") - - logger.debug(f"Appending anonymous flag and retrying command {cmd}") - anonymous_cmd = cmd[:] # strings only, so this is deep enough - anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG) - return subprocess.check_call(anonymous_cmd, env=env) - - except Exception as e: - logger.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") - retry += 1 - if retry > s3_cfg.retries: - raise - secs = s3_cfg.backoff - logger.info(f"Sleeping before retrying again, after {secs.total_seconds()} seconds") - time.sleep(secs.total_seconds()) - logger.info("Retrying again") - - -def _extra_args(extra_args: Dict[str, str]) -> List[str]: - cmd = [] - if "ContentType" in extra_args: - cmd += ["--content-type", extra_args["ContentType"]] - if "ContentEncoding" in extra_args: - cmd += ["--content-encoding", extra_args["ContentEncoding"]] - if "ACL" in extra_args: - cmd += ["--acl", extra_args["ACL"]] - return cmd - - -class S3Persistence(DataPersistence): - """ - DataPersistence plugin for AWS S3 (and Minio). Use aws cli to manage the transfer. The binary needs to be installed - separately - - .. prompt:: - - pip install awscli - - """ - - PROTOCOL = "s3://" - _AWS_CLI = "aws" - _SHARD_CHARACTERS = [str(x) for x in range(10)] + list(_string.ascii_lowercase) - - def __init__(self, default_prefix: Optional[str] = None, data_config: typing.Optional[DataConfig] = None): - super().__init__(name="awscli-s3", default_prefix=default_prefix) - self.s3_cfg = data_config.s3 if data_config else S3Config.auto() - - @staticmethod - def _check_binary(): - """ - Make sure that the AWS cli is present - """ - if not shell_which(S3Persistence._AWS_CLI): - raise FlyteUserException("AWS CLI not found! Please install it with `pip install awscli`.") - - @staticmethod - def _split_s3_path_to_bucket_and_key(path: str) -> typing.Tuple[str, str]: - """ - splits a valid s3 uri into bucket and key - """ - path = path[len("s3://") :] - first_slash = path.index("/") - return path[:first_slash], path[first_slash + 1 :] - - def exists(self, remote_path): - """ - Given a remote path of the format s3://, checks if the remote file exists - """ - S3Persistence._check_binary() - - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - bucket, file_path = self._split_s3_path_to_bucket_and_key(remote_path) - cmd = [ - S3Persistence._AWS_CLI, - "s3api", - "head-object", - "--bucket", - bucket, - "--key", - file_path, - ] - try: - _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - return True - except Exception as ex: - # The s3api command returns an error if the object does not exist. The error message contains - # the http status code: "An error occurred (404) when calling the HeadObject operation: Not Found" - # This is a best effort for returning if the object does not exist by searching - # for existence of (404) in the error message. This should not be needed when we get off the cli and use lib - if _re.search("(404)", str(ex)): - return False - else: - raise ex - - def get(self, from_path: str, to_path: str, recursive: bool = False): - S3Persistence._check_binary() - - if not from_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - if recursive: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", from_path, to_path] - else: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", from_path, to_path] - return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - extra_args = { - "ACL": "bucket-owner-full-control", - } - - if not to_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - S3Persistence._check_binary() - cmd = [S3Persistence._AWS_CLI, "s3", "cp"] - if recursive: - cmd += ["--recursive"] - cmd.extend(_extra_args(extra_args)) - cmd += [from_path, to_path] - return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: - paths = list(paths) # make type check happy - if add_prefix: - paths.insert(0, self.default_prefix) - path = "/".join(paths) - if add_protocol: - return f"{self.PROTOCOL}{path}" - return path - - -DataPersistencePlugins.register_plugin(S3Persistence.PROTOCOL, S3Persistence) diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 8e7d8b3b29..ef8013a5da 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -14,7 +14,6 @@ from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor from flytekit.models import task as task_models -from flytekit.types.schema import FlyteSchema def unarchive_file(local_path: str, to_dir: str): @@ -78,12 +77,14 @@ def __init__( query_template: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, - output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + output_schema_type: typing.Optional[typing.Type["FlyteSchema"]] = None, # type: ignore container_image: typing.Optional[str] = None, **kwargs, ): if task_config is None or task_config.uri is None: raise ValueError("SQLite DB uri is required.") + from flytekit.types.schema import FlyteSchema + outputs = kwtypes(results=output_schema_type if output_schema_type else FlyteSchema) super().__init__( name=name, diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 12ef36af3e..87b60126d6 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -1,5 +1,6 @@ import datetime import os +import platform import string import subprocess import typing @@ -213,6 +214,9 @@ def execute(self, **kwargs) -> typing.Any: print("\n==============================================\n") try: + if platform.system() == "Windows" and os.environ.get("ComSpec") is None: + # https://github.com/python/cpython/issues/101283 + os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe" subprocess.check_call(gen_script, shell=True) except subprocess.CalledProcessError as e: files = os.listdir(".") @@ -356,7 +360,6 @@ def execute(self, **kwargs) -> typing.Any: # This utility function allows for the specification of env variables, arguments, and the actual script within the # workflow definition rather than at `RawShellTask` instantiation def get_raw_shell_task(name: str) -> RawShellTask: - return RawShellTask( name=name, debug=True, diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index b5699906fb..2447ffa826 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -23,9 +23,10 @@ if _tensorflow_installed: + from .model import TensorFlowModelTransformer from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer else: logger.info( - "We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer " + "We won't register TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer and TensorFlowModelTransformer" "because tensorflow is not installed." ) diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py new file mode 100644 index 0000000000..857ec2c984 --- /dev/null +++ b/flytekit/extras/tensorflow/model.py @@ -0,0 +1,76 @@ +import pathlib +from typing import Type + +import tensorflow as tf + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + + +class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]): + TENSORFLOW_FORMAT = "TensorFlowModel" + + def __init__(self): + super().__init__(name="TensorFlow Model", t=tf.keras.Model) + + def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: tf.keras.Model, + python_type: Type[tf.keras.Model], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save model in SavedModel format + tf.keras.models.save_model(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path() + ctx.file_access.put_data(local_path, remote_path, is_multipart=True) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model] + ) -> tf.keras.Model: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=True) + + # load model + return tf.keras.models.load_model(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return tf.keras.Model + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlowModelTransformer()) diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index d5d750b521..17e7c37ddd 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -159,7 +159,6 @@ def to_literal( def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory] ) -> TFRecordDatasetV2: - uri, metadata = extract_metadata_and_uri(lv, expected_python_type) local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(uri, local_dir, is_multipart=True) diff --git a/flytekit/image_spec/__init__.py b/flytekit/image_spec/__init__.py new file mode 100644 index 0000000000..ca1bdedee6 --- /dev/null +++ b/flytekit/image_spec/__init__.py @@ -0,0 +1 @@ +from .image_spec import ImageSpec diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py new file mode 100644 index 0000000000..1cddb2a913 --- /dev/null +++ b/flytekit/image_spec/image_spec.py @@ -0,0 +1,170 @@ +import base64 +import hashlib +import os +import typing +from abc import abstractmethod +from copy import copy +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional + +import click +import requests +from dataclasses_json import dataclass_json + +DOCKER_HUB = "docker.io" +_F_IMG_ID = "_F_IMG_ID" + + +@dataclass_json +@dataclass +class ImageSpec: + """ + This class is used to specify the docker image that will be used to run the task. + + Args: + name: name of the image. + python_version: python version of the image. Use default python in the base image if None. + builder: Type of plugin to build the image. Use envd by default. + source_root: source root of the image. + env: environment variables of the image. + registry: registry of the image. + packages: list of python packages to install. + apt_packages: list of apt packages to install. + base_image: base image of the image. + """ + + name: str = "flytekit" + python_version: str = None # Use default python in the base image if None. + builder: str = "envd" + source_root: Optional[str] = None + env: Optional[typing.Dict[str, str]] = None + registry: Optional[str] = None + packages: Optional[List[str]] = None + apt_packages: Optional[List[str]] = None + base_image: Optional[str] = None + + def image_name(self) -> str: + """ + return full image name with tag. + """ + tag = calculate_hash_from_image_spec(self) + container_image = f"{self.name}:{tag}" + if self.registry: + container_image = f"{self.registry}/{container_image}" + return container_image + + def is_container(self) -> bool: + from flytekit.core.context_manager import ExecutionState, FlyteContextManager + + state = FlyteContextManager.current_context().execution_state + if state and state.mode and state.mode != ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + return os.environ.get(_F_IMG_ID) == self.image_name() + return True + + @lru_cache + def exist(self) -> bool: + """ + Check if the image exists in the registry. + """ + import docker + from docker.errors import APIError, ImageNotFound + + try: + client = docker.from_env() + if self.registry: + client.images.get_registry_data(self.image_name()) + else: + client.images.get(self.image_name()) + return True + except APIError as e: + if e.response.status_code == 404: + return False + except ImageNotFound: + return False + except Exception as e: + tag = calculate_hash_from_image_spec(self) + # if docker engine is not running locally + container_registry = DOCKER_HUB + if "/" in self.registry: + container_registry = self.registry.split("/")[0] + if container_registry == DOCKER_HUB: + url = f"https://hub.docker.com/v2/repositories/{self.registry}/{self.name}/tags/{tag}" + response = requests.get(url) + if response.status_code == 200: + return True + + if response.status_code == 404: + return False + + click.secho(f"Failed to check if the image exists with error : {e}", fg="red") + click.secho("Flytekit assumes that the image already exists.", fg="blue") + return True + + def __hash__(self): + return hash(self.to_json()) + + +class ImageSpecBuilder: + @abstractmethod + def build_image(self, image_spec: ImageSpec): + """ + Build the docker image and push it to the registry. + + Args: + image_spec: image spec of the task. + """ + raise NotImplementedError("This method is not implemented in the base class.") + + +class ImageBuildEngine: + """ + ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. + """ + + _REGISTRY: typing.Dict[str, ImageSpecBuilder] = {} + + @classmethod + def register(cls, builder_type: str, image_spec_builder: ImageSpecBuilder): + cls._REGISTRY[builder_type] = image_spec_builder + + @classmethod + def build(cls, image_spec: ImageSpec): + if image_spec.builder not in cls._REGISTRY: + raise Exception(f"Builder {image_spec.builder} is not registered.") + if not image_spec.exist(): + click.secho(f"Image {image_spec.image_name()} not found. Building...", fg="blue") + cls._REGISTRY[image_spec.builder].build_image(image_spec) + else: + click.secho(f"Image {image_spec.image_name()} found. Skip building.", fg="blue") + + +@lru_cache +def calculate_hash_from_image_spec(image_spec: ImageSpec): + """ + Calculate the hash from the image spec. + """ + # copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different. + spec = copy(image_spec) + spec.source_root = hash_directory(image_spec.source_root) if image_spec.source_root else b"" + image_spec_bytes = bytes(image_spec.to_json(), "utf-8") + tag = base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii") + # replace "=" with "." to make it a valid tag + return tag.replace("=", ".") + + +def hash_directory(path): + """ + Return the SHA-256 hash of the directory at the given path. + """ + hasher = hashlib.sha256() + for root, dirs, files in os.walk(path): + for file in files: + with open(os.path.join(root, file), "rb") as f: + while True: + # Read file in small chunks to avoid loading large files into memory all at once + chunk = f.read(4096) + if not chunk: + break + hasher.update(chunk) + return bytes(hasher.hexdigest(), "utf-8") diff --git a/flytekit/lazy_import/__init__.py b/flytekit/lazy_import/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/lazy_import/lazy_module.py b/flytekit/lazy_import/lazy_module.py new file mode 100644 index 0000000000..553386eb52 --- /dev/null +++ b/flytekit/lazy_import/lazy_module.py @@ -0,0 +1,33 @@ +import importlib.util +import sys + +LAZY_MODULES = [] + + +def is_imported(module_name): + """ + This function is used to check if a module has been imported by the regular import. + """ + return module_name in sys.modules and module_name not in LAZY_MODULES + + +def lazy_module(fullname): + """ + This function is used to lazily import modules. It is used in the following way: + .. code-block:: python + from flytekit.lazy_import import lazy_module + sklearn = lazy_module("sklearn") + sklearn.svm.SVC() + :param Text fullname: The full name of the module to import + """ + if fullname in sys.modules: + return sys.modules[fullname] + # https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + spec = importlib.util.find_spec(fullname) + loader = importlib.util.LazyLoader(spec.loader) + spec.loader = loader + module = importlib.util.module_from_spec(spec) + sys.modules[fullname] = module + LAZY_MODULES.append(module) + loader.exec_module(module) + return module diff --git a/flytekit/loggers.py b/flytekit/loggers.py index f047348de0..fdc3c75d3a 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -2,6 +2,8 @@ import os from pythonjsonlogger import jsonlogger +from rich.console import Console +from rich.logging import RichHandler # Note: # The environment variable controls exposed to affect the individual loggers should be considered to be beta. @@ -10,6 +12,7 @@ # For now, assume this is the environment variable whose usage will remain unchanged and controls output for all # loggers defined in this file. LOGGING_ENV_VAR = "FLYTE_SDK_LOGGING_LEVEL" +LOGGING_FMT_ENV_VAR = "FLYTE_SDK_LOGGING_FORMAT" # By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning logger = logging.getLogger("flytekit") @@ -33,8 +36,18 @@ user_space_logger = child_loggers["user_space"] # create console handler -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) +try: + handler = RichHandler( + rich_tracebacks=True, + omit_repeated_times=False, + keywords=["[flytekit]"], + log_time_format="%Y-%m-%d %H:%M:%S,%f", + console=Console(width=os.get_terminal_size().columns), + ) +except OSError: + handler = logging.StreamHandler() + +handler.setLevel(logging.DEBUG) # Root logger control # Don't want to import the configuration library since that will cause all sorts of circular imports, let's @@ -63,10 +76,14 @@ child_logger.setLevel(logging.WARNING) # create formatter -formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") +logging_fmt = os.environ.get(LOGGING_FMT_ENV_VAR, "json") +if logging_fmt == "json": + formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") +else: + formatter = logging.Formatter(fmt="[%(name)s] %(message)s") -# add formatter to ch -ch.setFormatter(formatter) +# add formatter to the handler +handler.setFormatter(formatter) # add ch to logger -logger.addHandler(ch) +logger.addHandler(handler) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 62018c1eef..4f030e25a4 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -1,5 +1,6 @@ import abc as _abc import json as _json +import re from flyteidl.admin import common_pb2 as _common_pb2 from google.protobuf import json_format as _json_format @@ -57,7 +58,8 @@ def short_string(self): """ :rtype: Text """ - return str(self.to_flyte_idl()) + literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip() + return f"" def verbose_string(self): """ diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 4f06c3d3c6..e0a864e31e 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -628,7 +628,7 @@ def uri(self) -> str: return self._uri @property - def metadata(self) -> StructuredDatasetMetadata: + def metadata(self) -> Optional[StructuredDatasetMetadata]: return self._metadata def to_flyte_idl(self) -> _literals_pb2.StructuredDataset: diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 7babb859e4..9af90a4b8a 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -36,15 +36,13 @@ class MountType(Enum): """ group: str - key: str + key: Optional[str] = None group_version: Optional[str] = None mount_requirement: MountType = MountType.ANY def __post_init__(self): if self.group is None: raise ValueError("Group is a required parameter") - if self.key is None: - raise ValueError("Key is also a required parameter") def to_flyte_idl(self) -> _sec.Secret: return _sec.Secret( @@ -59,7 +57,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret": return cls( group=pb2_object.group, group_version=pb2_object.group_version if pb2_object.group_version else None, - key=pb2_object.key, + key=pb2_object.key if pb2_object.key else None, mount_requirement=Secret.MountType(pb2_object.mount_requirement), ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index fc79c87a2d..f7f1d710c9 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -868,12 +868,18 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata): class K8sPod(_common.FlyteIdlEntity): - def __init__(self, metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None): + def __init__( + self, + metadata: K8sObjectMetadata = None, + pod_spec: typing.Dict[str, typing.Any] = None, + data_config: typing.Optional[DataLoadingConfig] = None, + ): """ This defines a kubernetes pod target. It will build the pod target during task execution """ self._metadata = metadata self._pod_spec = pod_spec + self._data_config = data_config @property def metadata(self) -> K8sObjectMetadata: @@ -883,10 +889,15 @@ def metadata(self) -> K8sObjectMetadata: def pod_spec(self) -> typing.Dict[str, typing.Any]: return self._pod_spec + @property + def data_config(self) -> typing.Optional[DataLoadingConfig]: + return self._data_config + def to_flyte_idl(self) -> _core_task.K8sPod: return _core_task.K8sPod( metadata=self._metadata.to_flyte_idl(), pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None, + data_config=self.data_config.to_flyte_idl() if self.data_config else None, ) @classmethod @@ -894,6 +905,9 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod): return cls( metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata), pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None, + data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) + if pb2_object.HasField("data_config") + else None, ) diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 4358d7229e..3e3c778d6b 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -259,7 +259,7 @@ def __init__( :param flytekit.models.core.types.StructuredDatasetType structured_dataset_type: structured dataset :param dict[Text, T] metadata: Additional data describing the type :param flytekit.models.annotation.TypeAnnotation annotation: Additional data - describing the type _intended to be saturated by the client_ + describing the type intended to be saturated by the client """ self._simple = simple self._schema = schema diff --git a/flytekit/remote/backfill.py b/flytekit/remote/backfill.py index 154bf4d1b4..2f31889060 100644 --- a/flytekit/remote/backfill.py +++ b/flytekit/remote/backfill.py @@ -68,6 +68,8 @@ def create_backfill_workflow( logging.info(f"Generating backfill from {start_date} -> {end_date}. Parallel?[{parallel}]") wf = ImperativeWorkflow(name=f"backfill-{for_lp.name}") + + input_name = schedule.kickoff_time_input_arg date_iter = croniter(cron_schedule.schedule, start_time=start_date, ret_type=datetime) prev_node = None actual_start = None @@ -79,7 +81,10 @@ def create_backfill_workflow( if next_start_date >= end_date: break actual_end = next_start_date - next_node = wf.add_launch_plan(for_lp, t=next_start_date) + inputs = {} + if input_name: + inputs[input_name] = next_start_date + next_node = wf.add_launch_plan(for_lp, **inputs) next_node = next_node.with_overrides( name=f"b-{next_start_date}", retries=per_node_retries, timeout=per_node_timeout ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 03cc9a66e9..91189ede74 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -6,17 +6,19 @@ from __future__ import annotations import base64 -import functools import hashlib import os import pathlib +import tempfile import time import typing import uuid +from base64 import b64encode from collections import OrderedDict from dataclasses import asdict, dataclass from datetime import datetime, timedelta +import requests from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 @@ -34,7 +36,11 @@ from flytekit.core.type_engine import LiteralsResolver, TypeEngine from flytekit.core.workflow import WorkflowBase from flytekit.exceptions import user as user_exceptions -from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException +from flytekit.exceptions.user import ( + FlyteEntityAlreadyExistsException, + FlyteEntityNotExistException, + FlyteValueException, +) from flytekit.loggers import remote_logger from flytekit.models import common as common_models from flytekit.models import filters as filter_models @@ -62,7 +68,7 @@ from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.fast_registration import fast_package -from flytekit.tools.script_mode import fast_register_single_script, hash_file +from flytekit.tools.script_mode import compress_scripts, hash_file from flytekit.tools.translator import ( FlyteControlPlaneEntity, FlyteLocalEntity, @@ -615,6 +621,10 @@ def _serialize_and_register( version=version, ) is_dummy_serialization_setting = True + + if serialization_settings.version is None: + serialization_settings.version = version + _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) ident = None @@ -704,9 +714,9 @@ def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: s md5_bytes, _ = hash_file(pathlib.Path(zip_file)) # Upload zip file to Admin using FlyteRemote. - return self._upload_file(pathlib.Path(zip_file)) + return self.upload_file(pathlib.Path(zip_file)) - def _upload_file( + def upload_file( self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None ) -> typing.Tuple[bytes, str]: """ @@ -728,7 +738,23 @@ def _upload_file( content_md5=md5_bytes, filename=to_upload.name, ) - self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) + + encoded_md5 = b64encode(md5_bytes) + with open(str(to_upload), "+rb") as local_file: + content = local_file.read() + content_length = len(content) + rsp = requests.put( + upload_location.signed_url, + data=content, + headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5}, + ) + + if rsp.status_code != requests.codes["OK"]: + raise FlyteValueException( + rsp.status_code, + f"Request to send data {upload_location.signed_url} failed.", + ) + remote_logger.debug( f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" ) @@ -773,17 +799,19 @@ def register_script( project: typing.Optional[str] = None, domain: typing.Optional[str] = None, destination_dir: str = ".", - default_launch_plan: typing.Optional[bool] = True, + copy_all: bool = False, + default_launch_plan: bool = True, options: typing.Optional[Options] = None, source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. - :param destination_dir: - :param domain: - :param project: - :param image_config: + :param destination_dir: The destination directory where the workflow will be copied to. + :param copy_all: If true, the entire source directory will be copied over to the destination directory. + :param domain: The domain to register the workflow in. + :param project: The project to register the workflow in. + :param image_config: The image config to use for the workflow. :param version: version for the entity to be registered as :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow @@ -795,16 +823,16 @@ def register_script( if image_config is None: image_config = ImageConfig.auto_default_image() - upload_location, md5_bytes = fast_register_single_script( - source_path, - module_name, - functools.partial( - self.client.get_upload_signed_url, - project=project or self.default_project, - domain=domain or self.default_domain, - filename="scriptmode.tar.gz", - ), - ) + with tempfile.TemporaryDirectory() as tmp_dir: + if copy_all: + md5_bytes, upload_native_url = self.fast_package(pathlib.Path(source_path), False, tmp_dir) + else: + archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) + compress_scripts(source_path, str(archive_fname), module_name) + md5_bytes, upload_native_url = self.upload_file( + archive_fname, project or self.default_project, domain or self.default_domain + ) + serialization_settings = SerializationSettings( project=project, domain=domain, @@ -813,7 +841,7 @@ def register_script( fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, - distribution_location=upload_location.native_url, + distribution_location=upload_native_url, ), ) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index b2f7efcc65..f115480112 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -12,6 +12,7 @@ import click from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.utils import timeit from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore from flytekit.tools.script_mode import tar_strip_file_attributes @@ -97,6 +98,7 @@ def get_additional_distribution_loc(remote_location: str, identifier: str) -> st return posixpath.join(remote_location, "{}.{}".format(identifier, "tar.gz")) +@timeit("Download distribution") def download_distribution(additional_distribution: str, destination: str): """ Downloads a remote code distribution and overwrites any local files. diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 3c9fe64068..82d4c2c226 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -37,7 +37,7 @@ def serialize( :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. """ - + settings.source_root = local_source_root ctx = FlyteContextManager.current_context().with_serialization_settings(settings) with FlyteContextManager.with_context(ctx) as ctx: # Scan all modules. the act of loading populates the global singleton that contains all objects @@ -60,6 +60,8 @@ def serialize_to_folder( """ Serialize the given set of python packages to a folder """ + if folder is None: + folder = "." loaded_entities = serialize(pkgs, settings, local_source_root, options=options) persist_registrable_entities(loaded_entities, folder) diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 29b617824c..ecc71a2398 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -8,13 +8,12 @@ import typing from pathlib import Path -from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 - -from flytekit.core import context_manager +from flytekit import PythonFunctionTask from flytekit.core.tracker import get_full_module_path +from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase -def compress_single_script(source_path: str, destination: str, full_module_name: str): +def compress_scripts(source_path: str, destination: str, module_name: str): """ Compresses the single script while maintaining the folder structure for that file. @@ -39,33 +38,14 @@ def compress_single_script(source_path: str, destination: str, full_module_name: │   ├── example.py │   └── __init__.py - Note how `another_example.py` and `yet_another_example.py` were not copied to the destination. + Note: If `example.py` didn't import tasks or workflows from `another_example.py` and `yet_another_example.py`, these files were not copied to the destination.. + """ with tempfile.TemporaryDirectory() as tmp_dir: destination_path = os.path.join(tmp_dir, "code") - # This is the script relative path to the root of the project - script_relative_path = Path() - # For each package in pkgs, create a directory and copy the __init__.py in it. - # Skip the last package as that is the script file. - pkgs = full_module_name.split(".") - for p in pkgs[:-1]: - os.makedirs(os.path.join(destination_path, p)) - source_path = os.path.join(source_path, p) - destination_path = os.path.join(destination_path, p) - script_relative_path = Path(script_relative_path, p) - init_file = Path(os.path.join(source_path, "__init__.py")) - if init_file.exists(): - shutil.copy(init_file, Path(os.path.join(tmp_dir, "code", script_relative_path, "__init__.py"))) - - # Ensure destination path exists to cover the case of a single file and no modules. - os.makedirs(destination_path, exist_ok=True) - script_file = Path(source_path, f"{pkgs[-1]}.py") - script_file_destination = Path(destination_path, f"{pkgs[-1]}.py") - # Build the final script relative path and copy it to a known place. - shutil.copy( - script_file, - script_file_destination, - ) + + visited: typing.List[str] = [] + copy_module_to_destination(source_path, destination_path, module_name, visited) tar_path = os.path.join(tmp_dir, "tmp.tar") with tarfile.open(tar_path, "w") as tar: tar.add(os.path.join(tmp_dir, "code"), arcname="", filter=tar_strip_file_attributes) @@ -74,6 +54,54 @@ def compress_single_script(source_path: str, destination: str, full_module_name: gzipped.write(tar_file.read()) +def copy_module_to_destination( + original_source_path: str, original_destination_path: str, module_name: str, visited: typing.List[str] +): + """ + Copy the module (file) to the destination directory. If the module relative imports other modules, flytekit will + recursively copy them as well. + """ + mod = importlib.import_module(module_name) + full_module_name = get_full_module_path(mod, mod.__name__) + if full_module_name in visited: + return + visited.append(full_module_name) + + source_path = original_source_path + destination_path = original_destination_path + pkgs = full_module_name.split(".") + + for p in pkgs[:-1]: + os.makedirs(os.path.join(destination_path, p), exist_ok=True) + destination_path = os.path.join(destination_path, p) + source_path = os.path.join(source_path, p) + init_file = Path(os.path.join(source_path, "__init__.py")) + if init_file.exists(): + shutil.copy(init_file, Path(os.path.join(destination_path, "__init__.py"))) + + # Ensure destination path exists to cover the case of a single file and no modules. + os.makedirs(destination_path, exist_ok=True) + script_file = Path(source_path, f"{pkgs[-1]}.py") + script_file_destination = Path(destination_path, f"{pkgs[-1]}.py") + # Build the final script relative path and copy it to a known place. + shutil.copy( + script_file, + script_file_destination, + ) + + # Try to copy other files to destination if tasks or workflows aren't in the same file + for flyte_entity_name in mod.__dict__: + flyte_entity = mod.__dict__[flyte_entity_name] + if ( + isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase)) + and not isinstance(flyte_entity, ImperativeWorkflow) + and flyte_entity.instantiated_in + ): + copy_module_to_destination( + original_source_path, original_destination_path, flyte_entity.instantiated_in, visited + ) + + # Takes in a TarInfo and returns the modified TarInfo: # https://docs.python.org/3/library/tarfile.html#tarinfo-objects # intented to be passed as a filter to tarfile.add @@ -96,24 +124,6 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: return tar_info -def fast_register_single_script( - source_path: str, module_name: str, create_upload_location_fn: typing.Callable -) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes): - - # Open a temp directory and dump the contents of the digest. - with tempfile.TemporaryDirectory() as tmp_dir: - archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz") - mod = importlib.import_module(module_name) - compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__)) - - flyte_ctx = context_manager.FlyteContextManager.current_context() - md5, _ = hash_file(archive_fname) - upload_location = create_upload_location_fn(content_md5=md5) - flyte_ctx.file_access.put_data(archive_fname, upload_location.signed_url) - - return upload_location, md5 - - def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str): """ Hash a file and produce a digest to be used as a version @@ -131,7 +141,7 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str): return h.digest(), h.hexdigest() -def _find_project_root(source_path) -> Path: +def _find_project_root(source_path) -> str: """ Find the root of the project. The root of the project is considered to be the first ancestor from source_path that does @@ -143,4 +153,4 @@ def _find_project_root(source_path) -> Path: path = Path(source_path).parent.resolve() while os.path.exists(os.path.join(path, "__init__.py")): path = path.parent - return path + return str(path) diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 69af2b96b4..86a029d411 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -10,12 +10,10 @@ from flytekit.core import context_manager as flyte_context from flytekit.core.base_task import PythonTask from flytekit.core.workflow import WorkflowBase -from flytekit.exceptions.user import FlyteValidationException from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.workflow import WorkflowSpec -from flytekit.models.core import identifier as _identifier from flytekit.models.task import TaskSpec from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.translator import FlyteControlPlaneEntity, Options, get_serializable @@ -44,20 +42,6 @@ def _should_register_with_admin(entity) -> bool: ) and not isinstance(entity, RemoteEntity) -def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Set[task_models.TaskSpec]: - """ - Given a list of `TaskSpec`, this function returns a set containing the duplicated `TaskSpec` if any exists. - """ - seen: typing.Set[_identifier.Identifier] = set() - duplicate_tasks: typing.Set[task_models.TaskSpec] = set() - for task in tasks: - if task.template.id not in seen: - seen.add(task.template.id) - else: - duplicate_tasks.add(task) - return duplicate_tasks - - def get_registrable_entities( ctx: flyte_context.FlyteContext, options: typing.Optional[Options] = None ) -> typing.List[FlyteControlPlaneEntity]: @@ -78,19 +62,6 @@ def get_registrable_entities( new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) - serializable_tasks: typing.List[task_models.TaskSpec] = [ - entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) - ] - # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same - # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate - # tasks are considered invalid at registration - # time and usually indicate user error, so we catch this common mistake at serialization time. - duplicate_tasks = _find_duplicate_tasks(serializable_tasks) - if len(duplicate_tasks) > 0: - duplicate_task_names = [task.template.id.name for task in duplicate_tasks] - raise FlyteValidationException( - f"Multiple definitions of the following tasks were found: {duplicate_task_names}" - ) return entities_to_be_serialized diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5ec249fa4b..b2835dca10 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -9,6 +9,7 @@ from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode +from flytekit.core.container_task import ContainerTask from flytekit.core.gate import Gate from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.map_task import MapPythonTask @@ -189,7 +190,7 @@ def get_serializable_task( # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. - elif pod: + elif pod and not isinstance(entity, ContainerTask): if isinstance(entity, MapPythonTask): entity.set_command_prefix(get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) @@ -662,11 +663,6 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) - elif isinstance(entity, GateNode): - import ipdb - - ipdb.set_trace() - elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): if entity.should_register: if isinstance(entity, FlyteTask): diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index c2ab8fd438..87b494d0ae 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -28,7 +28,7 @@ TensorBoard. """ -tfrecords_dir = typing.TypeVar("tfrecord") +tfrecords_dir = typing.TypeVar("tfrecords_dir") TFRecordsDirectory = FlyteDirectory[tfrecords_dir] """ This type can be used to denote that the output is a folder that contains tensorflow record files. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index afb59d58d0..f4f23eb72f 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -2,20 +2,25 @@ import os import pathlib +import random import typing from dataclasses import dataclass, field from pathlib import Path +from typing import Any, Generator, Tuple +from uuid import UUID +import fsspec from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types.file import FileExt +from flytekit.types.file import FileExt, FlyteFile T = typing.TypeVar("T") PathType = typing.Union[str, os.PathLike] @@ -115,7 +120,12 @@ def t1(in1: FlyteDirectory["svg"]): field in the ``BlobType``. """ - def __init__(self, path: typing.Union[str, os.PathLike], downloader: typing.Callable = None, remote_directory=None): + def __init__( + self, + path: typing.Union[str, os.PathLike], + downloader: typing.Optional[typing.Callable] = None, + remote_directory: typing.Optional[str] = None, + ): """ :param path: The source path that users are expected to call open() on :param downloader: Optional function that can be passed that used to delay downloading of the actual fil @@ -143,6 +153,18 @@ def __fspath__(self): def extension(cls) -> str: return "" + @classmethod + def new_remote(cls) -> FlyteDirectory: + """ + Create a new FlyteDirectory object using the currently configured default remote in the context (i.e. + the raw_output_prefix configured in the current FileAccessProvider object in the context). + This is used if you explicitly have a folder somewhere that you want to create files under. + If you want to write a whole folder, you can let your task return a FlyteDirectory object, + and let flytekit handle the uploading. + """ + d = FlyteContext.current_context().file_access.get_random_remote_directory() + return FlyteDirectory(path=d) + def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: return cls @@ -171,6 +193,12 @@ def downloaded(self) -> bool: def remote_directory(self) -> typing.Optional[str]: return self._remote_directory + @property + def sep(self) -> str: + if os.name == "nt" and get_protocol(self.path or self.remote_source or self.remote_directory) == "file": + return "\\" + return "/" + @property def remote_source(self) -> str: """ @@ -179,9 +207,67 @@ def remote_source(self) -> str: """ return typing.cast(str, self._remote_source) + def new_file(self, name: typing.Optional[str] = None) -> FlyteFile: + """ + This will create a new file under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + # TODO we may want to use - https://github.com/fsspec/universal_pathlib + if not name: + name = UUID(int=random.getrandbits(128)).hex + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteFile(path=new_path) + + def new_dir(self, name: typing.Optional[str] = None) -> FlyteDirectory: + """ + This will create a new folder under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + if not name: + name = UUID(int=random.getrandbits(128)).hex + + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteDirectory(path=new_path) + def download(self) -> str: return self.__fspath__() + def crawl( + self, maxdepth: typing.Optional[int] = None, topdown: bool = True, **kwargs + ) -> Generator[Tuple[typing.Union[str, os.PathLike[Any]], typing.Dict[Any, Any]], None, None]: + """ + Crawl returns a generator of all files prefixed by any sub-folders under the given "FlyteDirectory". + if details=True is passed, then it will return a dictionary as specified by fsspec. + + Example: + + >>> list(fd.crawl()) + [("/base", "file1"), ("/base", "dir1/file1"), ("/base", "dir2/file1"), ("/base", "dir1/dir/file1")] + + >>> list(x.crawl(detail=True)) + [('/tmp/test', {'my-dir/ab.py': {'name': '/tmp/test/my-dir/ab.py', 'size': 0, 'type': 'file', + 'created': 1677720780.2318847, 'islink': False, 'mode': 33188, 'uid': 501, 'gid': 0, + 'mtime': 1677720780.2317934, 'ino': 1694329, 'nlink': 1}})] + """ + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_directory: + final_path = self.remote_directory + ctx = FlyteContextManager.current_context() + fs = ctx.file_access.get_filesystem_for_path(final_path) + base_path_len = len(fsspec.core.strip_protocol(final_path)) + 1 # Add additional `/` at the end + for base, _, files in fs.walk(final_path, maxdepth, topdown, **kwargs): + current_base = base[base_path_len:] + if isinstance(files, dict): + for f, v in files.items(): + yield final_path, {os.path.join(current_base, f): v} + else: + for f in files: + yield final_path, os.path.join(current_base, f) + def __repr__(self): return self.path diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9fc55f76ce..bb8feb3d9c 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -3,12 +3,14 @@ import os import pathlib import typing +from contextlib import contextmanager from dataclasses import dataclass, field from dataclasses_json import config, dataclass_json from marshmallow import fields +from typing_extensions import Annotated, get_args, get_origin -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.core.types import BlobType @@ -27,7 +29,9 @@ def noop(): @dataclass_json @dataclass class FlyteFile(os.PathLike, typing.Generic[T]): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore + path: typing.Union[str, os.PathLike] = field( + default=None, metadata=config(mm_field=fields.String()) + ) # type: ignore """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -148,6 +152,15 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: def extension(cls) -> str: return "" + @classmethod + def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: + """ + Create a new FlyteFile object with a remote path. + """ + ctx = FlyteContextManager.current_context() + remote_path = ctx.file_access.get_random_remote_path(name) + return cls(path=remote_path) + def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: from . import FileExt @@ -226,6 +239,57 @@ def remote_source(self) -> str: def download(self) -> str: return self.__fspath__() + @contextmanager + def open( + self, + mode: str, + cache_type: typing.Optional[str] = None, + cache_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + ): + """ + Returns a streaming File handle + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with ff.open("rb", cache_type="readahead", cache={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + Alternatively + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with fsspec.open(f"readahead::{ff.remote_path}", "rb", readahead={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + + :param mode: str Open mode like 'rb', 'rt', 'wb', ... + :param cache_type: optional str Specify if caching is to be used. Cache protocol can be ones supported by + fsspec https://filesystem-spec.readthedocs.io/en/latest/api.html#readbuffering, + especially useful for large file reads + :param cache_options: optional Dict[str, Any] Refer to fsspec caching options. This is strongly coupled to the + cache_protocol + """ + ctx = FlyteContextManager.current_context() + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_path: + final_path = self.remote_path + fs = ctx.file_access.get_filesystem_for_path(final_path) + f = fs.open(final_path, mode, cache_type=cache_type, cache_options=cache_options) + yield f + f.close() + def __repr__(self): return self.path @@ -272,6 +336,10 @@ def to_literal( if python_val is None: raise TypeTransformerFailedError("None value cannot be converted to a file.") + # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] + if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") @@ -346,13 +414,13 @@ def to_python_value( return FlyteFile(uri) # The rest of the logic is only for FlyteFile types. - if not issubclass(expected_python_type, FlyteFile): + if not issubclass(expected_python_type, FlyteFile): # type: ignore raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}") # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. if not ctx.file_access.is_remote(uri): - return expected_python_type(uri) + return expected_python_type(uri) # type: ignore # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 38fedfacca..d766818bfd 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -77,7 +77,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return np.load( file=local_path, allow_pickle=metadata.get("allow_pickle", False), - mmap_mode=metadata.get("mmap_mode"), + mmap_mode=metadata.get("mmap_mode"), # type: ignore ) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[np.ndarray]: diff --git a/flytekit/types/pickle/__init__.py b/flytekit/types/pickle/__init__.py index 65604e67bb..e5bd1c056d 100644 --- a/flytekit/types/pickle/__init__.py +++ b/flytekit/types/pickle/__init__.py @@ -9,4 +9,4 @@ FlytePickle """ -from .pickle import FlytePickle +from .pickle import BatchSize, FlytePickle diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 3472dec7e6..3de75b765b 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -13,6 +13,19 @@ T = typing.TypeVar("T") +class BatchSize: + """ + Flyte-specific object used to wrap the hash function for a specific type + """ + + def __init__(self, val: int): + self._val = val + + @property + def val(self) -> int: + return self._val + + class FlytePickle(typing.Generic[T]): """ This type is only used by flytekit internally. User should not use this type. diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 8a8d832b58..ac6b71ba38 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -100,32 +100,38 @@ def write(self, *dfs, **kwargs): class LocalIOSchemaReader(SchemaReader[T]): - def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(from_path), cols, fmt) + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(from_path, cols, fmt) @abstractmethod def _read(self, *path: os.PathLike, **kwargs) -> T: pass def iter(self, **kwargs) -> typing.Generator[T, None, None]: - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - yield self._read(Path(entry.path), **kwargs) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + yield self._read(Path(typing.cast(os.DirEntry, entry).path), **kwargs) def all(self, **kwargs) -> T: files: typing.List[os.PathLike] = [] - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - files.append(Path(entry.path)) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + files.append(Path(typing.cast(os.DirEntry, entry).path)) return self._read(*files, **kwargs) class LocalIOSchemaWriter(SchemaWriter[T]): - def __init__(self, to_local_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(to_local_path), cols, fmt) + def __init__(self, to_local_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(to_local_path, cols, fmt) @abstractmethod def _write(self, df: T, path: os.PathLike, **kwargs): @@ -176,11 +182,10 @@ def get_handler(cls, t: Type) -> SchemaHandler: @dataclass_json @dataclass class FlyteSchema(object): - remote_path: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) """ This is the main schema class that users should use. """ - logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -197,6 +202,7 @@ def format(cls) -> SchemaFormat: def __class_getitem__( cls, columns: typing.Dict[str, typing.Type], fmt: SchemaFormat = SchemaFormat.PARQUET ) -> Type[FlyteSchema]: + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") if columns is None: return FlyteSchema @@ -229,11 +235,12 @@ def format(cls) -> SchemaFormat: def __init__( self, - local_path: os.PathLike = None, - remote_path: os.PathLike = None, + local_path: typing.Optional[str] = None, + remote_path: typing.Optional[str] = None, supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, - downloader: typing.Callable[[str, os.PathLike], None] = None, + downloader: typing.Optional[typing.Callable] = None, ): + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") if supported_mode == SchemaOpenMode.READ and remote_path is None: raise ValueError("To create a FlyteSchema in read mode, remote_path is required") if ( @@ -254,7 +261,7 @@ def __init__( self._downloader = downloader @property - def local_path(self) -> os.PathLike: + def local_path(self) -> str: return self._local_path @property @@ -262,7 +269,7 @@ def supported_mode(self) -> SchemaOpenMode: return self._supported_mode def open( - self, dataframe_fmt: type = pandas.DataFrame, override_mode: SchemaOpenMode = None + self, dataframe_fmt: type = pandas.DataFrame, override_mode: typing.Optional[SchemaOpenMode] = None ) -> typing.Union[SchemaReader, SchemaWriter]: """ Returns a reader or writer depending on the mode of the object when created. This mode can be @@ -290,13 +297,13 @@ def open( self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: - return h.writer(typing.cast(str, self.local_path), self.columns(), self.format()) - return h.reader(typing.cast(str, self.local_path), self.columns(), self.format()) + return h.writer(self.local_path, self.columns(), self.format()) + return h.reader(self.local_path, self.columns(), self.format()) # Remote IO is handled. So we will just pass the remote reference to the object if mode == SchemaOpenMode.WRITE: - return h.writer(self.remote_path, self.columns(), self.format()) - return h.reader(self.remote_path, self.columns(), self.format()) + return h.writer(typing.cast(str, self.remote_path), self.columns(), self.format()) + return h.reader(typing.cast(str, self.remote_path), self.columns(), self.format()) def as_readonly(self) -> FlyteSchema: if self._supported_mode == SchemaOpenMode.READ: @@ -304,7 +311,7 @@ def as_readonly(self) -> FlyteSchema: s = FlyteSchema.__class_getitem__(self.columns(), self.format())( local_path=self.local_path, # Dummy path is ok, as we will assume data is already downloaded and will not download again - remote_path=self.remote_path if self.remote_path else "", + remote_path=typing.cast(str, self.remote_path) if self.remote_path else "", supported_mode=SchemaOpenMode.READ, ) s._downloaded = True diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index e4c6078e94..ca6cab8030 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -17,7 +17,9 @@ class ParquetIO(object): def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs) - def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame: + def read( + self, *files: os.PathLike, columns: typing.Optional[typing.List[str]] = None, **kwargs + ) -> pandas.DataFrame: frames = [self._read(chunk=f, columns=columns, **kwargs) for f in files if os.path.getsize(f) > 0] if len(frames) == 1: return frames[0] @@ -56,7 +58,7 @@ def write( class PandasSchemaReader(LocalIOSchemaReader[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() @@ -65,7 +67,7 @@ def _read(self, *path: os.PathLike, **kwargs) -> pandas.DataFrame: class PandasSchemaWriter(LocalIOSchemaWriter[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 52577a650d..86fa19f4f0 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -13,15 +13,9 @@ """ -from flytekit.configuration.internal import LocalSDK +from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer from flytekit.loggers import logger -from .basic_dfs import ( - ArrowToParquetEncodingHandler, - PandasToParquetEncodingHandler, - ParquetToArrowDecodingHandler, - ParquetToPandasDecodingHandler, -) from .structured_dataset import ( StructuredDataset, StructuredDatasetDecoder, @@ -29,15 +23,42 @@ StructuredDatasetTransformerEngine, ) -try: - from .bigquery import ( - ArrowToBQEncodingHandlers, - BQToArrowDecodingHandler, - BQToPandasDecodingHandler, - PandasToBQEncodingHandlers, - ) -except ImportError: - logger.info( - "We won't register bigquery handler for structured dataset because " - "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" - ) + +def register_pandas_handlers(): + import pandas as pd + + from .basic_dfs import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler + + StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) + + +def register_arrow_handlers(): + import pyarrow as pa + + from .basic_dfs import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler + + StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) + + +def register_bigquery_handlers(): + try: + from .bigquery import ( + ArrowToBQEncodingHandlers, + BQToArrowDecodingHandler, + BQToPandasDecodingHandler, + PandasToBQEncodingHandlers, + ) + + StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) + StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) + StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) + StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) + except ImportError: + logger.info( + "We won't register bigquery handler for structured dataset because " + "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" + ) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 39f8d11e24..8004867271 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -1,14 +1,18 @@ import os import typing +from pathlib import Path from typing import TypeVar import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from botocore.exceptions import NoCredentialsError +from fsspec.core import split_protocol, strip_protocol +from fsspec.utils import get_protocol -from flytekit import FlyteContext -from flytekit.deck import TopFrameRenderer -from flytekit.deck.renderer import ArrowRenderer +from flytekit import FlyteContext, logger +from flytekit.configuration import DataConfig +from flytekit.core.data_persistence import s3_setup_args from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -17,12 +21,20 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, - StructuredDatasetTransformerEngine, ) T = TypeVar("T") +def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]: + protocol = get_protocol(uri) + if protocol == "s3": + kwargs = s3_setup_args(cfg.s3, anon) + if kwargs: + return kwargs + return None + + class PandasToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pd.DataFrame, None, PARQUET) @@ -33,15 +45,19 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + if not ctx.file_access.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + path = os.path.join(uri, f"{0:05}") df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - df.to_parquet(local_path, coerce_timestamps="us", allow_truncated_timestamps=False) - ctx.file_access.upload_directory(local_dir, path) + df.to_parquet( + path, + coerce_timestamps="us", + allow_truncated_timestamps=False, + storage_options=get_storage_options(ctx.file_access.data_config, path), + ) structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): @@ -54,13 +70,17 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) + uri = flyte_value.uri + columns = None + kwargs = get_storage_options(ctx.file_access.data_config, uri) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pd.read_parquet(local_dir, columns=columns) - return pd.read_parquet(local_dir) + try: + return pd.read_parquet(uri, columns=columns, storage_options=kwargs) + except NoCredentialsError: + logger.debug("S3 source detected, attempting anonymous S3 access") + kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + return pd.read_parquet(uri, columns=columns, storage_options=kwargs) class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): @@ -73,13 +93,13 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_path() - df = structured_dataset.dataframe - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - pq.write_table(df, local_path) - ctx.file_access.upload_directory(local_dir, path) - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + if not ctx.file_access.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + path = os.path.join(uri, f"{0:05}") + filesystem = ctx.file_access.get_filesystem_for_path(path) + pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem) + return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): @@ -92,19 +112,20 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) + uri = flyte_value.uri + if not ctx.file_access.is_remote(uri): + Path(uri).parent.mkdir(parents=True, exist_ok=True) + _, path = split_protocol(uri) + + columns = None if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pq.read_table(local_dir, columns=columns) - return pq.read_table(local_dir) - - -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) - -StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) -StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) + try: + fs = ctx.file_access.get_filesystem_for_path(uri) + return pq.read_table(path, filesystem=fs, columns=columns) + except NoCredentialsError as e: + logger.debug("S3 source detected, attempting anonymous S3 access") + fs = ctx.file_access.get_filesystem_for_path(uri, anonymous=True) + if fs is not None: + return pq.read_table(path, filesystem=fs, columns=columns) + raise e diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 85cede1544..049a21c07e 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -14,7 +14,6 @@ StructuredDatasetDecoder, StructuredDatasetEncoder, StructuredDatasetMetadata, - StructuredDatasetTransformerEngine, ) BIGQUERY = "bq" @@ -110,9 +109,3 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata)) - - -StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) -StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) -StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) -StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 0e4649203a..05df91776c 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import os import types import typing from abc import ABC, abstractmethod @@ -9,15 +8,13 @@ from typing import Dict, Generator, Optional, Type, Union import _datetime -import numpy as _np -import pandas as pd -import pyarrow as pa from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin +from flytekit import lazy_module from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.deck.renderer import Renderable from flytekit.loggers import logger @@ -26,6 +23,13 @@ from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType +if typing.TYPE_CHECKING: + import pandas as pd + import pyarrow as pa +else: + pd = lazy_module("pandas") + pa = lazy_module("pyarrow") + T = typing.TypeVar("T") # StructuredDataset type or a dataframe type DF = typing.TypeVar("DF") # Dataframe type @@ -35,6 +39,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" GENERIC_FORMAT: StructuredDatasetFormat = "" +GENERIC_PROTOCOL: str = "generic protocol" @dataclass_json @@ -45,7 +50,7 @@ class StructuredDataset(object): class (that is just a model, a Python class representation of the protobuf). """ - uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod @@ -59,7 +64,7 @@ def column_names(cls) -> typing.List[str]: def __init__( self, dataframe: typing.Optional[typing.Any] = None, - uri: Optional[str, os.PathLike] = None, + uri: typing.Optional[str] = None, metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, **kwargs, ): @@ -74,10 +79,11 @@ def __init__( # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type: Optional[Type[DF]] = None + self._dataframe_type: Optional[DF] = None # type: ignore + self._already_uploaded = False @property - def dataframe(self) -> Optional[Type[DF]]: + def dataframe(self) -> Optional[DF]: return self._dataframe @property @@ -92,7 +98,7 @@ def open(self, dataframe_type: Type[DF]): self._dataframe_type = dataframe_type return self - def all(self) -> DF: + def all(self) -> DF: # type: ignore if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() @@ -109,7 +115,7 @@ def iter(self) -> Generator[DF, None, None]: def extract_cols_and_format( t: typing.Any, -) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional[pa.lib.Schema]]: +) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]: """ Helper function, just used to iterate through Annotations and extract out the following information: - base type, if not Annotated, it will just be the type that was passed in. @@ -143,7 +149,7 @@ def extract_cols_and_format( if ordered_dict_cols is not None: raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") ordered_dict_cols = aa - elif isinstance(aa, pa.Schema): + elif isinstance(aa, pa.lib.Schema): if pa_schema is not None: raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") pa_schema = aa @@ -255,7 +261,7 @@ def decode( ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, - ) -> Union[DF, Generator[DF, None, None]]: + ) -> Union[DF, typing.Iterator[DF]]: """ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal value into a Python instance. @@ -271,11 +277,6 @@ def decode( raise NotImplementedError -def protocol_prefix(uri: str) -> str: - p = DataPersistencePlugins.get_protocol(uri) - return p - - def convert_schema_type_to_structured_dataset_type( column_type: int, ) -> int: @@ -295,16 +296,8 @@ def convert_schema_type_to_structured_dataset_type( raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}") -class DuplicateHandlerError(ValueError): - ... - - -class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): - """ - Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. - If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of - registering with the main type engine, you should register with this transformer instead. - """ +def get_supported_types(): + import numpy as _np _SUPPORTED_TYPES: typing.Dict[Type, LiteralType] = { _np.int32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER), @@ -326,6 +319,19 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): _np.object_: type_models.LiteralType(simple=type_models.SimpleType.STRING), str: type_models.LiteralType(simple=type_models.SimpleType.STRING), } + return _SUPPORTED_TYPES + + +class DuplicateHandlerError(ValueError): + ... + + +class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): + """ + Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. + If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of + registering with the main type engine, you should register with this transformer instead. + """ ENCODERS: Dict[Type, Dict[str, Dict[str, StructuredDatasetEncoder]]] = {} DECODERS: Dict[Type, Dict[str, Dict[str, StructuredDatasetDecoder]]] = {} @@ -337,42 +343,54 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): @classmethod def _finder(cls, handler_map, df_type: Type, protocol: str, format: str): - # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler - # if missing, see if there's a generic format handler. Error if missing. - # If the incoming format requested is the generic format (""), then see if it's present, - # if not, look to see if there is a default format for the df_type and a handler for that format. - # if still missing, look to see if there's only _one_ handler for that type, if so then use that. - if format != GENERIC_FORMAT: - try: - return handler_map[df_type][protocol][format] - except KeyError: - try: - return handler_map[df_type][protocol][GENERIC_FORMAT] - except KeyError: - ... - else: - try: - return handler_map[df_type][protocol][GENERIC_FORMAT] - except KeyError: - if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]: - hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]] - logger.debug( - f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" - f" using the generic handler {hh} instead." - ) - return hh - if len(handler_map[df_type][protocol]) == 1: - hh = list(handler_map[df_type][protocol].values())[0] - logger.debug( - f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}" - ) - return hh + # If there's an exact match, then we should use it. + try: + return handler_map[df_type][protocol][format] + except KeyError: + ... + + fsspec_handler = None + protocol_specific_handler = None + single_handler = None + default_format = cls.DEFAULT_FORMATS.get(df_type, None) + + try: + fss_handlers = handler_map[df_type]["fsspec"] + if format in fss_handlers: + fsspec_handler = fss_handlers[format] + elif GENERIC_FORMAT in fss_handlers: + fsspec_handler = fss_handlers[GENERIC_FORMAT] + else: + if default_format and default_format in fss_handlers and format == GENERIC_FORMAT: + fsspec_handler = fss_handlers[default_format] else: - logger.warning( - f"Did not automatically pick a handler for {df_type}," - f" more than one detected {handler_map[df_type][protocol].keys()}" - ) - raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") + if len(fss_handlers) == 1 and format == GENERIC_FORMAT: + single_handler = list(fss_handlers.values())[0] + else: + ... + except KeyError: + ... + + try: + protocol_handlers = handler_map[df_type][protocol] + if GENERIC_FORMAT in protocol_handlers: + protocol_specific_handler = protocol_handlers[GENERIC_FORMAT] + else: + if default_format and default_format in protocol_handlers: + protocol_specific_handler = protocol_handlers[default_format] + else: + if len(protocol_handlers) == 1: + single_handler = list(protocol_handlers.values())[0] + else: + ... + + except KeyError: + ... + + if protocol_specific_handler or fsspec_handler or single_handler: + return protocol_specific_handler or fsspec_handler or single_handler + else: + raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") @classmethod def get_encoder(cls, df_type: Type, protocol: str, format: str): @@ -437,18 +455,12 @@ def register( if h.protocol is None: if default_for_type: raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.") - for persistence_protocol in DataPersistencePlugins.supported_protocols(): - # TODO: Clean this up when we get to replacing the persistence layer. - # The behavior of the protocols given in the supported_protocols and is_supported_protocol - # is not actually the same as the one returned in get_protocol. - stripped = DataPersistencePlugins.get_protocol(persistence_protocol) - logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") - try: - cls.register_for_protocol( - h, stripped, False, override, default_format_for_type, default_storage_for_type - ) - except DuplicateHandlerError: - logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") + try: + cls.register_for_protocol( + h, "fsspec", False, override, default_format_for_type, default_storage_for_type + ) + except DuplicateHandlerError: + logger.debug(f"Skipping generic fsspec protocol for handler {h} because duplicate") elif h.protocol == "": raise ValueError(f"Use None instead of empty string for registering handler {h}") @@ -471,8 +483,7 @@ def register_for_protocol( See the main register function instead. """ if protocol == "/": - # TODO: Special fix again, because get_protocol returns file, instead of file:// - protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL) + protocol = "file" lowest_level = cls._handler_finder(h, protocol) if h.supported_format in lowest_level and override is False: raise DuplicateHandlerError( @@ -543,13 +554,15 @@ def to_literal( # def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]: # return dataset if python_val._literal_sd is not None: + if python_val._already_uploaded: + return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) if python_val.dataframe is not None: raise ValueError( f"Shouldn't have specified both literal {python_val._literal_sd} and dataframe {python_val.dataframe}" ) return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) - # 2. A task returns a python StructuredDataset with a uri. + # 2. A task returns a python StructuredDataset with an uri. # Note: this case is also what happens we start a local execution of a task with a python StructuredDataset. # It gets converted into a literal first, then back into a python StructuredDataset. # @@ -594,7 +607,7 @@ def _protocol_from_type_or_prefix(self, ctx: FlyteContext, df_type: Type, uri: O if df_type in self.DEFAULT_PROTOCOLS: return self.DEFAULT_PROTOCOLS[df_type] else: - protocol = protocol_prefix(uri or ctx.file_access.raw_output_prefix) + protocol = get_protocol(uri or ctx.file_access.raw_output_prefix) logger.debug( f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.file_access.raw_output_prefix}" ) @@ -617,13 +630,16 @@ def encode( # least as good as the type of the interface. if sd_model.metadata is None: sd_model._metadata = StructuredDatasetMetadata(structured_literal_type) - if sd_model.metadata.structured_dataset_type is None: + if sd_model.metadata and sd_model.metadata.structured_dataset_type is None: sd_model.metadata._structured_dataset_type = structured_literal_type # Always set the format here to the format of the handler. # Note that this will always be the same as the incoming format except for when the fallback handler # with a format of "" is used. sd_model.metadata._structured_dataset_type.format = handler.supported_format - return Literal(scalar=Scalar(structured_dataset=sd_model)) + lit = Literal(scalar=Scalar(structured_dataset=sd_model)) + sd._literal_sd = sd_model + sd._already_uploaded = True + return lit def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset @@ -747,7 +763,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ # Here we only render column information by default instead of opening the structured dataset. col = typing.cast(StructuredDataset, python_val).columns() df = pd.DataFrame(col, ["column type"]) - return df.to_html() + return df.to_html() # type: ignore else: df = python_val @@ -770,7 +786,7 @@ def open_as( :param updated_metadata: New metadata type, since it might be different from the metadata in the literal. :return: dataframe. It could be pandas dataframe or arrow table, etc. """ - protocol = protocol_prefix(sd.uri) + protocol = get_protocol(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) result = decoder.decode(ctx, sd, updated_metadata) if isinstance(result, types.GeneratorType): @@ -783,17 +799,17 @@ def iter_as( sd: literals.StructuredDataset, df_type: Type[DF], updated_metadata: StructuredDatasetMetadata, - ) -> Generator[DF, None, None]: - protocol = protocol_prefix(sd.uri) + ) -> typing.Iterator[DF]: + protocol = get_protocol(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] - result = decoder.decode(ctx, sd, updated_metadata) + result: Union[DF, typing.Iterator[DF]] = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") return result def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: - if t in self._SUPPORTED_TYPES: - return self._SUPPORTED_TYPES[t] + if t in get_supported_types(): + return get_supported_types()[t] if hasattr(t, "__origin__") and t.__origin__ == list: return type_models.LiteralType(collection_type=self._get_dataset_column_literal_type(t.__args__[0])) if hasattr(t, "__origin__") and t.__origin__ == dict: @@ -801,7 +817,7 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: raise AssertionError(f"type {t} is currently not supported by StructuredDataset") def _convert_ordered_dict_of_columns_to_list( - self, column_map: typing.OrderedDict[str, Type] + self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: @@ -812,10 +828,13 @@ def _convert_ordered_dict_of_columns_to_list( return converted_cols def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: - original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) + original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information - converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) + converted_cols: typing.List[ + StructuredDatasetType.DatasetColumn + ] = self._convert_ordered_dict_of_columns_to_list(column_map) + return StructuredDatasetType( columns=converted_cols, format=storage_format, diff --git a/plugins/flytekit-aws-sagemaker/requirements.txt b/plugins/flytekit-aws-sagemaker/requirements.txt index 64d03c18f2..37771df5ce 100644 --- a/plugins/flytekit-aws-sagemaker/requirements.txt +++ b/plugins/flytekit-aws-sagemaker/requirements.txt @@ -46,7 +46,8 @@ cryptography==39.0.1 dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via + # retry2 deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -152,8 +153,6 @@ protoc-gen-swagger==0.1.0 # via flyteidl psutil==5.9.4 # via sagemaker-training -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit pycparser==2.21 @@ -193,8 +192,8 @@ requests==2.28.2 # responses responses==0.22.0 # via flytekit -retry==0.9.2 - # via flytekit +retry2==0.9.5 + # via flytekitplugins-awssagemaker retrying==1.3.4 # via sagemaker-training s3transfer==0.6.0 diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 9be6800e49..ade15df0e8 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0", "retry2==0.9.5"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py index cba899669b..416a021516 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py @@ -11,4 +11,5 @@ BigQueryTask """ +from .backend_plugin import BigQueryPlugin from .task import BigQueryConfig, BigQueryTask diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py new file mode 100644 index 0000000000..acd5ece430 --- /dev/null +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py @@ -0,0 +1,94 @@ +import datetime +from typing import Dict, Optional + +import grpc +from flyteidl.service.external_plugin_service_pb2 import ( + SUCCEEDED, + TaskCreateResponse, + TaskDeleteResponse, + TaskGetResponse, +) +from google.cloud import bigquery + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry, convert_to_flyte_state +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.models.types import LiteralType, StructuredDatasetType + +pythonTypeToBigQueryType: Dict[type, str] = { + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes + list: "ARRAY", + bool: "BOOL", + bytes: "BYTES", + datetime.datetime: "DATETIME", + float: "FLOAT64", + int: "INT64", + str: "STRING", +} + + +class BigQueryPlugin(BackendPluginBase): + def __init__(self): + super().__init__(task_type="bigquery_query_job_task") + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> TaskCreateResponse: + job_config = None + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + + logger.info(f"Create BigQuery job config with inputs: {native_inputs}") + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ScalarQueryParameter(name, pythonTypeToBigQueryType[python_interface_inputs[name]], val) + for name, val in native_inputs.items() + ] + ) + + custom = task_template.custom + client = bigquery.Client(project=custom["ProjectID"], location=custom["Location"]) + query_job = client.query(task_template.sql.statement, job_config=job_config) + + return TaskCreateResponse(job_id=str(query_job.job_id)) + + def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + client = bigquery.Client() + job = client.get_job(job_id) + cur_state = convert_to_flyte_state(str(job.state)) + res = None + + if cur_state == SUCCEEDED: + ctx = FlyteContextManager.current_context() + output_location = f"bq://{job.destination.project}:{job.destination.dataset_id}.{job.destination.table_id}" + res = literals.LiteralMap( + { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(uri=output_location), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } + ) + + return TaskGetResponse(state=cur_state, outputs=res.to_flyte_idl()) + + def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse: + client = bigquery.Client() + client.cancel_job(job_id) + return TaskDeleteResponse() + + +BackendPluginRegistry.register(BigQueryPlugin()) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 1d4a7f0dbd..7c24e9e3e9 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -5,10 +5,10 @@ from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct -from flytekit import StructuredDataset from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask from flytekit.models import task as _task_model +from flytekit.types.structured import StructuredDataset @dataclass @@ -81,3 +81,6 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI) return sql + + def execute(self, **kwargs) -> Any: + raise Exception("Cannot run a SQL Task natively, please mock.") diff --git a/plugins/flytekit-bigquery/setup.py b/plugins/flytekit-bigquery/setup.py index 88f77429a2..b97b00ae1a 100644 --- a/plugins/flytekit-bigquery/setup.py +++ b/plugins/flytekit-bigquery/setup.py @@ -33,4 +33,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-bigquery/tests/test_backend_plugin.py b/plugins/flytekit-bigquery/tests/test_backend_plugin.py new file mode 100644 index 0000000000..c95cf308a7 --- /dev/null +++ b/plugins/flytekit-bigquery/tests/test_backend_plugin.py @@ -0,0 +1,94 @@ +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +from flyteidl.service.external_plugin_service_pb2 import SUCCEEDED + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_plugin import BackendPluginRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Sql, TaskTemplate + + +@mock.patch("google.cloud.bigquery.job.QueryJob") +@mock.patch("google.cloud.bigquery.Client") +def test_bigquery_plugin(mock_client, mock_query_job): + job_id = "dummy_id" + mock_instance = mock_client.return_value + mock_query_job_instance = mock_query_job.return_value + mock_query_job_instance.state.return_value = "SUCCEEDED" + mock_query_job_instance.job_id.return_value = job_id + + class MockDestination: + def __init__(self): + self.project = "dummy_project" + self.dataset_id = "dummy_dataset" + self.table_id = "dummy_table" + + class MockJob: + def __init__(self): + self.state = "SUCCEEDED" + self.job_id = job_id + self.destination = MockDestination() + + mock_instance.get_job.return_value = MockJob() + mock_instance.query.return_value = MockJob() + mock_instance.cancel_job.return_value = MockJob() + + ctx = MagicMock(spec=grpc.ServicerContext) + p = BackendPluginRegistry.get_plugin(ctx, "bigquery_query_job_task") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + task_config = { + "Location": "us-central1", + "ProjectID": "dummy_project", + } + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="bigquery_query_job_task", + sql=Sql("SELECT 1"), + ) + + assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == job_id + res = p.get(ctx, job_id) + assert res.state == SUCCEEDED + assert ( + res.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" + ) + p.delete(ctx, job_id) + mock_instance.cancel_job.assert_called() diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index 68ee456ed6..e69de29bb2 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -1,53 +0,0 @@ -""" -.. currentmodule:: flytekitplugins.fsspec - -This package contains things that are useful when extending Flytekit. - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - ArrowToParquetEncodingHandler - FSSpecPersistence - PandasToParquetEncodingHandler - ParquetToArrowDecodingHandler - ParquetToPandasDecodingHandler -""" - -__all__ = [ - "ArrowToParquetEncodingHandler", - "FSSpecPersistence", - "PandasToParquetEncodingHandler", - "ParquetToArrowDecodingHandler", - "ParquetToPandasDecodingHandler", -] - -import importlib - -from flytekit import StructuredDatasetTransformerEngine, logger - -from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler -from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler -from .persist import FSSpecPersistence - -S3 = "s3" -ABFS = "abfs" -GCS = "gs" - - -def _register(protocol: str): - logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.") - StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), True, True) - - -if importlib.util.find_spec("adlfs"): - _register(ABFS) - -if importlib.util.find_spec("s3fs"): - _register(S3) - -if importlib.util.find_spec("gcsfs"): - _register(GCS) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py deleted file mode 100644 index ec8d5f975e..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import typing -from pathlib import Path - -import pyarrow as pa -import pyarrow.parquet as pq -from botocore.exceptions import NoCredentialsError -from flytekitplugins.fsspec.persist import FSSpecPersistence -from fsspec.core import split_protocol, strip_protocol - -from flytekit import FlyteContext, logger -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - PARQUET, - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, -) - - -class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - if not ctx.file_access.is_remote(uri): - Path(uri).mkdir(parents=True, exist_ok=True) - path = os.path.join(uri, f"{0:05}") - fp = FSSpecPersistence(data_config=ctx.file_access.data_config) - filesystem = fp.get_filesystem(path) - pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem) - return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pa.Table: - uri = flyte_value.uri - if not ctx.file_access.is_remote(uri): - Path(uri).parent.mkdir(parents=True, exist_ok=True) - _, path = split_protocol(uri) - - columns = None - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - try: - fp = FSSpecPersistence(data_config=ctx.file_access.data_config) - fs = fp.get_filesystem(uri) - return pq.read_table(path, filesystem=fs, columns=columns) - except NoCredentialsError as e: - logger.debug("S3 source detected, attempting anonymous S3 access") - fs = FSSpecPersistence.get_anonymous_filesystem(uri) - if fs is not None: - return pq.read_table(path, filesystem=fs, columns=columns) - raise e diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py deleted file mode 100644 index e4986ed9f6..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -import typing -from pathlib import Path - -import pandas as pd -from botocore.exceptions import NoCredentialsError -from flytekitplugins.fsspec.persist import FSSpecPersistence, s3_setup_args - -from flytekit import FlyteContext, logger -from flytekit.configuration import DataConfig -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - PARQUET, - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, -) - - -def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]: - protocol = FSSpecPersistence.get_protocol(uri) - if protocol == "s3": - kwargs = s3_setup_args(cfg.s3) - if kwargs: - return kwargs - return None - - -class PandasToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - if not ctx.file_access.is_remote(uri): - Path(uri).mkdir(parents=True, exist_ok=True) - path = os.path.join(uri, f"{0:05}") - df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - df.to_parquet( - path, - coerce_timestamps="us", - allow_truncated_timestamps=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), - ) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - uri = flyte_value.uri - columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - try: - return pd.read_parquet(uri, columns=columns, storage_options=kwargs) - except NoCredentialsError: - logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs["anon"] = True - return pd.read_parquet(uri, columns=columns, storage_options=kwargs) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py deleted file mode 100644 index b890b3cc6c..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import typing - -import fsspec -from fsspec.registry import known_implementations - -from flytekit.configuration import DataConfig, S3Config -from flytekit.extend import DataPersistence, DataPersistencePlugins -from flytekit.loggers import logger - -# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 -# for key and secret -_FSSPEC_S3_KEY_ID = "key" -_FSSPEC_S3_SECRET = "secret" - - -def s3_setup_args(s3_cfg: S3Config): - kwargs = {} - if s3_cfg.access_key_id: - kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - - if s3_cfg.secret_access_key: - kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key - - # S3fs takes this as a special arg - if s3_cfg.endpoint is not None: - kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} - - return kwargs - - -class FSSpecPersistence(DataPersistence): - """ - This DataPersistence plugin uses fsspec to perform the IO. - NOTE: The put is not as performant as it can be for multiple files because of - - https://github.com/intake/filesystem_spec/issues/724. Once this bug is fixed, we can remove the `HACK` in the put - method - """ - - def __init__(self, default_prefix=None, data_config: typing.Optional[DataConfig] = None): - super(FSSpecPersistence, self).__init__(name="fsspec-persistence", default_prefix=default_prefix) - self.default_protocol = self.get_protocol(default_prefix) - self._data_cfg = data_config if data_config else DataConfig.auto() - - @staticmethod - def get_protocol(path: typing.Optional[str] = None): - if path: - return DataPersistencePlugins.get_protocol(path) - logger.info("Setting protocol to file") - return "file" - - def get_filesystem(self, path: str) -> fsspec.AbstractFileSystem: - protocol = FSSpecPersistence.get_protocol(path) - kwargs = {} - if protocol == "file": - kwargs = {"auto_mkdir": True} - elif protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - return fsspec.filesystem(protocol, **kwargs) # type: ignore - - def get_anonymous_filesystem(self, path: str) -> typing.Optional[fsspec.AbstractFileSystem]: - protocol = FSSpecPersistence.get_protocol(path) - if protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore - return anonymous_fs - return None - - @staticmethod - def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: - if not f.endswith("*"): - f = os.path.join(f, "*") - if not t.endswith("/"): - t += "/" - return f, t - - def exists(self, path: str) -> bool: - try: - fs = self.get_filesystem(path) - return fs.exists(path) - except OSError as oe: - logger.debug(f"Error in exists checking {path} {oe}") - fs = self.get_anonymous_filesystem(path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 exists check") - return fs.exists(path) - raise oe - - def get(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(from_path) - if recursive: - from_path, to_path = self.recursive_paths(from_path, to_path) - try: - return fs.get(from_path, to_path, recursive=recursive) - except OSError as oe: - logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") - fs = self.get_anonymous_filesystem(from_path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 access") - return fs.get(from_path, to_path, recursive=recursive) - raise oe - - def put(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(to_path) - if recursive: - from_path, to_path = self.recursive_paths(from_path, to_path) - # BEGIN HACK! - # Once https://github.com/intake/filesystem_spec/issues/724 is fixed, delete the special recursive handling - from fsspec.implementations.local import LocalFileSystem - from fsspec.utils import other_paths - - lfs = LocalFileSystem() - try: - lpaths = lfs.expand_path(from_path, recursive=recursive) - except FileNotFoundError: - # In some cases, there is no file in the original directory, so we just skip copying the file to the remote path - logger.debug(f"there is no file in the {from_path}") - return - rpaths = other_paths(lpaths, to_path) - for l, r in zip(lpaths, rpaths): - fs.put_file(l, r) - return - # END OF HACK!! - return fs.put(from_path, to_path, recursive=recursive) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - path_list = list(paths) # make type check happy - if add_prefix: - path_list.insert(0, self.default_prefix) # type: ignore - path = "/".join(path_list) - if add_protocol: - return f"{self.default_protocol}://{path}" - return typing.cast(str, path) - - -def _register(): - logger.info("Registering fsspec known implementations and overriding all default implementations for persistence.") - DataPersistencePlugins.register_plugin("/", FSSpecPersistence, force=True) - for k, v in known_implementations.items(): - DataPersistencePlugins.register_plugin(f"{k}://", FSSpecPersistence, force=True) - - -# Registering all plugins -_register() diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index 5e6712f396..7102e4fb5f 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "fsspec<=2023.1", "botocore>=1.7.48", "pandas>=1.2.0"] +plugin_requires = [] __version__ = "0.0.0+develop" @@ -13,7 +13,7 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="This package data-plugins for flytekit, that are powered by fsspec", + description="This is a deprecated plugin as of flytekit 1.5", url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-data-fsspec", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -22,9 +22,9 @@ install_requires=plugin_requires, extras_require={ # https://github.com/fsspec/filesystem_spec/blob/master/setup.py#L36 - "abfs": ["adlfs>=2022.2.0"], - "aws": ["s3fs>=2021.7.0"], - "gcp": ["gcsfs>=2021.7.0"], + "abfs": [], + "aws": [], + "gcp": [], }, license="apache2", python_requires=">=3.7", @@ -42,5 +42,4 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py b/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py deleted file mode 100644 index 434a763a93..0000000000 --- a/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py +++ /dev/null @@ -1,44 +0,0 @@ -import pandas as pd -import pyarrow as pa -from flytekitplugins.fsspec.pandas import get_storage_options - -from flytekit import kwtypes, task -from flytekit.configuration import DataConfig, S3Config - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - - -def test_get_storage_options(): - endpoint = "https://s3.amazonaws.com" - - options = get_storage_options(DataConfig(s3=S3Config(endpoint=endpoint)), "s3://bucket/somewhere") - assert options == {"client_kwargs": {"endpoint_url": endpoint}} - - options = get_storage_options(DataConfig(), "/tmp/file") - assert options is None - - -cols = kwtypes(Name=str, Age=int) -subset_cols = kwtypes(Name=str) - - -@task -def t1( - df1: Annotated[pd.DataFrame, cols], df2: Annotated[pa.Table, cols] -) -> (Annotated[pd.DataFrame, subset_cols], Annotated[pa.Table, subset_cols]): - return df1, df2 - - -def test_structured_dataset_wf(): - pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - pa_df = pa.Table.from_pandas(pd_df) - - subset_pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"]}) - subset_pa_df = pa.Table.from_pandas(subset_pd_df) - - df1, df2 = t1(df1=pd_df, df2=pa_df) - assert df1.equals(subset_pd_df) - assert df2.equals(subset_pa_df) diff --git a/plugins/flytekit-data-fsspec/tests/test_persist.py b/plugins/flytekit-data-fsspec/tests/test_persist.py deleted file mode 100644 index 8e87c9c5eb..0000000000 --- a/plugins/flytekit-data-fsspec/tests/test_persist.py +++ /dev/null @@ -1,183 +0,0 @@ -import os -import pathlib -import tempfile - -import mock -from flytekitplugins.fsspec.persist import FSSpecPersistence, s3_setup_args -from fsspec.implementations.local import LocalFileSystem - -from flytekit.configuration import S3Config - - -def test_s3_setup_args(): - kwargs = s3_setup_args(S3Config()) - assert kwargs == {} - - kwargs = s3_setup_args(S3Config(endpoint="http://localhost:30084")) - assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} - - kwargs = s3_setup_args(S3Config(access_key_id="access")) - assert kwargs == {"key": "access"} - - -@mock.patch.dict(os.environ, {}, clear=True) -def test_s3_setup_args_env_empty(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {} - - -@mock.patch.dict( - os.environ, - { - "AWS_ACCESS_KEY_ID": "ignore-user", - "AWS_SECRET_ACCESS_KEY": "ignore-secret", - "FLYTE_AWS_ACCESS_KEY_ID": "flyte", - "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", - }, - clear=True, -) -def test_s3_setup_args_env_both(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret"} - - -@mock.patch.dict( - os.environ, - { - "FLYTE_AWS_ACCESS_KEY_ID": "flyte", - "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", - }, - clear=True, -) -def test_s3_setup_args_env_flyte(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret"} - - -@mock.patch.dict( - os.environ, - { - "AWS_ACCESS_KEY_ID": "ignore-user", - "AWS_SECRET_ACCESS_KEY": "ignore-secret", - }, - clear=True, -) -def test_s3_setup_args_env_aws(): - kwargs = s3_setup_args(S3Config.auto()) - # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default - assert kwargs == {} - - -def test_get_protocol(): - assert FSSpecPersistence.get_protocol("s3://abc") == "s3" - assert FSSpecPersistence.get_protocol("/abc") == "file" - assert FSSpecPersistence.get_protocol("file://abc") == "file" - assert FSSpecPersistence.get_protocol("gs://abc") == "gs" - assert FSSpecPersistence.get_protocol("sftp://abc") == "sftp" - assert FSSpecPersistence.get_protocol("abfs://abc") == "abfs" - - -def test_get_anonymous_filesystem(): - fp = FSSpecPersistence() - fs = fp.get_anonymous_filesystem("/abc") - assert fs is None - fs = fp.get_anonymous_filesystem("s3://abc") - assert fs is not None - assert fs.protocol == ["s3", "s3a"] - - -def test_get_filesystem(): - fp = FSSpecPersistence() - fs = fp.get_filesystem("/abc") - assert fs is not None - assert isinstance(fs, LocalFileSystem) - - -def test_recursive_paths(): - f, t = FSSpecPersistence.recursive_paths("/tmp", "/tmp") - assert (f, t) == ("/tmp/*", "/tmp/") - f, t = FSSpecPersistence.recursive_paths("/tmp/", "/tmp/") - assert (f, t) == ("/tmp/*", "/tmp/") - f, t = FSSpecPersistence.recursive_paths("/tmp/*", "/tmp") - assert (f, t) == ("/tmp/*", "/tmp/") - - -def test_exists(): - fs = FSSpecPersistence() - assert not fs.exists("/tmp/non-existent") - - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - assert fs.exists(f) - - -def test_get(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - t = os.path.join(tdir, "t.txt") - - fs.get(f, t) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_get_recursive(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - p = pathlib.Path(tdir) - d = p.joinpath("d") - d.mkdir() - f = d.joinpath(d, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - o = p.joinpath("o") - - t = o.joinpath(o, "f.txt") - fs.get(str(d), str(o), recursive=True) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_put(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - t = os.path.join(tdir, "t.txt") - - fs.put(f, t) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_put_recursive(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - p = pathlib.Path(tdir) - d = p.joinpath("d") - d.mkdir() - f = d.joinpath(d, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - o = p.joinpath("o") - - t = o.joinpath(o, "f.txt") - fs.put(str(d), str(o), recursive=True) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_construct_path(): - fs = FSSpecPersistence() - assert fs.construct_path(True, False, "abc") == "file://abc" diff --git a/plugins/flytekit-data-fsspec/tests/test_placeholder.py b/plugins/flytekit-data-fsspec/tests/test_placeholder.py new file mode 100644 index 0000000000..eb6dc82a34 --- /dev/null +++ b/plugins/flytekit-data-fsspec/tests/test_placeholder.py @@ -0,0 +1,3 @@ +# This test is here to give pytest something to run, otherwise it returns a non-zero return code. +def test_dummy(): + assert 1 + 1 == 2 diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index d2b44e0b65..c090ea6a46 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -1,7 +1,18 @@ -import markdown -import pandas -import plotly.express as px -from pandas_profiling import ProfileReport +from typing import TYPE_CHECKING, List, Optional, Union + +from flytekit import lazy_module +from flytekit.types.file import FlyteFile + +if TYPE_CHECKING: + import markdown + import pandas as pd + import PIL + import plotly.express as px +else: + pd = lazy_module("pandas") + markdown = lazy_module("markdown") + px = lazy_module("plotly.express") + PIL = lazy_module("PIL") class FrameProfilingRenderer: @@ -12,9 +23,11 @@ class FrameProfilingRenderer: def __init__(self, title: str = "Pandas Profiling Report"): self._title = title - def to_html(self, df: pandas.DataFrame) -> str: - assert isinstance(df, pandas.DataFrame) - profile = ProfileReport(df, title=self._title) + def to_html(self, df: "pd.DataFrame") -> str: + assert isinstance(df, pd.DataFrame) + import ydata_profiling + + profile = ydata_profiling.ProfileReport(df, title=self._title) return profile.to_html() @@ -37,7 +50,7 @@ class BoxRenderer: Each box spans from quartile 1 (Q1) to quartile 3 (Q3). The second quartile (Q2) is marked by a line inside the box. By default, the - whiskers correspond to the box' edges +/- 1.5 times the interquartile + whiskers correspond to the box edges +/- 1.5 times the interquartile range (IQR: Q3-Q1), see "points" for other options. """ @@ -45,6 +58,116 @@ class BoxRenderer: def __init__(self, column_name): self._column_name = column_name - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: "pd.DataFrame") -> str: fig = px.box(df, y=self._column_name) return fig.to_html() + + +class ImageRenderer: + """Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data + represented as a base64-encoded string. + """ + + def to_html(self, image_src: Union[FlyteFile, "PIL.Image.Image"]) -> str: + img = self._get_image_object(image_src) + return self._image_to_html_string(img) + + @staticmethod + def _get_image_object(image_src: Union[FlyteFile, "PIL.Image.Image"]) -> "PIL.Image.Image": + if isinstance(image_src, FlyteFile): + local_path = image_src.download() + return PIL.Image.open(local_path) + elif isinstance(image_src, PIL.Image.Image): + return image_src + else: + raise ValueError("Unsupported image source type") + + @staticmethod + def _image_to_html_string(img: "PIL.Image.Image") -> str: + import base64 + from io import BytesIO + + buffered = BytesIO() + img.save(buffered, format="PNG") + img_base64 = base64.b64encode(buffered.getvalue()).decode() + return f'Rendered Image' + + +class TableRenderer: + """ + Convert a pandas DataFrame into an HTML table. + """ + + def to_html(self, df: pd.DataFrame, header_labels: Optional[List] = None, table_width: Optional[int] = None) -> str: + # Check if custom labels are provided and have the correct length + if header_labels is not None and len(header_labels) == len(df.columns): + df = df.copy() + df.columns = header_labels + + style = f""" + + """ + return style + df.to_html(classes="table-class", index=False) + + +class GanttChartRenderer: + """ + This renderer is primarily used by the timeline deck. The input DataFrame should + have at least the following columns: + - "Start": datetime.datetime (represents the start time) + - "Finish": datetime.datetime (represents the end time) + - "Name": string (the name of the task or event) + """ + + def to_html(self, df: pd.DataFrame, chart_width: Optional[int] = None) -> str: + fig = px.timeline(df, x_start="Start", x_end="Finish", y="Name", color="Name", width=chart_width) + + fig.update_xaxes( + tickangle=90, + rangeslider_visible=True, + tickformatstops=[ + dict(dtickrange=[None, 1], value="%3f ms"), + dict(dtickrange=[1, 60], value="%S:%3f s"), + dict(dtickrange=[60, 3600], value="%M:%S m"), + dict(dtickrange=[3600, None], value="%H:%M h"), + ], + ) + + # Remove y-axis tick labels and title since the time line deck space is limited. + fig.update_yaxes(showticklabels=False, title="") + + fig.update_layout( + autosize=True, + # Set the orientation of the legend to horizontal and move the legend anchor 2% beyond the top of the timeline graph's vertical axis + legend=dict(orientation="h", y=1.02), + ) + + return fig.to_html() diff --git a/plugins/flytekit-deck-standard/tests/test_renderer.py b/plugins/flytekit-deck-standard/tests/test_renderer.py index 79eb7e877d..1878193733 100644 --- a/plugins/flytekit-deck-standard/tests/test_renderer.py +++ b/plugins/flytekit-deck-standard/tests/test_renderer.py @@ -1,8 +1,33 @@ +import datetime +import tempfile + import markdown import pandas as pd -from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer +import pytest +from flytekitplugins.deck.renderer import ( + BoxRenderer, + FrameProfilingRenderer, + GanttChartRenderer, + ImageRenderer, + MarkdownRenderer, + TableRenderer, +) +from PIL import Image + +from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) +time_info_df = pd.DataFrame( + [ + dict( + Name="foo", + Start=datetime.datetime.utcnow(), + Finish=datetime.datetime.utcnow() + datetime.timedelta(microseconds=1000), + WallTime=1.0, + ProcessTime=1.0, + ) + ] +) def test_frame_profiling_renderer(): @@ -19,3 +44,39 @@ def test_markdown_renderer(): def test_box_renderer(): renderer = BoxRenderer("Name") assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title() + + +def create_simple_image(fmt: str): + """Create a simple PNG image using PIL""" + img = Image.new("RGB", (100, 100), color="black") + tmp = tempfile.mktemp() + img.save(tmp, fmt) + return tmp + + +png_image = create_simple_image(fmt="png") +jpeg_image = create_simple_image(fmt="jpeg") + + +@pytest.mark.parametrize( + "image_src", + [ + FlyteFile(path=png_image), + JPEGImageFile(path=jpeg_image), + PNGImageFile(path=png_image), + Image.open(png_image), + ], +) +def test_image_renderer(image_src): + renderer = ImageRenderer() + assert " str: +# return "hello" +``` diff --git a/plugins/flytekit-envd/flytekitplugins/envd/__init__.py b/plugins/flytekit-envd/flytekitplugins/envd/__init__.py new file mode 100644 index 0000000000..d3dec806a1 --- /dev/null +++ b/plugins/flytekit-envd/flytekitplugins/envd/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.envd + +This plugin enables seamless integration between Flyte and envd. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + EnvdImageSpecBuilder +""" + +from .image_builder import EnvdImageSpecBuilder diff --git a/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py new file mode 100644 index 0000000000..fec6647443 --- /dev/null +++ b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py @@ -0,0 +1,73 @@ +import pathlib +import shutil +import subprocess + +import click + +from flytekit.configuration import DefaultImages +from flytekit.core import context_manager +from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, ImageSpec, ImageSpecBuilder + + +class EnvdImageSpecBuilder(ImageSpecBuilder): + """ + This class is used to build a docker image using envd. + """ + + def build_image(self, image_spec: ImageSpec): + cfg_path = create_envd_config(image_spec) + command = f"envd build --path {pathlib.Path(cfg_path).parent}" + if image_spec.registry: + command += f" --output type=image,name={image_spec.image_name()},push=true" + click.secho(f"Run command: {command} ", fg="blue") + p = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + for line in iter(p.stdout.readline, ""): + if p.poll() is not None: + break + if line.decode().strip() != "": + click.secho(line.decode().strip(), fg="blue") + + if p.returncode != 0: + _, stderr = p.communicate() + raise Exception( + f"failed to build the imageSpec at {cfg_path} with error {stderr}", + ) + + +def create_envd_config(image_spec: ImageSpec) -> str: + base_image = DefaultImages.default_image() if image_spec.base_image is None else image_spec.base_image + packages = [] if image_spec.packages is None else image_spec.packages + apt_packages = [] if image_spec.apt_packages is None else image_spec.apt_packages + env = {"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()} + if image_spec.env: + env.update(image_spec.env) + + envd_config = f"""# syntax=v1 + +def build(): + base(image="{base_image}", dev=False) + install.python_packages(name = [{', '.join(map(str, map(lambda x: f'"{x}"', packages)))}]) + install.apt_packages(name = [{', '.join(map(str, map(lambda x: f'"{x}"', apt_packages)))}]) + runtime.environ(env={env}) +""" + + if image_spec.python_version: + # Indentation is required by envd + envd_config += f' install.python(version="{image_spec.python_version}")\n' + + ctx = context_manager.FlyteContextManager.current_context() + cfg_path = ctx.file_access.get_random_local_path("build.envd") + pathlib.Path(cfg_path).parent.mkdir(parents=True, exist_ok=True) + + if image_spec.source_root: + shutil.copytree(image_spec.source_root, pathlib.Path(cfg_path).parent, dirs_exist_ok=True) + # Indentation is required by envd + envd_config += ' io.copy(host_path="./", envd_path="/root")' + + with open(cfg_path, "w+") as f: + f.write(envd_config) + + return cfg_path + + +ImageBuildEngine.register("envd", EnvdImageSpecBuilder()) diff --git a/plugins/flytekit-envd/requirements.in b/plugins/flytekit-envd/requirements.in new file mode 100644 index 0000000000..16b527ba7e --- /dev/null +++ b/plugins/flytekit-envd/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-envd diff --git a/plugins/flytekit-envd/requirements.txt b/plugins/flytekit-envd/requirements.txt new file mode 100644 index 0000000000..78e9a287ca --- /dev/null +++ b/plugins/flytekit-envd/requirements.txt @@ -0,0 +1,229 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-envd + # via -r requirements.in +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests +cffi==1.15.1 + # via cryptography +chardet==5.1.0 + # via binaryornot +charset-normalizer==3.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.2.1 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.8 + # via flytekit +cryptography==40.0.1 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +envd==0.3.16 + # via flytekitplugins-envd +flyteidl==1.3.15 + # via flytekit +flytekit==1.5.0 + # via flytekitplugins-envd +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-auth==2.17.1 + # via kubernetes +googleapis-common-protos==1.59.0 + # via + # flyteidl + # flytekit + # grpcio-status +grpcio==1.53.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.53.0 + # via flytekit +idna==3.4 + # via requests +importlib-metadata==6.1.0 + # via + # flytekit + # keyring +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.13.1 + # via flytekit +kubernetes==26.1.0 + # via flytekit +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.1.0 + # via jaraco-classes +mypy-extensions==1.0.0 + # via typing-inspect +natsort==8.3.1 + # via flytekit +numpy==1.23.5 + # via + # flytekit + # pandas + # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.0 + # via + # docker + # marshmallow +pandas==1.5.3 + # via flytekit +protobuf==4.22.1 + # via + # flyteidl + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==10.0.1 + # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pycparser==2.21 + # via cffi +pyopenssl==23.1.1 + # via flytekit +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # kubernetes + # pandas +python-json-logger==2.0.7 + # via flytekit +python-slugify==8.0.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2023.3 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # kubernetes + # responses +regex==2023.3.23 + # via docker-image-py +requests==2.28.2 + # via + # cookiecutter + # docker + # flytekit + # kubernetes + # requests-oauthlib + # responses +requests-oauthlib==1.3.1 + # via kubernetes +responses==0.23.1 + # via flytekit +retry==0.9.2 + # via flytekit +rsa==4.9 + # via google-auth +six==1.16.0 + # via + # google-auth + # kubernetes + # python-dateutil +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +types-pyyaml==6.0.12.9 + # via responses +typing-extensions==4.5.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.15 + # via + # docker + # flytekit + # kubernetes + # requests + # responses +websocket-client==1.5.1 + # via + # docker + # kubernetes +wheel==0.40.0 + # via flytekit +wrapt==1.15.0 + # via + # deprecated + # flytekit +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-envd/setup.py b/plugins/flytekit-envd/setup.py new file mode 100644 index 0000000000..d95a260958 --- /dev/null +++ b/plugins/flytekit-envd/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "envd" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit", "envd"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables users to easily build a Docker image for tasks or workflows.", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-envd/tests/__init__.py b/plugins/flytekit-envd/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-envd/tests/test_image_spec.py b/plugins/flytekit-envd/tests/test_image_spec.py new file mode 100644 index 0000000000..7c7ccd2151 --- /dev/null +++ b/plugins/flytekit-envd/tests/test_image_spec.py @@ -0,0 +1,31 @@ +from pathlib import Path + +from flytekitplugins.envd.image_builder import EnvdImageSpecBuilder, create_envd_config + +from flytekit.image_spec.image_spec import ImageSpec + + +def test_image_spec(): + image_spec = ImageSpec( + packages=["pandas"], + apt_packages=["git"], + python_version="3.8", + registry="", + base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", + ) + + EnvdImageSpecBuilder().build_image(image_spec) + config_path = create_envd_config(image_spec) + contents = Path(config_path).read_text() + assert ( + contents + == """# syntax=v1 + +def build(): + base(image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", dev=False) + install.python_packages(name = ["pandas"]) + install.apt_packages(name = ["git"]) + runtime.environ(env={'PYTHONPATH': '/root', '_F_IMG_ID': 'flytekit:yZ8jICcDTLoDArmNHbWNwg..'}) + install.python(version="3.8") +""" + ) diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 0d6788ac92..014b88f4f3 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -355,8 +355,12 @@ def simple_pod_task(i: int): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.test_pod", "task-name", diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index 6920c34e84..df5c74288e 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import MPIJob +from .task import HorovodJob, MPIJob diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 6f207b421d..e1c1be0a03 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -133,5 +133,62 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) +@dataclass +class HorovodJob(object): + slots: int + num_launcher_replicas: int = 1 + num_workers: int = 1 + + +class HorovodFunctionTask(MPIFunctionTask): + """ + For more info, check out https://github.com/horovod/horovod + """ + + # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. + ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" + discovery_script_path = "/etc/mpi/discover_hosts.sh" + + def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): + + super().__init__( + task_config=task_config, + task_function=task_function, + **kwargs, + ) + + def get_command(self, settings: SerializationSettings) -> List[str]: + cmd = super().get_command(settings) + mpi_cmd = self._get_horovod_prefix() + cmd + return mpi_cmd + + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + config = super().get_config(settings) + return {**config, "worker_spec_command": self.ssh_command} + + def _get_horovod_prefix(self) -> List[str]: + np = self.task_config.num_workers * self.task_config.slots + base_cmd = [ + "horovodrun", + "-np", + f"{np}", + "--verbose", + "--log-level", + "INFO", + "--network-interface", + "eth0", + "--min-np", + f"{np}", + "--max-np", + f"{np}", + "--slots-per-host", + f"{self.task_config.slots}", + "--host-discovery-script", + self.discovery_script_path, + ] + return base_cmd + + # Register the MPI Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) +TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask) diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index ebb0c49b58..7732d520c2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,4 +1,4 @@ -from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel +from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -41,3 +41,26 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} assert my_mpi_task.task_type == "mpi" + + +def test_horovod_task(): + @task( + task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + ) + def my_horovod_task(): + ... + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + cmd = my_horovod_task.get_command(settings) + assert "horovodrun" in cmd + config = my_horovod_task.get_config(settings) + assert "/usr/sbin/sshd" in config["worker_spec_command"] + custom = my_horovod_task.get_custom(settings) + assert isinstance(custom, dict) is True diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 280fe687b6..7de27502bf 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -2,6 +2,9 @@ This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends. +This plugin can execute torch elastic training, which is equivalent to run `torchrun`. Elastic training can be executed +in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner. + To install the plugin, run the following command: ```bash diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index aedb0b192f..cb9add7302 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -8,6 +8,7 @@ :toctree: generated/ PyTorch + Elastic """ -from .task import PyTorch +from .task import Elastic, PyTorch diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py deleted file mode 100644 index 517f4a9eb6..0000000000 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py +++ /dev/null @@ -1,23 +0,0 @@ -from flyteidl.plugins import pytorch_pb2 as _pytorch_task - -from flytekit.models import common as _common - - -class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): - self._workers_count = workers_count - - @property - def workers_count(self): - return self._workers_count - - def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 4b0bde78b0..aea2c9a2e6 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -2,16 +2,20 @@ This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ +import os from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Union +import cloudpickle +from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask from google.protobuf.json_format import MessageToDict +import flytekit from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings -from flytekit.extend import TaskPlugins +from flytekit.extend import IgnoreOutputs, TaskPlugins -from .models import PyTorchJob +TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @dataclass @@ -29,6 +33,31 @@ class PyTorch(object): num_workers: int +@dataclass +class Elastic(object): + """ + Configuration for `torch elastic training `_. + + Use this to run single- or multi-node distributed pytorch elastic training on k8s. + + Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1. + Multi-node training is executed otherwise using a `Pytorch Job `_. + + Args: + nnodes (Union[int, str]): Number of nodes, or the range of nodes in form :. + nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int]. + start_method (str): Multiprocessing start method to use when creating workers. + monitor_interval (int): Interval, in seconds, to monitor the state of workers. + max_restarts (int): Maximum number of worker group restarts before failing. + """ + + nnodes: Union[int, str] = 1 + nproc_per_node: Union[int, str] = "auto" + start_method: str = "spawn" + monitor_interval: int = 5 + max_restarts: int = 0 + + class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator) @@ -46,9 +75,173 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = PyTorchJob(workers_count=self.task_config.num_workers) - return MessageToDict(job.to_flyte_idl()) + job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers) + return MessageToDict(job) # Register the Pytorch Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask) + + +def spawn_helper(fn: bytes, kwargs) -> Any: + """Help to spawn worker processes. + + The purpose of this function is to 1) be pickleable so that it can be used with + the multiprocessing start method `spawn` and 2) to call a cloudpickle-serialized + function passed to it. This function itself doesn't have to be pickleable. Without + such a helper task functions, which are not pickleable, couldn't be used with the + start method `spawn`. + + Args: + fn (bytes): Cloudpickle-serialized target function to be executed in the worker process. + + Returns: + The return value of the received target function. + """ + fn = cloudpickle.loads(fn) + return_val = fn(**kwargs) + return return_val + + +class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): + """ + Plugin for distributed training with torch elastic/torchrun (see + https://pytorch.org/docs/stable/elastic/run.html). + """ + + _ELASTIC_TASK_TYPE = "pytorch" + _ELASTIC_TASK_TYPE_STANDALONE = "python-task" + + def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): + task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE + + super(PytorchElasticFunctionTask, self).__init__( + task_config=task_config, + task_type=task_type, + task_function=task_function, + **kwargs, + ) + try: + from torch.distributed import run + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.task_config.nnodes)) + + """ + c10d is the backend recommended by torch elastic. + https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend + + For c10d, no backend server has to be deployed. + https://pytorch.org/docs/stable/elastic/run.html#deployment + Instead, the workers will use the master's address as the rendezvous point. + """ + self.rdzv_backend = "c10d" + + def _execute(self, **kwargs) -> Any: + """ + This helper method will be invoked to execute the task. + + + Returns: + The result of rank zero. + """ + try: + from torch.distributed import run + from torch.distributed.launcher.api import LaunchConfig, elastic_launch + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + + if isinstance(self.task_config.nproc_per_node, str): + nproc = run.determine_local_world_size(self.task_config.nproc_per_node) + else: + nproc = self.task_config.nproc_per_node + + config = LaunchConfig( + run_id=flytekit.current_context().execution_id.name, + min_nodes=self.min_nodes, + max_nodes=self.max_nodes, + nproc_per_node=nproc, + rdzv_backend=self.rdzv_backend, # rdzv settings + rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"), + max_restarts=self.task_config.max_restarts, + monitor_interval=self.task_config.monitor_interval, + start_method=self.task_config.start_method, + ) + + if self.task_config.start_method == "spawn": + """ + We use cloudpickle to serialize the non-pickleable task function. + The torch elastic launcher then launches the spawn_helper function (which is pickleable) + instead of the task function. This helper function, in the child-process, then deserializes + the task function, again with cloudpickle, and executes it. + """ + launcher_target_func = spawn_helper + + dumped_target_function = cloudpickle.dumps(self._task_function) + launcher_args = (dumped_target_function, kwargs) + elif self.task_config.start_method == "fork": + """ + The torch elastic launcher doesn't support passing kwargs to the target function, + only args. Flyte only works with kwargs. Thus, we create a closure which already has + the task kwargs bound. We tell the torch elastic launcher to start this function in + the child processes. + """ + + def fn_partial(): + """Closure of the task function with kwargs already bound.""" + return self._task_function(**kwargs) + + launcher_target_func = fn_partial + launcher_args = () + + else: + raise Exception("Bad start method") + + out = elastic_launch( + config=config, + entrypoint=launcher_target_func, + )(*launcher_args) + + # `out` is a dictionary of rank (not local rank) -> result + # Rank 0 returns the result of the task function + if 0 in out: + return out[0] + else: + raise IgnoreOutputs() + + def execute(self, **kwargs) -> Any: + """ + This method will be invoked to execute the task. + + Handles the exception scope for the `_execute` method. + """ + from flytekit.exceptions import scopes as exception_scopes + + return exception_scopes.user_entry_point(self._execute)(**kwargs) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + if self.task_config.nnodes == 1: + """ + Torch elastic distributed training is executed in a normal k8s pod so that this + works without the kubeflow train operator. + """ + return super().get_custom(settings) + else: + from flyteidl.plugins.pytorch_pb2 import ElasticConfig + + elastic_config = ElasticConfig( + rdzv_backend=self.rdzv_backend, + min_replicas=self.min_nodes, + max_replicas=self.max_nodes, + nproc_per_node=self.task_config.nproc_per_node, + max_restarts=self.task_config.max_restarts, + ) + job = DistributedPyTorchTrainingTask( + workers=self.max_nodes, + elastic_config=elastic_config, + ) + return MessageToDict(job) + + +# Register the PytorchElastic Plugin into the flytekit core plugin system +TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 96fa577a3e..ac3c2c174a 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,44 +1,91 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfpytorch # via -r requirements.in +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.1.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot -charset-normalizer==3.0.1 - # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit cloudpickle==2.2.1 - # via flytekit + # via + # flytekit + # flytekitplugins-kfpytorch cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.3.14 # via flytekit -cryptography==39.0.1 +cryptography==40.0.2 # via + # adal + # azure-identity + # azure-storage-blob + # msal + # pyjwt # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit docker==6.0.1 # via flytekit @@ -46,13 +93,55 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 - # via flytekit -flytekit==1.2.7 +flyteidl==1.2.10 + # via + # flytekit + # flytekitplugins-kfpytorch +flytekit==1.2.9 # via flytekitplugins-kfpytorch -googleapis-common-protos==1.58.0 +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.4.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.4.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.17.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.8.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.0 # via # flyteidl + # flytekit + # google-api-core # grpcio-status grpcio==1.48.2 # via @@ -61,14 +150,16 @@ grpcio==1.48.2 grpcio-status==1.48.2 # via flytekit idna==3.4 - # via requests -importlib-metadata==6.0.0 + # via + # requests + # yarl +importlib-metadata==6.6.0 # via # click # flytekit # keyring -importlib-resources==5.12.0 - # via keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -81,10 +172,14 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -98,43 +193,68 @@ marshmallow-jsonschema==0.13.0 # via flytekit more-itertools==9.0.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.2.0 # via flytekit -numpy==1.21.6 +numpy==1.24.3 # via # flytekit # pandas # pyarrow -packaging==23.0 +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 # via # docker # marshmallow pandas==1.3.5 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==3.20.3 # via # flyteidl - # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.0.0 +pyjwt[crypto]==2.6.0 + # via + # adal + # msal +pyopenssl==23.1.1 # via flytekit python-dateutil==2.8.2 # via + # adal # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -142,7 +262,7 @@ python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.7.1 +pytz==2023.3 # via # flytekit # pandas @@ -150,17 +270,34 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.10.31 + # kubernetes + # responses +regex==2023.3.23 # via docker-image-py requests==2.28.2 # via + # adal + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses -responses==0.22.0 +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 # via flytekit -retry==0.9.2 +rsa==4.9 + # via google-auth +s3fs==2023.4.0 # via flytekit secretstorage==3.3.3 # via keyring @@ -168,21 +305,27 @@ singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes # python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -toml==0.10.2 - # via responses -types-toml==0.10.8.5 +types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 # via - # arrow + # aioitertools + # azure-core + # azure-storage-blob # flytekit # importlib-metadata # responses @@ -191,19 +334,27 @@ typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.14 # via + # botocore # docker # flytekit + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker -wheel==0.38.4 + # via + # docker + # kubernetes +wheel==0.40.0 # via flytekit wrapt==1.14.1 # via + # aiobotocore # deprecated # flytekit -zipp==3.14.0 - # via - # importlib-metadata - # importlib-resources +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index c45e409567..a207b9381e 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0"] +plugin_requires = ["cloudpickle", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.10,<1.3.0"] __version__ = "0.0.0+develop" @@ -17,6 +17,9 @@ namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, + extras_require={ + "elastic": ["torch>=1.9.0"], + }, license="apache2", python_requires=">=3.7", classifiers=[ diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py new file mode 100644 index 0000000000..2ca6c9cc65 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -0,0 +1,67 @@ +import os +import typing +from dataclasses import dataclass + +import pytest +import torch +import torch.distributed as dist +from dataclasses_json import dataclass_json +from flytekitplugins.kfpytorch.task import Elastic + +from flytekit import task, workflow + + +@dataclass_json +@dataclass +class Config: + lr: float = 1e-5 + bs: int = 64 + name: str = "foo" + + +def dist_communicate() -> int: + """Communicate between distributed workers.""" + rank = torch.distributed.get_rank() + world_size = dist.get_world_size() + tensor = torch.tensor([5], dtype=torch.int64) + 2 * rank + world_size + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor.item() + + +def train(config: Config) -> typing.Tuple[str, Config, torch.nn.Module, int]: + """Mock training a model using torch-elastic for test purposes.""" + dist.init_process_group(backend="gloo") + + local_rank = os.environ["LOCAL_RANK"] + + out_model = torch.nn.Linear(1000, int(local_rank) + 1) + config.name = "elastic-test" + + distributed_result = dist_communicate() + + return f"result from local rank {local_rank}", config, out_model, distributed_result + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +def test_end_to_end(start_method: str) -> None: + """Test that the workflow with elastic task runs end to end.""" + world_size = 2 + + train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) + + @workflow + def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: + return train_task(config=config) + + r, cfg, m, distributed_result = wf() + assert "result from local rank 0" in r + assert cfg.name == "elastic-test" + assert m.in_features == 1000 + assert m.out_features == 1 + """ + The distributed result is calculated by the workers of the elastic train + task by performing a `dist.all_reduce` operation. The correct result can + only be obtained if the distributed process group is initialized correctly. + """ + assert distributed_result == sum([5 + 2 * rank + world_size for rank in range(world_size)]) diff --git a/plugins/flytekit-mlflow/dev-requirements.txt b/plugins/flytekit-mlflow/dev-requirements.txt index 6ad9be49bb..5788aeb7d2 100644 --- a/plugins/flytekit-mlflow/dev-requirements.txt +++ b/plugins/flytekit-mlflow/dev-requirements.txt @@ -36,11 +36,7 @@ h5py==3.7.0 # via tensorflow idna==3.4 # via requests -importlib-metadata==5.0.0 - # via markdown -keras==2.10.0 - # via tensorflow -keras-preprocessing==1.1.2 +keras==2.11.0 # via tensorflow libclang==14.0.6 # via tensorflow @@ -51,7 +47,6 @@ markupsafe==2.1.1 numpy==1.23.4 # via # h5py - # keras-preprocessing # opt-einsum # tensorboard # tensorflow @@ -87,17 +82,16 @@ six==1.16.0 # google-auth # google-pasta # grpcio - # keras-preprocessing # tensorflow -tensorboard==2.10.1 +tensorboard==2.11.2 # via tensorflow tensorboard-data-server==0.6.1 # via tensorboard tensorboard-plugin-wit==1.8.1 # via tensorboard -tensorflow==2.10.0 +tensorflow==2.11.1 # via -r dev-requirements.in -tensorflow-estimator==2.10.0 +tensorflow-estimator==2.11.0 # via tensorflow tensorflow-io-gcs-filesystem==0.27.0 # via tensorflow @@ -115,8 +109,6 @@ wheel==0.38.3 # tensorboard wrapt==1.14.1 # via tensorflow -zipp==3.10.0 - # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py index b58aa4a120..9fa897f90e 100644 --- a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -13,7 +13,7 @@ from flytekit import FlyteContextManager from flytekit.bin.entrypoint import get_one_of from flytekit.core.context_manager import ExecutionState -from flytekit.deck import TopFrameRenderer +from flytekit.deck.renderer import TopFrameRenderer def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py index b196327d8d..613cbfcd76 100644 --- a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -29,4 +29,4 @@ def train_model(epochs: int): def test_local_exec(): train_model(epochs=1) - assert len(flytekit.current_context().decks) == 4 # mlflow metrics, params, input, and output + assert len(flytekit.current_context().decks) == 5 # mlflow metrics, params, timeline, input, and output diff --git a/plugins/flytekit-pandera/tests/test_plugin.py b/plugins/flytekit-pandera/tests/test_plugin.py index cc9b26c4fa..7e73aac932 100644 --- a/plugins/flytekit-pandera/tests/test_plugin.py +++ b/plugins/flytekit-pandera/tests/test_plugin.py @@ -55,7 +55,13 @@ def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: return transform2(df=transform1(df=df)) - with pytest.raises(pandera.errors.SchemaError, match="^expected series 'col2' to have type float64, got object"): + with pytest.raises( + pandera.errors.SchemaError, + match=( + "^Encountered error while executing workflow 'test_plugin.wf_with_df_input':\n" + " expected series 'col2' to have type float64, got object" + ), + ): wf_with_df_input(df=invalid_df) # raise error when executing workflow with invalid output @@ -67,7 +73,14 @@ def transform2_noop(df: pandera.typing.DataFrame[IntermediateSchema]) -> pandera def wf_invalid_output(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: return transform2_noop(df=transform1(df=df)) - with pytest.raises(TypeError, match="^Failed to convert return value"): + with pytest.raises( + TypeError, + match=( + "^Encountered error while executing workflow 'test_plugin.wf_invalid_output':\n" + " Error encountered while executing 'wf_invalid_output':\n" + " Failed to convert outputs of task" + ), + ): wf_invalid_output(df=valid_df) diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py index bce2ef2653..648ba1c6e0 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py @@ -11,4 +11,4 @@ record_outputs """ -from .task import NotebookTask, record_outputs +from .task import NotebookTask, load_flytedirectory, load_flytefile, load_structureddataset, record_outputs diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 6c160c2690..b1f472e99a 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -1,23 +1,29 @@ import json +import logging import os +import sys +import tempfile import typing from typing import Any import nbformat import papermill as pm +from flyteidl.core.literals_pb2 import Literal as _pb2_Literal from flyteidl.core.literals_pb2 import LiteralMap as _pb2_LiteralMap from google.protobuf import text_format as _text_format from nbconvert import HTMLExporter -from flytekit import FlyteContext, PythonInstanceTask +from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset from flytekit.configuration import SerializationSettings +from flytekit.core import utils from flytekit.core.context_manager import ExecutionParameters from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger from flytekit.models import task as task_models -from flytekit.models.literals import LiteralMap -from flytekit.types.file import HTMLPage, PythonNotebook +from flytekit.models.literals import Literal, LiteralMap +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook T = typing.TypeVar("T") @@ -26,6 +32,8 @@ def _dummy_task_func(): return None +SAVE_AS_LITERAL = (FlyteFile, FlyteDirectory, StructuredDataset) + PAPERMILL_TASK_PREFIX = "pm.nb" @@ -86,6 +94,13 @@ class NotebookTask(PythonInstanceTask[T]): Users can access these notebooks after execution of the task locally or from remote servers. + .. note: + + By default, print statements in your notebook won't be transmitted to the pod logs/stdout. If you would + like to have logs forwarded as the notebook executes, pass the stream_logs argument. Note that notebook + logs can be quite verbose, so ensure you are prepared for any downstream log ingestion costs + (e.g., cloudwatch) + .. todo: Implicit extraction of SparkConfiguration from the notebook is not supported. @@ -114,6 +129,7 @@ def __init__( name: str, notebook_path: str, render_deck: bool = False, + stream_logs: bool = False, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, @@ -135,6 +151,16 @@ def __init__( self._notebook_path = os.path.abspath(notebook_path) self._render_deck = render_deck + self._stream_logs = stream_logs + + # Send the papermill logger to stdout so that it appears in pod logs. Note that papermill doesn't allow + # injecting a logger, so we cannot redirect logs to the flyte child loggers (e.g., the userspace logger) + # and inherit their settings, but we instead must send logs to stdout directly + if self._stream_logs: + papermill_logger = logging.getLogger("papermill") + papermill_logger.addHandler(logging.StreamHandler(sys.stdout)) + # Papermill leaves the default level of DEBUG. We increase it here. + papermill_logger.setLevel(logging.INFO) if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") @@ -235,8 +261,12 @@ def execute(self, **kwargs) -> Any: singleton """ logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.") + for k, v in kwargs.items(): + if isinstance(v, SAVE_AS_LITERAL): + kwargs[k] = save_python_val_to_file(v) + # Execute Notebook via Papermill. - pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs) # type: ignore + pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore outputs = self.extract_outputs(self.output_notebook_path) self.render_nb_html(self.output_notebook_path, self.rendered_output_path) @@ -245,6 +275,7 @@ def execute(self, **kwargs) -> Any: if outputs: m = outputs.literals output_list = [] + for k, type_v in self.python_interface.outputs.items(): if k == self._IMPLICIT_OP_NOTEBOOK: output_list.append(self.output_notebook_path) @@ -254,7 +285,7 @@ def execute(self, **kwargs) -> Any: v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v) output_list.append(v) else: - raise RuntimeError(f"Expected output {k} of type {v} not found in the notebook outputs") + raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") return tuple(output_list) @@ -287,3 +318,80 @@ def record_outputs(**kwargs) -> str: lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected) m[k] = lit return LiteralMap(literals=m).to_flyte_idl() + + +def save_python_val_to_file(input: Any) -> str: + """Save a python value to a local file as a Flyte literal. + + Args: + input (Any): the python value + + Returns: + str: the path to the file + """ + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(type(input)) + lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected) + + tmp_file = tempfile.mktemp(suffix="bin") + utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file) + return tmp_file + + +def load_python_val_from_file(path: str, dtype: T) -> T: + """Loads a python value from a Flyte literal saved to a local file. + + If the path matches the type, it is returned as is. This enables + reusing the parameters cell for local development. + + Args: + path (str): path to the file + dtype (T): the type of the literal + + Returns: + T: the python value of the literal + """ + if isinstance(path, dtype): + return path + + proto = utils.load_proto_from_file(_pb2_Literal, path) + lit = Literal.from_flyte_idl(proto) + ctx = FlyteContext.current_context() + python_value = TypeEngine.to_python_value(ctx, lit, dtype) + return python_value + + +def load_flytefile(path: str) -> T: + """Loads a FlyteFile from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=FlyteFile) + + +def load_flytedirectory(path: str) -> T: + """Loads a FlyteDirectory from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=FlyteDirectory) + + +def load_structureddataset(path: str) -> T: + """Loads a StructuredDataset from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=StructuredDataset) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 1947d09445..0e54e7082e 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,14 +1,17 @@ import datetime import os +import tempfile +import pandas as pd from flytekitplugins.papermill import NotebookTask from flytekitplugins.pod import Pod from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import kwtypes +from flytekit import StructuredDataset, kwtypes, task from flytekit.configuration import Image, ImageConfig -from flytekit.types.file import PythonNotebook +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, PythonNotebook from .testdata.datatype import X @@ -134,3 +137,38 @@ def test_notebook_pod_task(): nb.get_command(serialization_settings) == nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"] ) + + +def test_flyte_types(): + @task + def create_file() -> FlyteFile: + tmp_file = tempfile.mktemp() + with open(tmp_file, "w") as f: + f.write("abc") + return FlyteFile(path=tmp_file) + + @task + def create_dir() -> FlyteDirectory: + tmp_dir = tempfile.mkdtemp() + with open(os.path.join(tmp_dir, "file.txt"), "w") as f: + f.write("abc") + return FlyteDirectory(path=tmp_dir) + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + return StructuredDataset(dataframe=df) + + ff = create_file() + fd = create_dir() + sd = create_sd() + + nb_name = "nb-types" + nb_types = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset), + outputs=kwtypes(success=bool), + ) + success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) + assert success is True, "Notebook execution failed" diff --git a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb index ebdf9a3c71..1ad7aaed4a 100644 --- a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb +++ b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb @@ -34,7 +34,6 @@ "outputs": [], "source": [ "from flytekitplugins.papermill import record_outputs\n", - "\n", "record_outputs(square=out)" ] }, @@ -49,7 +48,7 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -63,9 +62,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.10.10" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb new file mode 100644 index 0000000000..824b1d39ae --- /dev/null +++ b/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "ff = None\n", + "fd = None\n", + "sd = None" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from flytekitplugins.papermill import (\n", + " load_flytefile, load_flytedirectory, load_structureddataset,\n", + " record_outputs\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ff = load_flytefile(ff)\n", + "fd = load_flytedirectory(fd)\n", + "sd = load_structureddataset(sd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read file\n", + "with open(ff.download(), 'r') as f:\n", + " text = f.read()\n", + " assert text == \"abc\", \"Text does not match\"\n", + "\n", + "# check file inside directory\n", + "with open(os.path.join(fd.download(),\"file.txt\"), 'r') as f:\n", + " text = f.read()\n", + " assert text == \"abc\", \"Text does not match\"\n", + "\n", + "# check dataset\n", + "df = sd.open(pd.DataFrame).all()\n", + "expected = pd.DataFrame({\"a\": [1, 2], \"b\": [3, 4]})\n", + "assert df.equals(expected), \"Dataframes do not match\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "outputs" + ] + }, + "outputs": [], + "source": [ + "record_outputs(success=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0b5bf8e577..4290c88ae4 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -7,6 +7,7 @@ from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.basic_dfs import get_storage_options from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -62,12 +63,12 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + uri = flyte_value.uri + kwargs = get_storage_options(ctx.file_access.data_config, uri) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) - return pl.read_parquet(local_dir, use_pyarrow=True) + return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs) + return pl.read_parquet(uri, use_pyarrow=True, storage_options=kwargs) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index 15a195e5d5..23fbf6d441 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -1,3 +1,5 @@ +import tempfile + import pandas as pd import polars as pl from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer @@ -79,3 +81,13 @@ def create_sd() -> StructuredDataset: sd = create_sd() polars_df = sd.open(pl.DataFrame).all() assert pl.DataFrame(data).frame_equal(polars_df) + + tmp = tempfile.mktemp() + pl.DataFrame(data).write_parquet(tmp) + + @task + def t1(sd: StructuredDataset) -> pl.DataFrame: + return sd.open(pd.DataFrame).all() + + sd = StructuredDataset(uri=tmp) + t1(sd=sd).frame_equal(polars_df) diff --git a/plugins/flytekit-spark/Dockerfile b/plugins/flytekit-spark/Dockerfile new file mode 100644 index 0000000000..0789df45b7 --- /dev/null +++ b/plugins/flytekit-spark/Dockerfile @@ -0,0 +1,14 @@ +# https://github.com/apache/spark/blob/master/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile +FROM apache/spark-py:3.3.1 +LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit + +USER 0 +RUN ln -s /usr/bin/python3 /usr/bin/python + +ARG VERSION +RUN pip install flytekitplugins-spark==$VERSION +RUN pip install flytekit==$VERSION + +RUN chown -R ${spark_uid}:${spark_uid} /root +WORKDIR /root +USER ${spark_uid} diff --git a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py index e48778ad70..4afb257f9d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py @@ -1,4 +1,3 @@ -import pathlib from typing import Type from pyspark.ml import PipelineModel @@ -24,22 +23,17 @@ def to_literal( python_type: Type[PipelineModel], expected: LiteralType, ) -> Literal: - local_path = ctx.file_access.get_random_local_path() - pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - python_val.save(local_path) - + # Must write to remote directory remote_dir = ctx.file_access.get_random_remote_directory() - ctx.file_access.upload_directory(local_path, remote_dir) + python_val.write().overwrite().save(remote_dir) return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PipelineModel] ) -> PipelineModel: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.download_directory(lv.scalar.blob.uri, local_dir) - - return PipelineModel.load(local_dir) + remote_dir = lv.scalar.blob.uri + return PipelineModel.load(remote_dir) TypeEngine.register(PySparkPipelineModelTransformer()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7b32e9f28b..564e55778f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -1,15 +1,15 @@ import os -import typing from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union, cast from google.protobuf.json_format import MessageToDict from pyspark.sql import SparkSession from flytekit import FlyteContextManager, PythonFunctionTask -from flytekit.configuration import SerializationSettings +from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins +from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -48,7 +48,7 @@ class Databricks(Spark): databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ - databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None + databricks_conf: Optional[Dict[str, Union[str, dict]]] = None databricks_token: Optional[str] = None databricks_instance: Optional[str] = None @@ -56,7 +56,7 @@ class Databricks(Spark): # This method does not reset the SparkSession since it's a bit hard to handle multiple # Spark sessions in a single application as it's described in: # https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application. -def new_spark_session(name: str, conf: typing.Dict[str, str] = None): +def new_spark_session(name: str, conf: Dict[str, str] = None): """ Optionally creates a new spark session and returns it. In cluster mode (running in hosted flyte, this will disregard the spark conf passed in) @@ -99,26 +99,43 @@ class PysparkFunctionTask(PythonFunctionTask[Spark]): _SPARK_TASK_TYPE = "spark" - def __init__(self, task_config: Spark, task_function: Callable, **kwargs): + def __init__( + self, + task_config: Spark, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + self.sess: Optional[SparkSession] = None + self._default_executor_path: Optional[str] = None + self._default_applications_path: Optional[str] = None + + if isinstance(container_image, ImageSpec): + if container_image.base_image is None: + img = f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + container_image.base_image = img + # default executor path and applications path in apache/spark-py:3.3.1 + self._default_executor_path = "/usr/bin/python3" + self._default_applications_path = "local:///usr/local/bin/entrypoint.py" super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, task_function=task_function, + container_image=container_image, **kwargs, ) - self.sess: Optional[SparkSession] = None def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job = SparkJob( spark_conf=self.task_config.spark_conf, hadoop_conf=self.task_config.hadoop_conf, - application_file="local://" + settings.entrypoint_settings.path, - executor_path=settings.python_interpreter, + application_file=self._default_applications_path or "local://" + settings.entrypoint_settings.path, + executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, ) if isinstance(self.task_config, Databricks): - cfg = typing.cast(Databricks, self.task_config) + cfg = cast(Databricks, self.task_config) job._databricks_conf = cfg.databricks_conf job._databricks_token = cfg.databricks_token job._databricks_instance = cfg.databricks_instance diff --git a/plugins/flytekit-spark/tests/test_pyspark_transformers.py b/plugins/flytekit-spark/tests/test_pyspark_transformers.py index cb527e16ef..212af454dd 100644 --- a/plugins/flytekit-spark/tests/test_pyspark_transformers.py +++ b/plugins/flytekit-spark/tests/test_pyspark_transformers.py @@ -6,13 +6,24 @@ import flytekit from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.types.structured.structured_dataset import StructuredDatasetTransformerEngine def test_type_resolution(): assert type(TypeEngine.get_transformer(PipelineModel)) == PySparkPipelineModelTransformer +def test_basic_get(): + + ctx = FlyteContextManager.current_context() + e = StructuredDatasetTransformerEngine() + prot = e._protocol_from_type_or_prefix(ctx, pyspark.sql.DataFrame, uri="/tmp/blah") + en = e.get_encoder(pyspark.sql.DataFrame, prot, "") + assert en is not None + + def test_pipeline_model_compatibility(): @task(task_config=Spark()) def my_dataset() -> pyspark.sql.DataFrame: diff --git a/plugins/flytekit-sqlalchemy/Dockerfile b/plugins/flytekit-sqlalchemy/Dockerfile new file mode 100644 index 0000000000..ed1a644d8f --- /dev/null +++ b/plugins/flytekit-sqlalchemy/Dockerfile @@ -0,0 +1,19 @@ +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-buster + +WORKDIR /root +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 +ENV PYTHONPATH /root + +ARG VERSION + +RUN pip install sqlalchemy \ + psycopg2-binary \ + pymysql \ + flytekitplugins-sqlalchemy==$VERSION \ + flytekit==$VERSION + +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.10 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.10 deleted file mode 100644 index 791b13fa53..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.10 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.10-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.10 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.7 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.7 deleted file mode 100644 index 879656adb5..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.7 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.7-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.7 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.8 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.8 deleted file mode 100644 index 93b7048e1b..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.8 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.8-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.8 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.9 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.9 deleted file mode 100644 index 039956dcd1..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.9 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.9-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.9 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/setup.py b/plugins/setup.py index 1b47cc58e0..d607468081 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -18,6 +18,8 @@ "flytekitplugins-dbt": "flytekit-dbt", "flytekitplugins-dolt": "flytekit-dolt", "flytekitplugins-duckdb": "flytekit-duckdb", + "flytekitplugins-data-fsspec": "flytekit-data-fsspec", + "flytekitplugins-envd": "flytekit-envd", "flytekitplugins-great_expectations": "flytekit-greatexpectations", "flytekitplugins-hive": "flytekit-hive", "flytekitplugins-pod": "flytekit-k8s-pod", @@ -30,6 +32,7 @@ "flytekitplugins-onnxpytorch": "flytekit-onnx-pytorch", "flytekitplugins-pandera": "flytekit-pandera", "flytekitplugins-papermill": "flytekit-papermill", + "flytekitplugins-polars": "flytekit-polars", "flytekitplugins-ray": "flytekit-ray", "flytekitplugins-snowflake": "flytekit-snowflake", "flytekitplugins-spark": "flytekit-spark", diff --git a/setup.py b/setup.py index 6519fbccd1..5500f9f737 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,8 @@ ] }, install_requires=[ - "flyteidl>=1.2.9,<1.3.0", + "googleapis-common-protos>=1.57", + "flyteidl>=1.2.10,<1.3.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -42,6 +43,10 @@ "grpcio>=1.43.0,!=1.45.0,<1.49.1,<2.0", "grpcio-status>=1.43,!=1.45.0,<1.49.1", "importlib-metadata", + "fsspec>=2023.3.0", + "adlfs", + "s3fs>=0.6.0", + "gcsfs", "pyopenssl", "joblib", "protobuf>=3.6.1,<4", @@ -56,7 +61,6 @@ "statsd>=3.0.0,<4.0.0", "urllib3>=1.22,<2.0.0", "wrapt>=1.0.0,<2.0.0", - "retry==0.9.2", "dataclasses-json>=0.5.2", "marshmallow-jsonschema>=0.12.0", "natsort>=7.0.1", @@ -73,6 +77,8 @@ "numpy<1.24.0", "gitpython", "kubernetes>=12.0.1", + "rich", + "rich_click", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 45d50a2fc5..1a24cccb61 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -2,6 +2,7 @@ import typing from collections import OrderedDict +import fsspec import mock import pytest from flyteidl.core.errors_pb2 import ErrorDocument @@ -10,15 +11,12 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs -from flytekit.core.data_persistence import DiskPersistence from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.scopes import system_entry_point -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models from flytekit.models.core import execution as execution_models @@ -311,7 +309,22 @@ def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM -def test_persist_ss(): +def test_setup_disk_prefix(): + with setup_execution("qwerty") as ctx: + assert isinstance(ctx.file_access._default_remote, fsspec.AbstractFileSystem) + assert ctx.file_access._default_remote.protocol == "file" + + +def test_setup_cloud_prefix(): + with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert ctx.file_access._default_remote.protocol[0] == "s3" + + with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert "gs" in ctx.file_access._default_remote.protocol + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +def test_persist_ss(mock_gcs): default_img = Image(name="default", fqn="test", tag="tag") ss = SerializationSettings( project="proj1", @@ -327,19 +340,6 @@ def test_persist_ss(): assert ctx.serialization_settings.domain == "dom" -def test_setup_disk_prefix(): - with setup_execution("qwerty") as ctx: - assert isinstance(ctx.file_access._default_remote, DiskPersistence) - - -def test_setup_cloud_prefix(): - with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, S3Persistence) - - with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, GCSPersistence) - - def test_normalize_inputs(): assert normalize_inputs("{{.rawOutputDataPrefix}}", "{{.checkpointOutputPrefix}}", "{{.prevCheckpointPrefix}}") == ( None, diff --git a/tests/flytekit/unit/cli/pyflyte/imageSpec.yaml b/tests/flytekit/unit/cli/pyflyte/imageSpec.yaml new file mode 100644 index 0000000000..ba67ab4b91 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/imageSpec.yaml @@ -0,0 +1,2 @@ +python_version: 3.8 +builder: test diff --git a/tests/flytekit/unit/cli/pyflyte/image_spec_wf.py b/tests/flytekit/unit/cli/pyflyte/image_spec_wf.py new file mode 100644 index 0000000000..9d5d74ff1b --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/image_spec_wf.py @@ -0,0 +1,20 @@ +from flytekit import task, workflow +from flytekit.image_spec import ImageSpec + +image_spec = ImageSpec(packages=["numpy", "pandas"], apt_packages=["git"], registry="", builder="test") + + +@task(container_image=image_spec) +def t2() -> str: + return "flyte" + + +@task(container_image=image_spec) +def t1() -> str: + return "flyte" + + +@workflow +def wf(): + t1() + t2() diff --git a/tests/flytekit/unit/cli/pyflyte/test_backfill.py b/tests/flytekit/unit/cli/pyflyte/test_backfill.py index 8389295af2..0fd328e638 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_backfill.py +++ b/tests/flytekit/unit/cli/pyflyte/test_backfill.py @@ -39,7 +39,6 @@ def test_pyflyte_backfill(mock_remote): "--backfill-window", "5 day", "daily", - "--dry-run", ], ) assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/test_build.py b/tests/flytekit/unit/cli/pyflyte/test_build.py new file mode 100644 index 0000000000..7b4b26fb69 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_build.py @@ -0,0 +1,31 @@ +import os + +from click.testing import CliRunner + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder + +WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_spec_wf.py") + + +def test_build(): + class TestImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + ... + + ImageBuildEngine.register("test", TestImageSpecBuilder()) + runner = CliRunner() + result = runner.invoke(pyflyte.main, ["build", "--fast", WORKFLOW_FILE, "wf"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", WORKFLOW_FILE, "wf"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", WORKFLOW_FILE, "wf"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", "--help"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", "../", "wf"]) + assert result.exit_code == 1 diff --git a/tests/flytekit/unit/cli/pyflyte/test_launchplan.py b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py new file mode 100644 index 0000000000..1a461bfd35 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py @@ -0,0 +1,34 @@ +import pytest +from click.testing import CliRunner +from mock import mock + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.remote import FlyteRemote + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@pytest.mark.parametrize( + ("action", "expected_state"), + [ + ("activate", "ACTIVE"), + ("deactivate", "INACTIVE"), + ], +) +def test_pyflyte_launchplan(mock_remote, action, expected_state): + mock_remote.generate_console_url.return_value = "ex" + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke( + pyflyte.main, + [ + "launchplan", + f"--{action}", + "-p", + "flytesnacks", + "-d", + "development", + "daily", + ], + ) + assert result.exit_code == 0 + assert f"Launchplan was set to {expected_state}: " in result.output diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index e3ccb1d803..4d8251fc57 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -1,7 +1,6 @@ import os import shutil -import pytest from click.testing import CliRunner import flytekit @@ -10,7 +9,6 @@ from flytekit import TaskMetadata from flytekit.clis.sdk_in_container import pyflyte from flytekit.core import context_manager -from flytekit.exceptions.user import FlyteValidationException from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core.identifier import Identifier, ResourceType from flytekit.models.launch_plan import LaunchPlan @@ -104,56 +102,6 @@ def test_package_with_fast_registration(): shutil.rmtree("core") -def test_duplicate_registrable_entities(): - @flytekit.task - def t_1(): - pass - - # Keep a reference to a task named `t_1` that's going to be duplicated below - reference_1 = t_1 - - @flytekit.workflow - def wf_1(): - return t_1() - - # Duplicate definition of `t_1` - @flytekit.task - def t_1() -> str: - pass - - # Keep a second reference to the duplicate task named `t_1` so that we can use it later - reference_2 = t_1 - - @flytekit.task - def non_duplicate_task(): - pass - - @flytekit.workflow - def wf_2(): - non_duplicate_task() - # refers to the second definition of `t_1` - return t_1() - - ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - flytekit.configuration.SerializationSettings( - project="p", - domain="d", - version="v", - image_config=flytekit.configuration.ImageConfig( - default_image=flytekit.configuration.Image("def", "docker.io/def", "latest") - ), - ) - ) - - context_manager.FlyteEntities.entities = [reference_1, wf_1, "str", reference_2, non_duplicate_task, wf_2, "str"] - - with pytest.raises( - FlyteValidationException, - match=r"Multiple definitions of the following tasks were found: \['pyflyte.test_package.t_1'\]", - ): - flytekit.tools.serialize_helpers.get_registrable_entities(ctx) - - def test_package(): runner = CliRunner() with runner.isolated_filesystem(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index a6c0bb91d8..0a371b76d1 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -92,7 +92,7 @@ def test_non_fast_register(mock_client, mock_remote): def test_non_fast_register_require_version(mock_client, mock_remote): mock_remote._client = mock_client mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" - mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + mock_remote.return_value.upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" runner = CliRunner() context_manager.FlyteEntities.entities.clear() with runner.isolated_filesystem(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index e963f3dfc6..735df4af2c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,6 +1,8 @@ import functools +import json import os import pathlib +import tempfile import typing from datetime import datetime, timedelta from enum import Enum @@ -8,6 +10,7 @@ import click import mock import pytest +import yaml from click.testing import CliRunner from flytekit import FlyteContextManager @@ -21,12 +24,14 @@ DurationParamType, FileParamType, FlyteLiteralConverter, + JsonParamType, get_entities_in_file, run_command, ) from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder from flytekit.models.types import SimpleType from flytekit.remote import FlyteRemote @@ -62,8 +67,19 @@ def test_imperative_wf(): assert result.exit_code == 0 +def test_copy_all_files(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() + parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet") result = runner.invoke( pyflyte.main, [ @@ -83,7 +99,7 @@ def test_pyflyte_run_cli(): "--f", '{"x":1.0, "y":2.0}', "--g", - os.path.join(DIR_NAME, "testdata/df.parquet"), + parquet_file, "--i", "2020-05-01", "--j", @@ -97,6 +113,10 @@ def test_pyflyte_run_cli(): "--image", os.path.join(DIR_NAME, "testdata"), "--h", + "--n", + json.dumps([{"x": parquet_file}]), + "--o", + json.dumps({"x": [parquet_file]}), ], catch_exceptions=False, ) @@ -148,19 +168,19 @@ def test_union_type2(input): def test_union_type_with_invalid_input(): runner = CliRunner() - with pytest.raises(ValueError, match="Failed to convert python type typing.Union"): - runner.invoke( - pyflyte.main, - [ - "--verbose", - "run", - os.path.join(DIR_NAME, "workflow.py"), - "test_union2", - "--a", - "hello", - ], - catch_exceptions=False, - ) + result = runner.invoke( + pyflyte.main, + [ + "--verbose", + "run", + os.path.join(DIR_NAME, "workflow.py"), + "test_union2", + "--a", + "hello", + ], + catch_exceptions=False, + ) + assert result.exit_code == 2 def test_get_entities_in_file(): @@ -216,6 +236,7 @@ def test_list_default_arguments(wf_path): ], catch_exceptions=False, ) + print(result.stdout) assert result.exit_code == 0 @@ -239,6 +260,17 @@ def test_list_default_arguments(wf_path): images=[Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"), Image(name="abc", fqn="docker.io/abc", tag=None)], ) +ic_result_4 = ImageConfig( + default_image=Image(name="default", fqn="flytekit", tag="4VC-c-UDrUvfySJ0aS3qCw.."), + images=[ + Image(name="default", fqn="flytekit", tag="4VC-c-UDrUvfySJ0aS3qCw.."), + Image(name="xyz", fqn="docker.io/xyz", tag="latest"), + Image(name="abc", fqn="docker.io/abc", tag=None), + ], +) + +IMAGE_SPEC = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imageSpec.yaml") + @pytest.mark.parametrize( "image_string, leaf_configuration_file_name, final_image_config", @@ -246,9 +278,16 @@ def test_list_default_arguments(wf_path): ("ghcr.io/flyteorg/mydefault:py3.9-latest", "no_images.yaml", ic_result_1), ("asdf=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_2), ("xyz=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_3), + (IMAGE_SPEC, "sample.yaml", ic_result_4), ], ) def test_pyflyte_run_run(image_string, leaf_configuration_file_name, final_image_config): + class TestImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + ... + + ImageBuildEngine.register("test", TestImageSpecBuilder()) + @task def a(): ... @@ -362,3 +401,34 @@ def test_datetime_type(): v = t.convert("now", None, None) assert v.day == now.day assert v.month == now.month + + +def test_json_type(): + t = JsonParamType() + assert t.convert(value='{"a": "b"}', param=None, ctx=None) == {"a": "b"} + + with pytest.raises(click.BadParameter): + t.convert(None, None, None) + + # test that it loads a json file + with tempfile.NamedTemporaryFile("w", delete=False) as f: + json.dump({"a": "b"}, f) + f.flush() + assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} + + # test that if the file is not a valid json, it raises an error + with tempfile.NamedTemporaryFile("w", delete=False) as f: + f.write("asdf") + f.flush() + with pytest.raises(click.BadParameter): + t.convert(value=f.name, param="asdf", ctx=None) + + # test if the file does not exist + with pytest.raises(click.BadParameter): + t.convert(value="asdf", param=None, ctx=None) + + # test if the file is yaml and ends with .yaml it works correctly + with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f: + yaml.dump({"a": "b"}, f) + f.flush() + assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} diff --git a/tests/flytekit/unit/cli/pyflyte/test_serve.py b/tests/flytekit/unit/cli/pyflyte/test_serve.py new file mode 100644 index 0000000000..f3ecbef547 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_serve.py @@ -0,0 +1,9 @@ +from click.testing import CliRunner + +from flytekit.clis.sdk_in_container import pyflyte + + +def test_pyflyte_serve(): + runner = CliRunner() + result = runner.invoke(pyflyte.main, ["serve", "--port", "0", "--timeout", "1"], catch_exceptions=False) + assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 85438eb00d..01621a6a01 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -56,8 +56,10 @@ def print_all( k: Color, l: dict, m: dict, + n: typing.List[typing.Dict[str, FlyteFile]], + o: typing.Dict[str, typing.List[FlyteFile]], ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}") @task @@ -84,6 +86,8 @@ def my_wf( j: datetime.timedelta, k: Color, l: dict, + n: typing.List[typing.Dict[str, FlyteFile]], + o: typing.Dict[str, typing.List[FlyteFile]], remote: pd.DataFrame, image: StructuredDataset, m: dict = {"hello": "world"}, @@ -91,5 +95,5 @@ def my_wf( x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o) return x diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 4c968cf0bd..fdbddb2ebe 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -8,10 +8,12 @@ ClientConfig, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, StaticClientConfigStore, ) from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import DeviceCodeResponse ENDPOINT = "example.com" @@ -65,31 +67,69 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -def test_get_basic_authorization_header(): - header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc") - assert header == "Basic Y2xpZW50X2lkOmFiYw==" - +@patch("flytekit.clients.auth.token_client.requests") +def test_client_creds_authenticator(mock_requests): + authn = ClientCredentialsAuthenticator( + ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store + ) -@patch("flytekit.clients.auth.authenticator.requests") -def test_get_token(mock_requests): response = MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response - access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"]) - assert access == "abc" - assert expiration == 60 + authn.refresh_credentials() + expected_scopes = static_cfg_store.get_client_config().scopes + assert authn._creds + assert authn._scopes == expected_scopes -@patch("flytekit.clients.auth.authenticator.requests") -def test_client_creds_authenticator(mock_requests): - authn = ClientCredentialsAuthenticator( - ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.token_client.get_device_code") +@patch("flytekit.clients.auth.token_client.poll_token_endpoint") +def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): + with pytest.raises(AuthenticationError): + DeviceCodeAuthenticator( + ENDPOINT, + static_cfg_store, + audience="x", + ) + + cfg_store = StaticClientConfigStore( + ClientConfig( + token_endpoint="token_endpoint", + authorization_endpoint="auth_endpoint", + redirect_uri="redirect_uri", + client_id="client", + device_authorization_endpoint="dev", + ) + ) + authn = DeviceCodeAuthenticator( + ENDPOINT, + cfg_store, + audience="x", ) + device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0) + poll_mock.return_value = ("access", 100) + authn.refresh_credentials() + assert authn._creds + + +@patch("flytekit.clients.auth.token_client.requests") +def test_client_creds_authenticator_with_custom_scopes(mock_requests): + expected_scopes = ["foo", "baz"] + authn = ClientCredentialsAuthenticator( + ENDPOINT, + client_id="client", + client_secret="secret", + cfg_store=static_cfg_store, + scopes=expected_scopes, + ) response = MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response authn.refresh_credentials() + assert authn._creds + assert authn._scopes == expected_scopes diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py new file mode 100644 index 0000000000..c22284cd38 --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -0,0 +1,81 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import ( + DeviceCodeResponse, + error_auth_pending, + get_basic_authorization_header, + get_device_code, + get_token, + poll_token_endpoint, +) + + +def test_get_basic_authorization_header(): + header = get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + header = get_basic_authorization_header("client_id", "abc%%$?\\/\\/") + assert header == "Basic Y2xpZW50X2lkOmFiYyUyNSUyNSUyNCUzRiU1QyUyRiU1QyUyRg==" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"]) + assert access == "abc" + assert expiration == 60 + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_device_code(mock_requests): + response = MagicMock() + response.ok = False + mock_requests.post.return_value = response + with pytest.raises(AuthenticationError): + get_device_code("test.com", "test") + + response.ok = True + response.json.return_value = { + "device_code": "code", + "user_code": "BNDJJFXL", + "verification_uri": "url", + "verification_uri_complete": "url", + "expires_in": 600, + "interval": 5, + } + mock_requests.post.return_value = response + c = get_device_code("test.com", "test") + assert c + assert c.device_code == "code" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_poll_token_endpoint(mock_requests): + response = MagicMock() + response.ok = False + response.json.return_value = {"error": error_auth_pending} + mock_requests.post.return_value = response + + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1 + ) + with pytest.raises(AuthenticationError): + poll_token_endpoint(r, "test.com", "test") + + response = MagicMock() + response.ok = True + response.json.return_value = {"access_token": "abc", "expires_in": 60} + mock_requests.post.return_value = response + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0 + ) + t, e = poll_token_endpoint(r, "test.com", "test") + assert t + assert e diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 8f14de730e..3bd57918f4 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -9,6 +9,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.auth.exceptions import AuthenticationError @@ -31,6 +32,8 @@ OAUTH_AUTHORIZE = "https://your.domain.io/oauth2/authorize" +DEVICE_AUTH_ENDPOINT = "https://your.domain.io/..." + def get_auth_service_mock() -> MagicMock: auth_stub_mock = MagicMock() @@ -66,13 +69,14 @@ def test_remote_client_config_store(mock_auth_service: MagicMock): assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE -def get_client_config() -> ClientConfigStore: +def get_client_config(**kwargs) -> ClientConfigStore: cfg_store = MagicMock() cfg_store.get_client_config.return_value = ClientConfig( token_endpoint=TOKEN_ENDPOINT, authorization_endpoint=OAUTH_AUTHORIZE, redirect_uri=REDIRECT_URI, client_id=CLIENT_ID, + **kwargs ) return cfg_store @@ -135,6 +139,15 @@ def test_get_authenticator_cmd(): assert authn._cmd == ["echo"] +def test_get_authenticator_deviceflow(): + cfg = PlatformConfig(auth_mode=AuthType.DEVICEFLOW) + with pytest.raises(AuthenticationError): + get_authenticator(cfg, get_client_config()) + + authn = get_authenticator(cfg, get_client_config(device_authorization_endpoint=DEVICE_AUTH_ENDPOINT)) + assert isinstance(authn, DeviceCodeAuthenticator) + + def test_wrap_exceptions_channel(): ch = MagicMock() out_ch = wrap_exceptions_channel(PlatformConfig(), ch) diff --git a/tests/flytekit/unit/configuration/test_image_config.py b/tests/flytekit/unit/configuration/test_image_config.py index 84c767f8fb..c14832df3c 100644 --- a/tests/flytekit/unit/configuration/test_image_config.py +++ b/tests/flytekit/unit/configuration/test_image_config.py @@ -60,3 +60,7 @@ def test_image_create(): ic = ImageConfig.from_images("cr.flyte.org/im/g:latest") assert ic.default_image.fqn == "cr.flyte.org/im/g" + + +def test_get_version_suffix(): + assert DefaultImages.get_version_suffix() == "latest" diff --git a/tests/flytekit/unit/core/flyte_functools/decorator_source.py b/tests/flytekit/unit/core/flyte_functools/decorator_source.py index 9c92364649..5790d5d358 100644 --- a/tests/flytekit/unit/core/flyte_functools/decorator_source.py +++ b/tests/flytekit/unit/core/flyte_functools/decorator_source.py @@ -1,10 +1,11 @@ """Script used for testing local execution of functool.wraps-wrapped tasks for stacked decorators""" - +import functools +import typing from functools import wraps from typing import List -def task_setup(function: callable = None, *, integration_requests: List = None) -> None: +def task_setup(function: typing.Callable, *, integration_requests: typing.Optional[List] = None) -> typing.Callable: integration_requests = integration_requests or [] @wraps(function) diff --git a/tests/flytekit/unit/core/flyte_functools/nested_function.py b/tests/flytekit/unit/core/flyte_functools/nested_function.py index 6a3ccfd9e1..98a39e497a 100644 --- a/tests/flytekit/unit/core/flyte_functools/nested_function.py +++ b/tests/flytekit/unit/core/flyte_functools/nested_function.py @@ -32,4 +32,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py index a51a283be5..3278af1bb0 100644 --- a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py @@ -38,4 +38,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py index 07c46cd46a..dd445a6fb3 100644 --- a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py +++ b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py @@ -48,4 +48,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/test_decorators.py b/tests/flytekit/unit/core/flyte_functools/test_decorators.py index 3edd547c8a..20e55e9d3c 100644 --- a/tests/flytekit/unit/core/flyte_functools/test_decorators.py +++ b/tests/flytekit/unit/core/flyte_functools/test_decorators.py @@ -39,7 +39,7 @@ def test_wrapped_tasks_error(capfd): ) out = capfd.readouterr().out - assert out.replace("\r", "").strip().split("\n") == [ + assert out.replace("\r", "").strip().split("\n")[:5] == [ "before running my_task", "try running my_task", "error running my_task: my_task failed with input: 0", @@ -74,11 +74,11 @@ def test_unwrapped_task(): capture_output=True, ) error = completed_process.stderr - error_str = error.strip().split("\n")[-1] - assert ( - "TaskFunction cannot be a nested/inner or local function." - " It should be accessible at a module level for Flyte to execute it." in error_str - ) + error_str = "" + for line in error.strip().split("\n"): + if line.startswith("ValueError"): + error_str += line + assert error_str.startswith("ValueError: TaskFunction cannot be a nested/inner or local function.") @pytest.mark.parametrize("script", ["nested_function.py", "nested_wrapped_function.py"]) @@ -90,5 +90,8 @@ def test_nested_function(script): capture_output=True, ) error = completed_process.stderr - error_str = error.strip().split("\n")[-1] + error_str = "" + for line in error.strip().split("\n"): + if line.startswith("ValueError"): + error_str += line assert error_str.startswith("ValueError: TaskFunction cannot be a nested/inner or local function.") diff --git a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py index 9f7e6599c6..6e22ca9840 100644 --- a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py @@ -26,4 +26,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/image_spec/__init__.py b/tests/flytekit/unit/core/image_spec/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py new file mode 100644 index 0000000000..be8ea61427 --- /dev/null +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from flytekit.core import context_manager +from flytekit.core.context_manager import ExecutionState +from flytekit.image_spec import ImageSpec +from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, ImageSpecBuilder, calculate_hash_from_image_spec + + +def test_image_spec(): + image_spec = ImageSpec( + packages=["pandas"], + apt_packages=["git"], + python_version="3.8", + registry="", + base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", + ) + + assert image_spec.python_version == "3.8" + assert image_spec.base_image == "cr.flyte.org/flyteorg/flytekit:py3.8-latest" + assert image_spec.packages == ["pandas"] + assert image_spec.apt_packages == ["git"] + assert image_spec.registry == "" + assert image_spec.name == "flytekit" + assert image_spec.builder == "envd" + assert image_spec.source_root is None + assert image_spec.env is None + assert image_spec.is_container() is True + assert image_spec.image_name() == "flytekit:yZ8jICcDTLoDArmNHbWNwg.." + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ): + os.environ[_F_IMG_ID] = "flytekit:123" + assert image_spec.is_container() is False + + class DummyImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + ... + + ImageBuildEngine.register("dummy", DummyImageSpecBuilder()) + ImageBuildEngine._REGISTRY["dummy"].build_image(image_spec) + assert "dummy" in ImageBuildEngine._REGISTRY + assert calculate_hash_from_image_spec(image_spec) == "yZ8jICcDTLoDArmNHbWNwg.." + assert image_spec.exist() is False + + with pytest.raises(Exception): + image_spec.builder = "flyte" + ImageBuildEngine.build(image_spec) diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index 2add1b9e7d..b5fa46fe54 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import pytest @@ -36,11 +37,14 @@ def test_sync_checkpoint_save_file(tmpdir): def test_sync_checkpoint_save_filepath(tmpdir): - td_path = Path(tmpdir) - cp = SyncCheckpoint(checkpoint_dest=tmpdir) - dst_path = td_path.joinpath("test") + src_path = Path(os.path.join(tmpdir, "src")) + src_path.mkdir(parents=True, exist_ok=True) + chkpnt_path = Path(os.path.join(tmpdir, "dest")) + chkpnt_path.mkdir() + cp = SyncCheckpoint(checkpoint_dest=str(chkpnt_path)) + dst_path = chkpnt_path.joinpath("test") assert not dst_path.exists() - inp = td_path.joinpath("test") + inp = src_path.joinpath("test") with inp.open("wb") as f: f.write(b"blah") cp.save(inp) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 09140d8cb7..6fe2b01e61 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -35,14 +35,12 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", [("sub_int", int)]) @task def t1(a: int) -> nt: a = a + 2 - return nt( - a, - ) # returns a named tuple + return nt(a) @task def t2(a: int, b: int) -> nt: diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index 24f051fbf7..7b0b292baa 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -305,7 +305,7 @@ def branching(x: int): def test_subworkflow_condition_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int, c=str) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int), ("c", str)]) @task def t() -> nt: @@ -324,13 +324,11 @@ def branching(x: int) -> nt: def test_subworkflow_condition_single_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int)]) @task def t() -> nt: - return nt( - 5, - ) + return nt(5) @workflow def wf1() -> nt: diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py new file mode 100644 index 0000000000..599061d403 --- /dev/null +++ b/tests/flytekit/unit/core/test_container_task.py @@ -0,0 +1,80 @@ +from kubernetes.client.models import ( + V1Affinity, + V1NodeAffinity, + V1NodeSelectorRequirement, + V1NodeSelectorTerm, + V1PodSpec, + V1PreferredSchedulingTerm, + V1Toleration, +) + +from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.container_task import ContainerTask +from flytekit.core.pod_template import PodTemplate +from flytekit.tools.translator import get_serializable_task + + +def test_pod_template(): + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + ps.runtime_class_name = "nvidia" + nsr = V1NodeSelectorRequirement(key="nvidia.com/gpu.memory", operator="Gt", values=["10000"]) + pref_sched = V1PreferredSchedulingTerm(preference=V1NodeSelectorTerm(match_expressions=[nsr]), weight=1) + ps.affinity = V1Affinity( + node_affinity=V1NodeAffinity(preferred_during_scheduling_ignored_during_execution=[pref_sched]) + ) + pt = PodTemplate(pod_spec=ps, labels={"somelabel": "foobar"}) + + image = "ghcr.io/flyteorg/rawcontainers-shell:v2" + cmd = [ + "./calculate-ellipse-area.sh", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ] + ct = ContainerTask( + name="ellipse-area-metadata-shell", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float, metadata=str), + image=image, + command=cmd, + pod_template=pt, + pod_template_name="my-base-template", + ) + + assert ct.metadata.pod_template_name == "my-base-template" + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + container = ct.get_container(default_serialization_settings) + assert container is None + + k8s_pod = ct.get_k8s_pod(default_serialization_settings) + assert k8s_pod.metadata.labels == {"somelabel": "foobar"} + + primary_container = k8s_pod.pod_spec["containers"][0] + + assert primary_container["image"] == image + assert primary_container["command"] == cmd + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, ct) + assert ts.template.metadata.pod_template_name == "my-base-template" + assert ts.template.container is None + assert ts.template.k8s_pod is not None + serialized_pod_spec = ts.template.k8s_pod.pod_spec + assert serialized_pod_spec["affinity"]["nodeAffinity"] is not None + assert serialized_pod_spec["tolerations"] == [ + {"effect": "NoSchedule", "key": "nvidia.com/gpu", "operator": "Exists"} + ] + assert serialized_pod_spec["runtimeClassName"] == "nvidia" diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 6e68c9d4be..fe535b761c 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -115,18 +115,17 @@ def test_secrets_manager_default(): def test_secrets_manager_get_envvar(): sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_env_var("test", "") with pytest.raises(ValueError): sec.get_secrets_env_var("", "x") cfg = SecretsConfig.auto() assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST" + assert sec.get_secrets_env_var("group", "test", "v1") == f"{cfg.env_prefix}GROUP_V1_TEST" + assert sec.get_secrets_env_var("group", group_version="v1") == f"{cfg.env_prefix}GROUP_V1" + assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP" def test_secrets_manager_get_file(): sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_file("test", "") with pytest.raises(ValueError): sec.get_secrets_file("", "x") cfg = SecretsConfig.auto() @@ -135,6 +134,12 @@ def test_secrets_manager_get_file(): "group", f"{cfg.file_prefix}test", ) + assert sec.get_secrets_file("group", "test", "v1") == os.path.join( + cfg.default_dir, + "group", + "v1", + f"{cfg.file_prefix}test", + ) def test_secrets_manager_file(tmpdir: py.path.local): @@ -145,8 +150,6 @@ def test_secrets_manager_file(tmpdir: py.path.local): with open(f, "w+") as w: w.write("my-password") - with pytest.raises(ValueError): - sec.get("test", "") with pytest.raises(ValueError): sec.get("", "x") # Group dir not exists @@ -207,7 +210,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 388 + assert len(tp) == 400 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py new file mode 100644 index 0000000000..2d61b58d8c --- /dev/null +++ b/tests/flytekit/unit/core/test_data.py @@ -0,0 +1,330 @@ +import os +import random +import shutil +import tempfile +from uuid import UUID + +import fsspec +import mock +import pytest + +from flytekit.configuration import Config, S3Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.types.directory.types import FlyteDirectory + +local = fsspec.filesystem("file") +root = os.path.abspath(os.sep) + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +@mock.patch("flytekit.core.data_persistence.UUID") +def test_path_getting(mock_uuid_class, mock_gcs): + mock_uuid_class.return_value.hex = "abcdef123" + + # Testing with raw output prefix pointing to a local path + loc_sandbox = os.path.join(root, "tmp", "unittest") + loc_data = os.path.join(root, "tmp", "unittestdata") + local_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data) + assert local_raw_fp.get_random_remote_path() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + assert local_raw_fp.get_random_remote_path("/fsa/blah.csv") == os.path.join( + root, "tmp", "unittestdata", "abcdef123", "blah.csv" + ) + assert local_raw_fp.get_random_remote_directory() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + + # Test local path and directory + assert local_raw_fp.get_random_local_path() == os.path.join(root, "tmp", "unittest", "local_flytekit", "abcdef123") + assert local_raw_fp.get_random_local_path("xjiosa/blah.txt") == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123", "blah.txt" + ) + assert local_raw_fp.get_random_local_directory() == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123" + ) + + # Recursive paths + assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" + ) + assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + "file:///abc/happy", "s3://my-s3-bucket/bucket1" + ) + + # Test with remote pointed to s3. + s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + # trailing slash should make no difference + s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + + # Testing with raw output prefix pointing to file:// + # Skip tests for windows + if os.name != "nt": + file_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata") + assert file_raw_fp.get_random_remote_path() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + assert file_raw_fp.get_random_remote_path("/fsa/blah.csv") == os.path.join( + root, "tmp", "unittestdata", "abcdef123", "blah.csv" + ) + assert file_raw_fp.get_random_remote_directory() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + + g_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/") + assert g_fa.get_random_remote_path() == "gs://my-s3-bucket/abcdef123" + + +@mock.patch("flytekit.core.data_persistence.UUID") +def test_default_file_access_instance(mock_uuid_class): + mock_uuid_class.return_value.hex = "abcdef123" + + assert default_local_file_access_provider.get_random_local_path().endswith( + os.path.join("sandbox", "local_flytekit", "abcdef123") + ) + assert default_local_file_access_provider.get_random_local_path("bob.txt").endswith( + os.path.join("abcdef123", "bob.txt") + ) + + assert default_local_file_access_provider.get_random_local_directory().endswith( + os.path.join("sandbox", "local_flytekit", "abcdef123") + ) + + x = default_local_file_access_provider.get_random_remote_path() + assert x.endswith(os.path.join("raw", "abcdef123")) + x = default_local_file_access_provider.get_random_remote_path("eve.txt") + assert x.endswith(os.path.join("raw", "abcdef123", "eve.txt")) + x = default_local_file_access_provider.get_random_remote_directory() + assert x.endswith(os.path.join("raw", "abcdef123")) + + +@pytest.fixture +def source_folder(): + # Set up source directory for testing + parent_temp = tempfile.mkdtemp() + src_dir = os.path.join(parent_temp, "source", "") + nested_dir = os.path.join(src_dir, "nested") + local.mkdir(nested_dir) + local.touch(os.path.join(src_dir, "original.txt")) + with open(os.path.join(src_dir, "original.txt"), "w") as fh: + fh.write("hello original") + local.touch(os.path.join(nested_dir, "more.txt")) + yield src_dir + shutil.rmtree(parent_temp) + + +def test_local_raw_fsspec(source_folder): + # Test copying using raw fsspec local filesystem, should not create a nested folder + with tempfile.TemporaryDirectory() as dest_tmpdir: + local.put(source_folder, dest_tmpdir, recursive=True) + + new_temp_dir_2 = tempfile.mkdtemp() + new_temp_dir_2 = os.path.join(new_temp_dir_2, "doesnotexist") + local.put(source_folder, new_temp_dir_2, recursive=True) + files = local.find(new_temp_dir_2) + assert len(files) == 2 + + +def test_local_provider(source_folder): + # Test that behavior putting from a local dir to a local remote dir is the same whether or not the local + # dest folder exists. + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as dest_tmpdir: + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + files = provider._default_remote.find(doesnotexist) + assert len(files) == 2 + + exists = provider.get_random_remote_directory() + provider._default_remote.mkdir(exists) + provider.put_data(source_folder, exists, is_multipart=True) + files = provider._default_remote.find(exists) + assert len(files) == 2 + + +@pytest.mark.sandbox_test +def test_s3_provider(source_folder): + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + fs = provider.get_filesystem_for_path(doesnotexist) + files = fs.find(doesnotexist) + assert len(files) == 2 + + +def test_local_provider_get_empty(): + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as empty_source: + with tempfile.TemporaryDirectory() as dest_folder: + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=empty_source, data_config=dc + ) + provider.get_data(empty_source, dest_folder, is_multipart=True) + loc = provider.get_filesystem_for_path(dest_folder) + src_files = loc.find(empty_source) + assert len(src_files) == 0 + dest_files = loc.find(dest_folder) + assert len(dest_files) == 0 + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_empty(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + mock_os.get.return_value = None + s3c = S3Config.auto() + kwargs = s3_setup_args(s3c) + assert kwargs == {"cache_regions": True} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_both(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "AWS_ACCESS_KEY_ID": "ignore-user", + "AWS_SECRET_ACCESS_KEY": "ignore-secret", + "FLYTE_AWS_ACCESS_KEY_ID": "flyte", + "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_flyte(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AWS_ACCESS_KEY_ID": "flyte", + "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "AWS_ACCESS_KEY_ID": "ignore-user", + "AWS_SECRET_ACCESS_KEY": "ignore-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default + assert kwargs == {"cache_regions": True} + + +def test_crawl_local_nt(source_folder): + """ + running this to see what it prints + """ + if os.name != "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + print(f"NT split {split}") + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + print(f"NT files joined {files}") + + +def test_crawl_local_non_nt(source_folder): + """ + crawl on the source folder fixture should return for example + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'original.txt') + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'nested/more.txt') + """ + if os.name == "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in split] + assert set(split) == {(source_folder, "original.txt"), (source_folder, os.path.join("nested", "more.txt"))} + expected = {os.path.join(source_folder, "original.txt"), os.path.join(source_folder, "nested", "more.txt")} + assert set(files) == expected + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + + # Test crawling a single file + fd = FlyteDirectory(path=os.path.join(source_folder, "original.txt")) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert len(files) == 0 + + +@pytest.mark.sandbox_test +def test_crawl_s3(source_folder): + """ + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'original.txt') + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'nested/more.txt') + """ + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + s3_random_target = provider.get_random_remote_directory() + provider.put_data(source_folder, s3_random_target, is_multipart=True) + ctx = FlyteContextManager.current_context() + expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} + + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory(path=s3_random_target) + res = fd.crawl() + res = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + assert set(res) == {(s3_random_target, "original.txt"), (s3_random_target, os.path.join("nested", "more.txt"))} + + fd_file = FlyteDirectory(path=f"{s3_random_target}/original.txt") + res = fd_file.crawl() + files = [r for r in res] + assert len(files) == 1 + + +@pytest.mark.sandbox_test +def test_walk_local_copy_to_s3(source_folder): + dc = Config.for_sandbox().data_config + explicit_empty_folder = UUID(int=random.getrandbits(128)).hex + raw_output_path = f"s3://my-s3-bucket/testdata/{explicit_empty_folder}" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output_path, data_config=dc) + + ctx = FlyteContextManager.current_context() + local_fd = FlyteDirectory(path=source_folder) + local_fd_crawl = local_fd.crawl() + local_fd_crawl = [x for x in local_fd_crawl] + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory.new_remote() + assert raw_output_path in fd.path + + # Write source folder files to new remote path + for root_path, suffix in local_fd_crawl: + new_file = fd.new_file(suffix) # noqa + with open(os.path.join(root_path, suffix), "rb") as r: # noqa + with new_file.open("w") as w: + print(f"Writing, t {type(w)} p {new_file.path} |{suffix}|") + w.write(str(r.read())) + + new_crawl = fd.crawl() + new_suffixes = [y for x, y in new_crawl] + assert len(new_suffixes) == 2 # should have written two files diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index af39e9e852..27b407c1ce 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,11 +1,11 @@ -from flytekit.core.data_persistence import DataPersistencePlugins, FileAccessProvider +from flytekit.core.data_persistence import FileAccessProvider def test_get_random_remote_path(): fp = FileAccessProvider("/tmp", "s3://my-bucket") path = fp.get_random_remote_path() assert path.startswith("s3://my-bucket") - assert fp.raw_output_prefix == "s3://my-bucket" + assert fp.raw_output_prefix == "s3://my-bucket/" def test_is_remote(): @@ -14,10 +14,3 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True - - -def test_lister(): - x = DataPersistencePlugins.supported_protocols() - main_protocols = {"file", "/", "gs", "http", "https", "s3"} - all_protocols = set([y.replace("://", "") for y in x]) - assert main_protocols.issubset(all_protocols) diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index 0cb4f524f9..bd20c39c53 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -49,7 +49,6 @@ def test_engine(): def test_transformer_to_literal_local(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) ctx = context_manager.FlyteContext.current_context() @@ -86,6 +85,15 @@ def test_transformer_to_literal_local(): with pytest.raises(TypeError, match="No automatic conversion from "): TypeEngine.to_literal(ctx, 3, FlyteDirectory, lt) + +def test_transformer_to_literal_localss(): + random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + + tf = FlyteDirToMultipartBlobTransformer() + lt = tf.get_literal_type(FlyteDirectory) # Can't use if it's not a directory with pytest.raises(FlyteAssertion): p = "/tmp/flyte/xyz" diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index e2123222e0..b7f0a1aeee 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock import pytest +from typing_extensions import Annotated import flytekit.configuration -from flytekit.configuration import Image, ImageConfig -from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState +from flytekit.configuration import Config, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine @@ -81,11 +82,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - # print(f"Random: {random_dir}") + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -108,10 +108,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -137,12 +137,12 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir print(f"Random {random_dir}") fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit folder @@ -189,11 +189,11 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit dir @@ -243,8 +243,8 @@ def dyn(in1: FlyteFile): fd = FlyteFile("s3://anything") - with context_manager.FlyteContextManager.with_context( - context_manager.FlyteContextManager.current_context().with_serialization_settings( + with FlyteContextManager.with_context( + FlyteContextManager.current_context().with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", @@ -254,8 +254,8 @@ def dyn(in1: FlyteFile): ) ) ): - ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context( + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( @@ -433,3 +433,59 @@ def wf(path: str) -> os.PathLike: return t2(ff=n1) assert flyte_tmp_dir in wf(path="s3://somewhere").path + + +def test_flyte_file_annotated_hashmethod(local_dummy_file): + def calc_hash(ff: FlyteFile) -> str: + return str(ff.path) + + @task + def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: + return FlyteFile(path) + + @workflow + def wf(path: str) -> None: + t1(path=path) + + wf(path=local_dummy_file) + + +@pytest.mark.sandbox_test +def test_file_open_things(): + @task + def write_this_file_to_s3() -> FlyteFile: + ctx = FlyteContextManager.current_context() + dest = ctx.file_access.get_random_remote_path() + ctx.file_access.put(__file__, dest) + return FlyteFile(path=dest) + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.remote_path) + with ff.open("r") as r: + with new_file.open("w") as w: + w.write(r.read()) + return new_file + + @task + def print_file(ff: FlyteFile): + with open(ff, "r") as fh: + print(len(fh.readlines())) + + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as new_sandbox: + provider = FileAccessProvider( + local_sandbox_dir=new_sandbox, raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + ctx = FlyteContextManager.current_context() + local = ctx.file_access.get_filesystem("file") # get a local file system. + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + f = write_this_file_to_s3() + copy_file(ff=f) + files = local.find(new_sandbox) + # copy_file was done via streaming so no files should have been written + assert len(files) == 0 + print_file(ff=f) + # print_file uses traditional download semantics so now a file should have been created + files = local.find(new_sandbox) + assert len(files) == 1 diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 7ceec809b1..c45e200f95 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -14,7 +14,7 @@ from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType from flytekit.tools.translator import get_serializable -from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -55,6 +55,11 @@ def test_get_literal_type(): ) +def test_batch_size(): + bs = BatchSize(5) + assert bs.val == 5 + + def test_nested(): class Foo(object): def __init__(self, number: int): diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py index c92e1c9e19..bb245ad594 100644 --- a/tests/flytekit/unit/core/test_gate.py +++ b/tests/flytekit/unit/core/test_gate.py @@ -219,7 +219,7 @@ def wf_dyn(a: int) -> typing.Tuple[int, int]: def test_subwf(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def nt1(a: int) -> nt: diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index db4b32f6a9..ead5358316 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -67,15 +67,13 @@ def t2(): assert len(wf_spec.template.interface.outputs) == 1 # docs_equivalent_start - nt = typing.NamedTuple("wf_output", from_n0t1=str) + nt = typing.NamedTuple("wf_output", [("from_n0t1", str)]) @workflow def my_workflow(in1: str) -> nt: x = t1(a=in1) t2() - return nt( - x, - ) + return nt(x) # docs_equivalent_end diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index db05de0ddb..26b43f2ef5 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -102,7 +102,7 @@ def x(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): return ("hello world", 5) def y(a: int, b: str) -> nt1: - return nt1("hello world", 5) + return nt1("hello world", 5) # type: ignore result = transform_variable_map(extract_return_annotation(typing.get_type_hints(x).get("return", None))) assert result["x_str"].type.simple == 3 diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index ffaff8daad..3addd13e42 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -292,7 +292,7 @@ def wf(a: int, c: str) -> (int, str): def test_lp_all_parameters(): - nt = typing.NamedTuple("OutputsBC", t1_int_output=int, c=str) + nt = typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]) @task def t1(a: int) -> nt: diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 95927873d0..d032aca2d1 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -1,3 +1,4 @@ +import functools import typing from collections import OrderedDict @@ -6,7 +7,7 @@ import flytekit.configuration from flytekit import LaunchPlan, map_task from flytekit.configuration import Image, ImageConfig -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapPythonTask, MapTaskResolver from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable @@ -36,6 +37,11 @@ def t2(a: int) -> str: return str(b) +@task(cache=True, cache_version="1") +def t3(a: int, b: str, c: float) -> str: + pass + + # This test is for documentation. def test_map_docs(): # test_map_task_start @@ -87,8 +93,12 @@ def test_serialization(serialization_settings): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.flytekit.unit.core.test_map_task", "task-name", @@ -177,15 +187,42 @@ def test_inputs_outputs_length(): def many_inputs(a: int, b: str, c: float) -> str: return f"{a} - {b} - {c}" - with pytest.raises(ValueError): - _ = map_task(many_inputs) + m = map_task(many_inputs) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_24c08b3a2f9c2e389ad9fc6a03482cf9" + r_m = MapPythonTask(many_inputs) + assert str(r_m.python_interface) == str(m.python_interface) + + p1 = functools.partial(many_inputs, c=1.0) + m = map_task(p1) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_697aa7389996041183cf6cfd102be4f7" + r_m = MapPythonTask(many_inputs, bound_inputs=set("c")) + assert str(r_m.python_interface) == str(m.python_interface) + + p2 = functools.partial(p1, b="hello") + m = map_task(p2) + assert m.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_cc18607da7494024a402a5fa4b3ea5c6" + r_m = MapPythonTask(many_inputs, bound_inputs={"c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + p3 = functools.partial(p2, a=1) + m = map_task(p3) + assert m.python_interface.inputs == {"a": int, "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_52fe80b04781ea77ef6f025f4b49abef" + r_m = MapPythonTask(many_inputs, bound_inputs={"a", "c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + with pytest.raises(TypeError): + m(a=[1, 2, 3]) @task def many_outputs(a: int) -> (int, str): return a, f"{a}" with pytest.raises(ValueError): - _ = map_task(many_inputs) + _ = map_task(many_outputs) def test_map_task_metadata(): @@ -194,3 +231,34 @@ def test_map_task_metadata(): assert mapped_1.metadata is map_meta mapped_2 = map_task(t2) assert mapped_2.metadata is t2.metadata + + +def test_map_task_resolver(serialization_settings): + list_outputs = {"o0": typing.List[str]} + mt = map_task(t3) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + assert mtr.name() == "MapTaskResolver" + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello", c=1.0)) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello")) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 48d3020e88..da708a8571 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -105,14 +105,12 @@ def test_more_normal_task(): @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # type: ignore @task def t1_nt(a: int) -> nt: # This one returns an instance of the named tuple. - return nt(f"{a + 2}") + return nt(f"{a + 2}") # type: ignore @task def t2(a: typing.List[str]) -> str: @@ -135,9 +133,7 @@ def test_reserved_keyword(): @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # type: ignore # Test that you can't name an output "outputs" with pytest.raises(FlyteAssertion): @@ -419,7 +415,7 @@ def t1(a: str) -> str: @workflow def my_wf(a: str) -> str: - return t1(a=a).with_overrides(name="foo") + return t1(a=a).with_overrides(name="foo", node_name="t_1") serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -431,6 +427,7 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + assert wf_spec.template.nodes[0].id == "t-1" def test_config_override(): diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py new file mode 100644 index 0000000000..24e3908d1d --- /dev/null +++ b/tests/flytekit/unit/core/test_partials.py @@ -0,0 +1,219 @@ +import typing +from collections import OrderedDict +from functools import partial + +import pandas as pd +import pytest + +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.map_task import MapTaskResolver, map_task +from flytekit.core.task import TaskMetadata, task +from flytekit.core.workflow import workflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + +def test_basics_1(): + @task + def t1(a: int, b: str, c: float) -> int: + return a + len(b) + int(c) + + outside_p = partial(t1, b="hello", c=3.14) + + @workflow + def my_wf_1(a: int) -> typing.Tuple[int, int]: + inner_partial = partial(t1, b="world", c=2.7) + out = outside_p(a=a) + inside = inner_partial(a=a) + return out, inside + + with pytest.raises(Exception): + get_serializable(OrderedDict(), serialization_settings, outside_p) + + # check the od todo + od = OrderedDict() + wf_1_spec = get_serializable(od, serialization_settings, my_wf_1) + tts, wspecs, lps = gather_dependent_entities(od) + tts = [t for t in tts.values()] + assert len(tts) == 1 + assert len(wf_1_spec.template.nodes) == 2 + assert wf_1_spec.template.nodes[0].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[1].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[0].inputs[0].binding.promise.var == "a" + assert wf_1_spec.template.nodes[0].inputs[1].binding.scalar is not None + assert wf_1_spec.template.nodes[0].inputs[2].binding.scalar is not None + + @task + def get_str() -> str: + return "got str" + + bind_c = partial(t1, c=2.7) + + @workflow + def my_wf_2(a: int) -> int: + s = get_str() + inner_partial = partial(bind_c, b=s) + inside = inner_partial(a=a) + return inside + + wf_2_spec = get_serializable(OrderedDict(), serialization_settings, my_wf_2) + assert len(wf_2_spec.template.nodes) == 2 + + +def test_map_task_types(): + @task(cache=True, cache_version="1") + def t3(a: int, b: str, c: float) -> str: + return str(a) + b + str(c) + + t3_bind_b1 = partial(t3, b="hello") + t3_bind_b2 = partial(t3, b="world") + t3_bind_c1 = partial(t3_bind_b1, c=3.14) + t3_bind_c2 = partial(t3_bind_b2, c=2.78) + + mt1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + + @task + def print_lists(i: typing.List[str], j: typing.List[str]): + print(f"First: {i}") + print(f"Second: {j}") + + @workflow + def wf_out(a: typing.List[int]): + i = mt1(a=a) + j = mt2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_out(a=[1, 2]) + + @workflow + def wf_in(a: typing.List[int]): + mt_in1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt_in2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + i = mt_in1(a=a) + j = mt_in2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_in(a=[1, 2]) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf_in) + tts, _, _ = gather_dependent_entities(od) + assert len(tts) == 2 # one map task + the print task + assert ( + wf_spec.template.nodes[0].task_node.reference_id.name == wf_spec.template.nodes[1].task_node.reference_id.name + ) + assert wf_spec.template.nodes[0].inputs[0].binding.promise is not None # comes from wf input + assert wf_spec.template.nodes[1].inputs[0].binding.collection is not None # bound to static list + assert wf_spec.template.nodes[1].inputs[1].binding.scalar is not None # these are bound + assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None + + +def test_lists_cannot_be_used_in_partials(): + @task + def t(a: int, b: typing.List[str]) -> str: + return str(a) + str(b) + + with pytest.raises(ValueError): + map_task(partial(t, b=["hello", "world"]))(a=[1, 2, 3]) + + @task + def t_multilist(a: int, b: typing.List[float], c: typing.List[int]) -> str: + return str(a) + str(b) + str(c) + + with pytest.raises(ValueError): + map_task(partial(t_multilist, b=[3.14, 12.34, 9876.5432], c=[42, 99]))(a=[1, 2, 3, 4]) + + @task + def t_list_of_lists(a: typing.List[typing.List[float]], b: int) -> str: + return str(a) + str(b) + + with pytest.raises(ValueError): + map_task(partial(t_list_of_lists, a=[[3.14]]))(b=[1, 2, 3, 4]) + + +def test_everything(): + @task + def get_static_list() -> typing.List[float]: + return [3.14, 2.718] + + @task + def get_list_of_pd(s: int) -> typing.List[pd.DataFrame]: + df1 = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + df2 = pd.DataFrame({"Name": ["Rachel", "Eve", "Mary"], "Age": [22, 23, 24]}) + if s == 2: + return [df1, df2] + else: + return [df1, df2, df1] + + @task + def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.DataFrame) -> str: + return str(a) + f"pdsize{len(a2)}" + b + str(c) + "&&" + str(d) + + t3_bind_b2 = partial(t3, b="world") + # TODO: partial lists are not supported yet. + # t3_bind_b1 = partial(t3, b="hello") + # t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) + # mt1 = map_task(t3_bind_c1) + + mt1 = map_task(t3_bind_b2) + + mr = MapTaskResolver() + aa = mr.loader_args(serialization_settings, mt1) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["b"] + + @task + def print_lists(i: typing.List[str], j: typing.List[str], k: typing.List[str]) -> str: + print(f"First: {i}") + print(f"Second: {j}") + print(f"Third: {k}") + return f"{i}-{j}-{k}" + + @dynamic + def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str: + i = mt1(a=a, a2=a2, c=[[1.1, 2.0, 3.0], [1.1, 2.0, 3.0]], d=[sl, sl]) + mt_in2 = map_task(t3_bind_b2) + dfs = get_list_of_pd(s=3) + j = mt_in2(a=[3, 4, 5], a2=dfs, c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) + + # Test a2 bound to a fixed dataframe + t3_bind_a2 = partial(t3_bind_b2, a2=a2[0]) + + mt_in3 = map_task(t3_bind_a2) + + aa = mr.loader_args(serialization_settings, mt_in3) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["a2", "b"] + + k = mt_in3(a=[3, 4, 5], c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) + return print_lists(i=i, j=j, k=k) + + @workflow + def wf_dt(a: typing.List[int]) -> str: + sl = get_static_list() + dfs = get_list_of_pd(s=2) + return dt1(a=a, a2=dfs, sl=sl) + + print(wf_dt(a=[1, 2])) + assert ( + wf_dt(a=[1, 2]) + == "['1pdsize2world[1.1, 2.0, 3.0]&&[3.14, 2.718]', '2pdsize3world[1.1, 2.0, 3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize3world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize2world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']" + ) diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index d8b043116e..9478cc33ba 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -3,6 +3,7 @@ import pytest from dataclasses_json import dataclass_json +from typing_extensions import Annotated from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager @@ -14,6 +15,8 @@ translate_inputs_to_literals, ) from flytekit.exceptions.user import FlyteAssertion +from flytekit.types.pickle import FlytePickle +from flytekit.types.pickle.pickle import BatchSize def test_create_and_link_node(): @@ -74,7 +77,7 @@ def wf(i: int, j: int): # without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required, # which is incorrect - with pytest.raises(FlyteAssertion, match="Missing input `i` type `simple: INTEGER"): + with pytest.raises(FlyteAssertion, match="Missing input `i` type ``"): create_and_link_node_from_remote(ctx, lp) # Even if j is not provided it will default @@ -92,7 +95,7 @@ def wf(i: int, j: int): @pytest.mark.parametrize( "input", - [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5], ) def test_translate_inputs_to_literals(input): @dataclass_json @@ -102,7 +105,7 @@ class MyDataclass(object): a: typing.List[str] @task - def t1(a: typing.Union[float, typing.List[int], MyDataclass]): + def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], BatchSize(2)]]): print(a) ctx = context_manager.FlyteContext.current_context() @@ -111,7 +114,7 @@ def t1(a: typing.Union[float, typing.List[int], MyDataclass]): def test_translate_inputs_to_literals_with_wrong_types(): ctx = context_manager.FlyteContext.current_context() - with pytest.raises(TypeError, match="Not a map type union_type"): + with pytest.raises(TypeError, match="Not a map type FlyteSchema[CLASSES_COLUMNS]: diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index df6e093b55..7486422fd9 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -160,7 +160,7 @@ def inner_test(ref_mock): @task def t1(a: int) -> nt1: a = a + 2 - return nt1(a, "world-" + str(a)) + return nt1(a, "world-" + str(a)) # type: ignore @workflow def wf2(a: int): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 862c469460..d47d57969c 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -406,17 +406,17 @@ def wf() -> typing.NamedTuple("OP", a=str): def test_named_outputs_nested(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello().greet, say_hello().greet) x, y = my_wf() @@ -425,19 +425,19 @@ def my_wf() -> wf_outputs: def test_named_outputs_nested_fail(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) with pytest.raises(AssertionError): # this should fail because say_hello returns a tuple, but we do not de-reference it @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello(), say_hello()) my_wf() diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index bfb41d0fef..eaba8b6343 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -1,9 +1,11 @@ +import os import tempfile import typing import pandas as pd import pyarrow as pa import pytest +from fsspec.utils import get_protocol from typing_extensions import Annotated import flytekit.configuration @@ -25,7 +27,6 @@ StructuredDatasetTransformerEngine, convert_schema_type_to_structured_dataset_type, extract_cols_and_format, - protocol_prefix, ) my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) @@ -44,8 +45,8 @@ def test_protocol(): - assert protocol_prefix("s3://my-s3-bucket/file") == "s3" - assert protocol_prefix("/file") == "file" + assert get_protocol("s3://my-s3-bucket/file") == "s3" + assert get_protocol("/file") == "file" def generate_pandas() -> pd.DataFrame: @@ -74,7 +75,6 @@ def t1(a: pd.DataFrame) -> pd.DataFrame: def test_setting_of_unset_formats(): - custom = Annotated[StructuredDataset, "parquet"] example = custom(dataframe=df, uri="/path") # It's okay that the annotation is not used here yet. @@ -89,7 +89,9 @@ def t2(path: str) -> StructuredDataset: def wf(path: str) -> StructuredDataset: return t2(path=path) - res = wf(path="/tmp/somewhere") + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "somewhere") + res = wf(path=fname) # Now that it's passed through an encoder however, it should be set. assert res.file_format == "parquet" @@ -281,7 +283,10 @@ def encode( # Check that registering with a / triggers the file protocol instead. StructuredDatasetTransformerEngine.register(TempEncoder("/")) - assert StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("file") is not None + res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "file", "/") + # Test that the one we got was registered under fsspec + assert res is StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("fsspec")["/"] + assert res is not None def test_sd(): diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index c7aa5563f9..cef124ffd0 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -50,5 +50,5 @@ def test_arrow(): assert encoder.protocol is None assert decoder.protocol is None assert encoder.python_type is decoder.python_type - d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["s3"]["parquet"] + d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["fsspec"]["parquet"] assert d is not None diff --git a/tests/flytekit/unit/core/test_type_conversion_errors.py b/tests/flytekit/unit/core/test_type_conversion_errors.py new file mode 100644 index 0000000000..dda19dd126 --- /dev/null +++ b/tests/flytekit/unit/core/test_type_conversion_errors.py @@ -0,0 +1,129 @@ +"""Unit tests for type conversion errors.""" + +from string import ascii_lowercase +from typing import Tuple + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from flytekit import task, workflow + + +@task +def int_to_float(n: int) -> float: + return float(n) + + +@task +def task_incorrect_output(a: float) -> int: + return str(a) # type: ignore [return-value] + + +@task +def task_correct_output(a: float) -> str: + return str(a) + + +@workflow +def wf_with_task_error(a: int) -> str: + return task_incorrect_output(a=int_to_float(n=a)) + + +@workflow +def wf_with_output_error(a: int) -> int: + return task_correct_output(a=int_to_float(n=a)) + + +@workflow +def wf_with_multioutput_error0(a: int, b: int) -> Tuple[int, str]: + out_a = task_correct_output(a=int_to_float(n=a)) + out_b = task_correct_output(a=int_to_float(n=b)) + return out_a, out_b + + +@workflow +def wf_with_multioutput_error1(a: int, b: int) -> Tuple[str, int]: + out_a = task_correct_output(a=int_to_float(n=a)) + out_b = task_correct_output(a=int_to_float(n=b)) + return out_a, out_b + + +@given(st.booleans() | st.integers() | st.text(ascii_lowercase)) +def test_task_input_error(incorrect_input): + with pytest.raises( + TypeError, + match=( + r"Failed to convert inputs of task '{}':\n" + r" Failed argument 'a': Expected value of type \ but got .+ of type .+" + ).format(task_correct_output.name), + ): + task_correct_output(a=incorrect_input) + + +@given(st.floats()) +def test_task_output_error(correct_input): + with pytest.raises( + TypeError, + match=( + r"Failed to convert outputs of task '{}' at position 0:\n" + r" Expected value of type \ but got .+ of type .+" + ).format(task_incorrect_output.name), + ): + task_incorrect_output(a=correct_input) + + +@given(st.integers()) +def test_workflow_with_task_error(correct_input): + with pytest.raises( + TypeError, + match=( + r"Encountered error while executing workflow '{}':\n" + r" Error encountered while executing 'wf_with_task_error':\n" + r" Failed to convert outputs of task '.+' at position 0:\n" + r" Expected value of type \ but got .+ of type .+" + ).format(wf_with_task_error.name), + ): + wf_with_task_error(a=correct_input) + + +@given(st.booleans() | st.floats() | st.text(ascii_lowercase)) +def test_workflow_with_input_error(incorrect_input): + with pytest.raises( + TypeError, + match=(r"Encountered error while executing workflow '{}':\n" r" Failed to convert input").format( + wf_with_output_error.name + ), + ): + wf_with_output_error(a=incorrect_input) + + +@given(st.integers()) +def test_workflow_with_output_error(correct_input): + with pytest.raises( + TypeError, + match=( + r"Encountered error while executing workflow '{}':\n" + r" Failed to convert output in position 0 of value .+, expected type \" + ).format(wf_with_output_error.name), + ): + wf_with_output_error(a=correct_input) + + +@pytest.mark.parametrize( + "workflow, position", + [ + (wf_with_multioutput_error0, 0), + (wf_with_multioutput_error1, 1), + ], +) +@given(st.integers()) +def test_workflow_with_multioutput_error(workflow, position, correct_input): + with pytest.raises( + TypeError, + match=( + r"Encountered error while executing workflow '{}':\n " + r"Failed to convert output in position {} of value .+, expected type \" + ).format(workflow.name, position), + ): + workflow(a=correct_input, b=correct_input) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index bd270fd360..8b1379a228 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1,10 +1,12 @@ import datetime +import json import os import tempfile import typing from dataclasses import asdict, dataclass from datetime import timedelta from enum import Enum +from typing import Optional, Type import mock import pandas as pd @@ -18,7 +20,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -51,7 +53,7 @@ from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.pickle import FlytePickle -from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer from flytekit.types.schema import FlyteSchema from flytekit.types.schema.types_pandas import PandasDataFrameTransformer from flytekit.types.structured.structured_dataset import StructuredDataset @@ -170,6 +172,51 @@ class Foo(object): assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) +def test_annotated_type(): + class JsonTypeTransformer(TypeTransformer[T]): + LiteralType = LiteralType( + simple=SimpleType.STRING, annotation=TypeAnnotation(annotations=dict(protocol="json")) + ) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + return self.LiteralType + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: + return json.loads(lv.scalar.primitive.string_value) + + def to_literal( + self, ctx: FlyteContext, python_val: T, python_type: typing.Type[T], expected: LiteralType + ) -> Literal: + return Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(python_val)))) + + class JSONSerialized: + def __class_getitem__(cls, item: Type[T]): + return Annotated[item, JsonTypeTransformer(name=f"json[{item}]", t=item)] + + MyJsonDict = JSONSerialized[typing.Dict[str, int]] + _, test_transformer = get_args(MyJsonDict) + + assert TypeEngine.get_transformer(MyJsonDict) is test_transformer + assert TypeEngine.to_literal_type(MyJsonDict) == JsonTypeTransformer.LiteralType + + test_dict = {"foo": 1} + test_literal = Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(test_dict)))) + + assert ( + TypeEngine.to_python_value( + FlyteContext.current_context(), + test_literal, + MyJsonDict, + ) + == test_dict + ) + + assert ( + TypeEngine.to_literal(FlyteContext.current_context(), test_dict, MyJsonDict, JsonTypeTransformer.LiteralType) + == test_literal + ) + + def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() @@ -1472,21 +1519,21 @@ def test_multiple_annotations(): TypeEngine.to_literal_type(t) -TestSchema = FlyteSchema[kwtypes(some_str=str)] +TestSchema = FlyteSchema[kwtypes(some_str=str)] # type: ignore @dataclass_json @dataclass class InnerResult: number: int - schema: TestSchema + schema: TestSchema # type: ignore @dataclass_json @dataclass class Result: result: InnerResult - schema: TestSchema + schema: TestSchema # type: ignore def test_schema_in_dataclass(): @@ -1574,3 +1621,67 @@ def test_file_ext_with_flyte_file_wrong_type(): with pytest.raises(ValueError) as e: FlyteFile[WRONG_TYPE] assert str(e.value) == "Underlying type of File Extension must be of type " + + +def test_is_batchable(): + assert ListTransformer.is_batchable(typing.List[int]) is False + assert ListTransformer.is_batchable(typing.List[str]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False + assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False + + assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True + assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True + assert ( + ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)]) + is True + ) + + +@pytest.mark.parametrize( + "python_val, python_type, expected_list_length", + [ + # Case 1: List of FlytePickle objects with default batch size. + # (By default, the batch_size is set to the length of the whole list.) + # After converting to literal, the result will be [batched_FlytePickle(5 items)]. + # Therefore, the expected list length is [1]. + ([{"foo"}] * 5, typing.List[FlytePickle], [1]), + # Case 2: List of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. + # Therefore, the expected list length is [3]. + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]), + # Case 3: Nested list of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] + # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]), + # Case 4: Empty list + ([[], typing.List[FlytePickle], []]), + ], +) +def test_batch_pickle_list(python_val, python_type, expected_list_length): + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + tmp_lv = lv + for length in expected_list_length: + # Check that after converting to literal, the length of the literal list is equal to: + # - the length of the original list divided by the batch size if not nested + # - the length of the original list if it contains a nested list + assert len(tmp_lv.collection.literals) == length + tmp_lv = tmp_lv.collection.literals[0] + + pv = TypeEngine.to_python_value(ctx, lv, python_type) + # Check that after converting literal to Python value, the result is equal to the original python values. + assert pv == python_val + if get_origin(python_type) is Annotated: + pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) + # Remove the annotation and check that after converting to Python value, the result is equal + # to the original input values. This is used to simulate the following case: + # @workflow + # def wf(): + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] + # task1(data=data) # task1(data: typing.List[FlytePickle]) + assert pv == python_val diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 9da416c1e8..412ec23cc1 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -3,12 +3,12 @@ import functools import os import random +import re import tempfile import typing from collections import OrderedDict from dataclasses import dataclass from enum import Enum -from textwrap import dedent import pandas import pandas as pd @@ -184,13 +184,11 @@ def my_wf(a: int, b: str) -> (int, str): assert my_wf._output_bindings[0].var == "o0" assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output" - nt = typing.NamedTuple("SingleNT", t1_int_output=float) + nt = typing.NamedTuple("SingleNT", [("t1_int_output", float)]) @task def t3(a: int) -> nt: - return nt( - a + 2, - ) + return nt(a + 2) assert t3.python_interface.output_tuple_name == "SingleNT" assert t3.interface.outputs["t1_int_output"] is not None @@ -492,11 +490,13 @@ def t1(path: str) -> DatasetStruct: def wf(path: str) -> DatasetStruct: return t1(path=path) - res = wf(path="/tmp/somewhere") - assert "parquet" == res.a.file_format - assert "parquet" == res.b.a.file_format - assert_frame_equal(df, res.a.open(pd.DataFrame).all()) - assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "df_file") + res = wf(path=fname) + assert "parquet" == res.a.file_format + assert "parquet" == res.b.a.file_format + assert_frame_equal(df, res.a.open(pd.DataFrame).all()) + assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) def test_wf1_with_map(): @@ -890,7 +890,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_subwf(a: int) -> (str, str): + def my_subwf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -1414,7 +1414,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_wf(a: int, b: str) -> (str, typing.List[str]): + def my_wf(a: int, b: str) -> typing.Tuple[str, typing.List[str]]: @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] @@ -1454,7 +1454,7 @@ def t1() -> str: return "Hello" @workflow - def wf() -> typing.NamedTuple("OP", a=str, b=str): + def wf() -> typing.NamedTuple("OP", [("a", str), ("b", str)]): # type: ignore return t1(), t1() assert wf() == ("Hello", "Hello") @@ -1627,16 +1627,28 @@ def foo2(a: int, b: str) -> typing.Tuple[int, str]: def foo3(a: typing.Dict) -> typing.Dict: return a - with pytest.raises(TypeError, match="Type of Val 'hello' is not an instance of "): + with pytest.raises( + TypeError, + match=( + "Failed to convert inputs of task 'tests.flytekit.unit.core.test_type_hints.foo':\n" + " Failed argument 'a': Expected value of type but got 'hello' of type " + ), + ): foo(a="hello", b=10) # type: ignore with pytest.raises( TypeError, - match="Failed to convert return value for var o0 for " "function tests.flytekit.unit.core.test_type_hints.foo2", + match=( + "Failed to convert outputs of task 'tests.flytekit.unit.core.test_type_hints.foo2' at position 0:\n" + " Expected value of type but got 'hello' of type " + ), ): foo2(a=10, b="hello") - with pytest.raises(TypeError, match="Not a collection type simple: STRUCT\n but got a list \\[{'hello': 2}\\]"): + with pytest.raises( + TypeError, + match="Not a collection type but got a list \\[{'hello': 2}\\]", + ): foo3(a=[{"hello": 2}]) # type: ignore @@ -1672,28 +1684,12 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]: with pytest.raises( TypeError, - match=dedent( - r""" - Cannot convert from scalar { - union { - value { - scalar { - primitive { - string_value: "2" - } - } - } - type { - simple: STRING - structure { - tag: "str" - } - } - } - } - to typing.Union\[float, dict\] \(using tag str\) - """ - )[1:-1], + match=re.escape( + "Error encountered while executing 'wf2':\n" + " Failed to convert inputs of task 'tests.flytekit.unit.core.test_type_hints.t2':\n" + ' Cannot convert from to typing.Union[float, dict] (using tag str)' + ), ): assert wf2(a="2") == "2" diff --git a/tests/flytekit/unit/core/test_typing_annotation.py b/tests/flytekit/unit/core/test_typing_annotation.py index 9c2d09c145..2937d9f978 100644 --- a/tests/flytekit/unit/core/test_typing_annotation.py +++ b/tests/flytekit/unit/core/test_typing_annotation.py @@ -18,7 +18,7 @@ env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) -entity_mapping = OrderedDict() +entity_mapping: OrderedDict = OrderedDict() @task diff --git a/tests/flytekit/unit/core/test_utils.py b/tests/flytekit/unit/core/test_utils.py index 112a864b30..5c191b31ee 100644 --- a/tests/flytekit/unit/core/test_utils.py +++ b/tests/flytekit/unit/core/test_utils.py @@ -1,6 +1,8 @@ import pytest -from flytekit.core.utils import _dnsify +import flytekit +from flytekit import FlyteContextManager, task +from flytekit.core.utils import _dnsify, timeit @pytest.mark.parametrize( @@ -20,3 +22,38 @@ ) def test_dnsify(input, expected): assert _dnsify(input) == expected + + +def test_timeit(): + ctx = FlyteContextManager.current_context() + ctx.user_space_params._decks = [] + + with timeit("Set disable_deck to False"): + kwargs = {} + kwargs["disable_deck"] = False + + ctx = FlyteContextManager.current_context() + time_info_list = ctx.user_space_params.timeline_deck.time_info + names = [time_info["Name"] for time_info in time_info_list] + # check if timeit works for flytekit level code + assert "Set disable_deck to False" in names + + @task(**kwargs) + def t1() -> int: + @timeit("Download data") + def download_data(): + return "1" + + data = download_data() + + with timeit("Convert string to int"): + return int(data) + + t1() + + time_info_list = flytekit.current_context().timeline_deck.time_info + names = [time_info["Name"] for time_info in time_info_list] + + # check if timeit works for user level code + assert "Download data" in names + assert "Convert string to int" in names diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 4f1082df63..7bcbcb8ea3 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -45,12 +45,12 @@ def test_default_metadata_values(): def test_workflow_values(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) - def wf(a: int) -> (str, str): + def wf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -95,7 +95,7 @@ def list_output_wf() -> typing.List[int]: def test_sub_wf_single_named_tuple(): - nt = typing.NamedTuple("SingleNamedOutput", named1=int) + nt = typing.NamedTuple("SingleNamedOutput", [("named1", int)]) @task def t1(a: int) -> nt: @@ -116,7 +116,7 @@ def wf(b: int) -> nt: def test_sub_wf_multi_named_tuple(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def t1(a: int) -> nt: @@ -136,6 +136,89 @@ def wf(b: int) -> nt: assert x == (7, 7) +def test_sub_wf_varying_types(): + @task + def t1l( + a: typing.List[typing.Dict[str, typing.List[int]]], + b: typing.Dict[str, typing.List[int]], + c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int], + d: int, + ) -> str: + xx = ",".join([f"{k}:{v}" for d in a for k, v in d.items()]) + yy = ",".join([f"{k}: {i}" for k, v in b.items() for i in v]) + if isinstance(c, list): + zz = ",".join([f"{k}:{v}" for d in c for k, v in d.items()]) + elif isinstance(c, dict): + zz = ",".join([f"{k}: {i}" for k, v in c.items() for i in v]) + else: + zz = str(c) + return f"First: {xx} Second: {yy} Third: {zz} Int: {d}" + + @task + def get_int() -> int: + return 1 + + @workflow + def subwf( + a: typing.List[typing.Dict[str, typing.List[int]]], + b: typing.Dict[str, typing.List[int]], + c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]]], + d: int, + ) -> str: + return t1l(a=a, b=b, c=c, d=d) + + @workflow + def wf() -> str: + ds = [ + {"first_map_a": [42], "first_map_b": [get_int(), 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + ll = { + "ll_1": [get_int(), get_int(), get_int()], + "ll_2": [4, 5, 6], + } + out = subwf(a=ds, b=ll, c=ds, d=get_int()) + return out + + wf.compile() + x = wf() + expected = ( + "First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Third: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Int: 1" + ) + assert x == expected + + @workflow + def wf() -> str: + ds = [ + {"first_map_a": [42], "first_map_b": [get_int(), 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + ll = { + "ll_1": [get_int(), get_int(), get_int()], + "ll_2": [4, 5, 6], + } + out = subwf(a=ds, b=ll, c=ll, d=get_int()) + return out + + x = wf() + expected = ( + "First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Third: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Int: 1" + ) + assert x == expected + + def test_unexpected_outputs(): @task def t1(a: int) -> int: @@ -154,7 +237,7 @@ def no_outputs_wf(): with pytest.raises(AssertionError): @workflow - def one_output_wf() -> int: # noqa + def one_output_wf() -> int: # type: ignore t1(a=3) one_output_wf() @@ -312,10 +395,10 @@ def sd_to_schema_wf() -> pd.DataFrame: @workflow -def schema_to_sd_wf() -> (pd.DataFrame, pd.DataFrame): +def schema_to_sd_wf() -> typing.Tuple[pd.DataFrame, pd.DataFrame]: # schema -> StructuredDataset df = t4() - return t2(df=df), t5(sd=df) + return t2(df=df), t5(sd=df) # type: ignore def test_structured_dataset_wf(): diff --git a/tests/flytekit/unit/core/tracker/d.py b/tests/flytekit/unit/core/tracker/d.py index 9385b0f08d..c84e36fe59 100644 --- a/tests/flytekit/unit/core/tracker/d.py +++ b/tests/flytekit/unit/core/tracker/d.py @@ -9,3 +9,7 @@ def tasks(): @task def foo(): pass + + +def inner_function(a: str) -> str: + return "hello" diff --git a/tests/flytekit/unit/core/tracker/test_arrow_data.py b/tests/flytekit/unit/core/tracker/test_arrow_data.py new file mode 100644 index 0000000000..747e7f1651 --- /dev/null +++ b/tests/flytekit/unit/core/tracker/test_arrow_data.py @@ -0,0 +1,29 @@ +import typing + +import pandas as pd +import pyarrow as pa +from typing_extensions import Annotated + +from flytekit import kwtypes, task + +cols = kwtypes(Name=str, Age=int) +subset_cols = kwtypes(Name=str) + + +@task +def t1( + df1: Annotated[pd.DataFrame, cols], df2: Annotated[pa.Table, cols] +) -> typing.Tuple[Annotated[pd.DataFrame, subset_cols], Annotated[pa.Table, subset_cols]]: + return df1, df2 + + +def test_structured_dataset_wf(): + pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + pa_df = pa.Table.from_pandas(pd_df) + + subset_pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"]}) + subset_pa_df = pa.Table.from_pandas(subset_pd_df) + + df1, df2 = t1(df1=pd_df, df2=pa_df) + assert df1.equals(subset_pd_df) + assert df2.equals(subset_pa_df) diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index 33ae18acd5..b33725436d 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -79,3 +79,10 @@ def test_extract_task_module(test_input, expected): except Exception: FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT = old raise + + +local_task = task(d.inner_function) + + +def test_local_task_wrap(): + assert local_task.instantiated_in == "tests.flytekit.unit.core.tracker.test_tracking" diff --git a/tests/flytekit/unit/deck/test_deck.py b/tests/flytekit/unit/deck/test_deck.py index a6b00e79e2..f65c94b877 100644 --- a/tests/flytekit/unit/deck/test_deck.py +++ b/tests/flytekit/unit/deck/test_deck.py @@ -1,3 +1,5 @@ +import datetime + import pandas as pd import pytest from mock import mock @@ -23,12 +25,30 @@ def test_deck(): _output_deck("test_task", ctx.user_space_params) +def test_timeline_deck(): + time_info = dict( + Name="foo", + Start=datetime.datetime.utcnow(), + Finish=datetime.datetime.utcnow() + datetime.timedelta(microseconds=1000), + WallTime=1.0, + ProcessTime=1.0, + ) + ctx = FlyteContextManager.current_context() + ctx.user_space_params._decks = [] + timeline_deck = ctx.user_space_params.timeline_deck + timeline_deck.append_time_info(time_info) + assert timeline_deck.name == "Timeline" + assert len(timeline_deck.time_info) == 1 + assert timeline_deck.time_info[0] == time_info + assert len(ctx.user_space_params.decks) == 1 + + @pytest.mark.parametrize( "disable_deck,expected_decks", [ - (None, 0), - (False, 2), # input and output decks - (True, 0), + (None, 1), # time line deck + (False, 3), # time line deck + input and output decks + (True, 1), # time line deck ], ) def test_deck_for_task(disable_deck, expected_decks): @@ -49,9 +69,9 @@ def t1(a: int) -> str: @pytest.mark.parametrize( "disable_deck, expected_decks", [ - (None, 1), - (False, 1 + 2), # input and output decks - (True, 1), + (None, 2), # default deck and time line deck + (False, 4), # default deck and time line deck + input and output decks + (True, 2), # default deck and time line deck ], ) def test_deck_pandas_dataframe(disable_deck, expected_decks): diff --git a/tests/flytekit/unit/extend/test_backend_plugin.py b/tests/flytekit/unit/extend/test_backend_plugin.py new file mode 100644 index 0000000000..9dfd20d99e --- /dev/null +++ b/tests/flytekit/unit/extend/test_backend_plugin.py @@ -0,0 +1,105 @@ +import typing +from datetime import timedelta +from unittest.mock import MagicMock + +import grpc +from flyteidl.service.external_plugin_service_pb2 import ( + PERMANENT_FAILURE, + SUCCEEDED, + TaskCreateRequest, + TaskCreateResponse, + TaskDeleteRequest, + TaskDeleteResponse, + TaskGetRequest, + TaskGetResponse, +) + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry +from flytekit.extend.backend.external_plugin_service import BackendPluginServer +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +dummy_id = "dummy_id" + + +class DummyPlugin(BackendPluginBase): + def __init__(self): + super().__init__(task_type="dummy") + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> TaskCreateResponse: + return TaskCreateResponse(job_id=dummy_id) + + def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + return TaskGetResponse(state=SUCCEEDED) + + def delete(self, context: grpc.ServicerContext, job_id) -> TaskDeleteResponse: + return TaskDeleteResponse() + + +BackendPluginRegistry.register(DummyPlugin()) + +task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version") +task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", +) + +int_type = types.LiteralType(types.SimpleType.INTEGER) +interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + }, + {}, +) +task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, +) + +dummy_template = TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type="dummy", + custom={}, +) + + +def test_dummy_plugin(): + ctx = MagicMock(spec=grpc.ServicerContext) + p = BackendPluginRegistry.get_plugin(ctx, "dummy") + assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == dummy_id + assert p.get(ctx, dummy_id).state == SUCCEEDED + assert p.delete(ctx, dummy_id) == TaskDeleteResponse() + + +def test_backend_plugin_server(): + server = BackendPluginServer() + ctx = MagicMock(spec=grpc.ServicerContext) + request = TaskCreateRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + ) + + assert server.CreateTask(request, ctx).job_id == dummy_id + assert server.GetTask(TaskGetRequest(task_type="dummy", job_id=dummy_id), ctx).state == SUCCEEDED + assert server.DeleteTask(TaskDeleteRequest(task_type="dummy", job_id=dummy_id), ctx) == TaskDeleteResponse() + + res = server.GetTask(TaskGetRequest(task_type="fake", job_id=dummy_id), ctx) + assert res.state == PERMANENT_FAILURE diff --git a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py b/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py deleted file mode 100644 index d2c50cc4a9..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py +++ /dev/null @@ -1,35 +0,0 @@ -import mock - -from flytekit import GCSPersistence - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1") - mock_exec.assert_called_with(["gsutil", "cp", "/test", "gs://my-bucket/k1"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "/test/*", "gs://my-bucket/k1/"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test") - mock_exec.assert_called_with(["gsutil", "cp", "gs://my-bucket/k1", "/test"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "gs://my-bucket/k1/*", "/test"]) diff --git a/tests/flytekit/unit/extras/persistence/test_http.py b/tests/flytekit/unit/extras/persistence/test_http.py deleted file mode 100644 index 893b43f364..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_http.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from flytekit import HttpPersistence - - -def test_put(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.put("", "", recursive=True) - - -def test_construct_path(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.construct_path(True, False, "", "") - - -def test_exists(): - proxy = HttpPersistence() - assert proxy.exists("https://flyte.org") diff --git a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py deleted file mode 100644 index a6f29f36d6..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py +++ /dev/null @@ -1,80 +0,0 @@ -from datetime import timedelta - -import mock - -from flytekit import S3Persistence -from flytekit.configuration import DataConfig, S3Config -from flytekit.extras.persistence import s3_awscli - - -def test_property(): - aws = S3Persistence("s3://raw-output") - assert aws.default_prefix == "s3://raw-output" - - -def test_construct_path(): - aws = S3Persistence() - p = aws.construct_path(True, False, "xyz") - assert p == "s3://xyz" - - -@mock.patch("flytekit.extras.persistence.s3_awscli.S3Persistence._check_binary") -@mock.patch("flytekit.extras.persistence.s3_awscli.subprocess") -def test_retries(mock_subprocess, mock_check): - mock_subprocess.check_call.side_effect = Exception("test exception (404)") - mock_check.return_value = True - - proxy = S3Persistence(data_config=DataConfig(s3=S3Config(backoff=timedelta(seconds=0)))) - assert proxy.exists("s3://test/fdsa/fdsa") is False - assert mock_subprocess.check_call.call_count == 8 - - -def test_extra_args(): - assert s3_awscli._extra_args({}) == [] - assert s3_awscli._extra_args({"ContentType": "ct"}) == ["--content-type", "ct"] - assert s3_awscli._extra_args({"ContentEncoding": "ec"}) == ["--content-encoding", "ec"] - assert s3_awscli._extra_args({"ACL": "acl"}) == ["--acl", "acl"] - assert s3_awscli._extra_args({"ContentType": "ct", "ContentEncoding": "ec", "ACL": "acl"}) == [ - "--content-type", - "ct", - "--content-encoding", - "ec", - "--acl", - "acl", - ] - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1") - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put_recursive(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test") - mock_exec.assert_called_with(cmd=["aws", "s3", "cp", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto()) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get_recursive(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto() - ) diff --git a/tests/flytekit/unit/extras/sqlite3/chinook.zip b/tests/flytekit/unit/extras/sqlite3/chinook.zip new file mode 100644 index 0000000000..6dd568fa61 Binary files /dev/null and b/tests/flytekit/unit/extras/sqlite3/chinook.zip differ diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index 40fc94a3d2..f8014f244b 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,3 +1,5 @@ +import os + import pandas import pytest @@ -10,8 +12,7 @@ from flytekit.types.schema import FlyteSchema ctx = context_manager.FlyteContextManager.current_context() -EXAMPLE_DB = ctx.file_access.get_random_local_path("chinook.zip") -ctx.file_access.get_data("https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip", EXAMPLE_DB) +EXAMPLE_DB = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinook.zip") # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( diff --git a/tests/flytekit/unit/extras/tensorflow/model/__init__.py b/tests/flytekit/unit/extras/tensorflow/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_model.py b/tests/flytekit/unit/extras/tensorflow/model/test_model.py new file mode 100644 index 0000000000..2464345986 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/model/test_model.py @@ -0,0 +1,54 @@ +import tensorflow as tf + +from flytekit import task, workflow + + +@task +def generate_model() -> tf.keras.Model: + inputs = tf.keras.Input(shape=(32,)) + outputs = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.Model(inputs, outputs) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=[ + tf.keras.metrics.BinaryAccuracy(), + ], + ) + return model + + +@task +def generate_sequential_model() -> tf.keras.Sequential: + model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=(32,)), + tf.keras.layers.Dense(1), + ] + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=[ + tf.keras.metrics.BinaryAccuracy(), + ], + ) + return model + + +@task +def model_forward_pass(model: tf.keras.Model) -> tf.Tensor: + x: tf.Tensor = tf.ones((1, 32)) + return model(x) + + +@workflow +def wf(): + model1 = generate_model() + model2 = generate_sequential_model() + model_forward_pass(model=model1) + model_forward_pass(model=model2) + + +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py new file mode 100644 index 0000000000..392ab695c5 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py @@ -0,0 +1,75 @@ +from collections import OrderedDict + +import numpy as np +import pytest +import tensorflow as tf + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.tensorflow import TensorFlowModelTransformer +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def get_tf_model(): + inputs = tf.keras.Input(shape=(32,)) + outputs = tf.keras.layers.Dense(1)(inputs) + tf_model = tf.keras.Model(inputs, outputs) + return tf_model + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + lt = transformer.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.MULTIPART)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + (TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + lt = transformer.get_literal_type(python_type) + + lv = transformer.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + output = transformer.to_python_value(ctx, lv, python_type) + + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.MULTIPART, + ) + ) + assert lv.scalar.blob.uri is not None + for w1, w2 in zip(output.weights, python_val.weights): + np.testing.assert_allclose(w1.numpy(), w2.numpy()) + + +def test_example_model(): + @task + def t1() -> tf.keras.Model: + return get_tf_model() + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorFlowModelTransformer.TENSORFLOW_FORMAT diff --git a/tests/flytekit/unit/models/core/test_security.py b/tests/flytekit/unit/models/core/test_security.py new file mode 100644 index 0000000000..c2933f9353 --- /dev/null +++ b/tests/flytekit/unit/models/core/test_security.py @@ -0,0 +1,13 @@ +from flytekit.models.security import Secret + + +def test_secret(): + obj = Secret("grp", "key") + obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) + assert obj2.key == "key" + assert obj2.group_version is None + + obj = Secret("grp", group_version="v1") + obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) + assert obj2.key is None + assert obj2.group_version == "v1" diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 4b8f82fb7e..5bfd7e4bf6 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -175,15 +175,7 @@ def test_more_stuff(mock_client): # Can't upload a folder with pytest.raises(ValueError): with tempfile.TemporaryDirectory() as tmp_dir: - r._upload_file(pathlib.Path(tmp_dir)) - - # Test that this copies the file. - with tempfile.TemporaryDirectory() as tmp_dir: - mm = MagicMock() - mm.signed_url = os.path.join(tmp_dir, "tmp_file") - mock_client.return_value.get_upload_signed_url.return_value = mm - - r._upload_file(pathlib.Path(__file__)) + r.upload_file(pathlib.Path(tmp_dir)) serialization_settings = flytekit.configuration.SerializationSettings( project="project", diff --git a/tests/flytekit/unit/tools/test_script_mode.py b/tests/flytekit/unit/tools/test_script_mode.py index a433769075..aba4e0ab17 100644 --- a/tests/flytekit/unit/tools/test_script_mode.py +++ b/tests/flytekit/unit/tools/test_script_mode.py @@ -1,13 +1,51 @@ import os +import subprocess +import sys -from flytekit.tools.script_mode import compress_single_script, hash_file +from flytekit.tools.script_mode import compress_scripts, hash_file + +MAIN_WORKFLOW = """ +from flytekit import task, workflow +from wf1.test import t1 -WORKFLOW = """ @workflow def my_wf() -> str: return "hello world" """ +IMPERATIVE_WORKFLOW = """ +from flytekit import Workflow, task + +@task +def t1(a: int): + print(a) + + +wf = Workflow(name="my.imperative.workflow.example") +wf.add_workflow_input("a", int) +node_t1 = wf.add_entity(t1, a=wf.inputs["a"]) +""" + +T1_TASK = """ +from flytekit import task +from wf2.test import t2 + + +@task() +def t1() -> str: + print("hello") + return "hello" +""" + +T2_TASK = """ +from flytekit import task + +@task() +def t2() -> str: + print("hello") + return "hello" +""" + def test_deterministic_hash(tmp_path): workflows_dir = tmp_path / "workflows" @@ -17,19 +55,46 @@ def test_deterministic_hash(tmp_path): open(workflows_dir / "__init__.py", "a").close() # Write a dummy workflow workflow_file = workflows_dir / "hello_world.py" - workflow_file.write_text(WORKFLOW) + workflow_file.write_text(MAIN_WORKFLOW) + + imperative_workflow_file = workflows_dir / "imperative_wf.py" + imperative_workflow_file.write_text(IMPERATIVE_WORKFLOW) + + t1_dir = tmp_path / "wf1" + t1_dir.mkdir() + open(t1_dir / "__init__.py", "a").close() + t1_file = t1_dir / "test.py" + t1_file.write_text(T1_TASK) + + t2_dir = tmp_path / "wf2" + t2_dir.mkdir() + open(t2_dir / "__init__.py", "a").close() + t2_file = t2_dir / "test.py" + t2_file.write_text(T2_TASK) destination = tmp_path / "destination" - compress_single_script(workflows_dir, destination, "hello_world") - print(f"{os.listdir(tmp_path)}") + sys.path.append(str(workflows_dir.parent)) + compress_scripts(str(workflows_dir.parent), str(destination), "workflows.hello_world") digest, hex_digest = hash_file(destination) # Try again to assert digest determinism destination2 = tmp_path / "destination2" - compress_single_script(workflows_dir, destination2, "hello_world") + compress_scripts(str(workflows_dir.parent), str(destination2), "workflows.hello_world") digest2, hex_digest2 = hash_file(destination) assert digest == digest2 assert hex_digest == hex_digest2 + + test_dir = tmp_path / "test" + test_dir.mkdir() + + result = subprocess.run( + ["tar", "-xvf", destination, "-C", test_dir], + stdout=subprocess.PIPE, + ) + result.check_returncode() + assert len(next(os.walk(test_dir))[1]) == 3 + + compress_scripts(str(workflows_dir.parent), str(destination), "workflows.imperative_wf") diff --git a/tests/flytekit/unit/types/directory/__init__.py b/tests/flytekit/unit/types/directory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/types/directory/test_types.py b/tests/flytekit/unit/types/directory/test_types.py new file mode 100644 index 0000000000..199b788733 --- /dev/null +++ b/tests/flytekit/unit/types/directory/test_types.py @@ -0,0 +1,31 @@ +import mock + +from flytekit import FlyteContext +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile + + +def test_new_file_dir(): + fd = FlyteDirectory(path="s3://my-bucket") + assert fd.sep == "/" + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + fd = FlyteDirectory(path="s3://my-bucket/") + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + f = inner_dir.new_file("test") + assert isinstance(f, FlyteFile) + assert f.path == "s3://my-bucket/test/test" + + +def test_new_remote_dir(): + fd = FlyteDirectory.new_remote() + assert FlyteContext.current_context().file_access.raw_output_prefix in fd.path + + +@mock.patch("flytekit.types.directory.types.os.name", "nt") +def test_sep_nt(): + fd = FlyteDirectory(path="file://mypath") + assert fd.sep == "\\" + fd = FlyteDirectory(path="s3://mypath") + assert fd.sep == "/" diff --git a/tests/flytekit/unit/types/file/__init__.py b/tests/flytekit/unit/types/file/__init__.py new file mode 100644 index 0000000000..e69de29bb2