diff --git a/.github/workflows/docs_build.yml b/.github/workflows/docs_build.yml deleted file mode 100644 index 4fd71ce3b0..0000000000 --- a/.github/workflows/docs_build.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Docs Build - -on: - push: - branches: - - master - pull_request: - branches: - - master -jobs: - docs_warnings: - name: Docs Warnings - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: "0" - - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - name: Report Sphinx Warnings - id: sphinx-warnings - run: | - sudo apt-get install python3-sphinx - pip install -r doc-requirements.txt - SPHINXOPTS="-W" cd docs && make html diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 75e356ab0a..e33346afe5 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -65,6 +65,7 @@ jobs: - flytekit-deck-standard - flytekit-dolt - flytekit-duckdb + - flytekit-envd - flytekit-greatexpectations - flytekit-hive - flytekit-k8s-pod @@ -156,19 +157,3 @@ jobs: uses: ludeeus/action-shellcheck@master with: ignore_paths: boilerplate - - docs: - runs-on: ubuntu-latest - steps: - - name: Fetch the code - uses: actions/checkout@v3 - - name: Set up Python 3.9 - uses: actions/setup-python@v4 - with: - python-version: 3.9 - - name: Install dependencies - run: | - python -m pip install --upgrade pip==21.2.4 setuptools wheel - pip install -r doc-requirements.txt - - name: Build the documentation - run: make -C docs html diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 097d82323e..169e0e58b7 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -47,11 +47,28 @@ jobs: run: | make -C plugins build_all_plugins make -C plugins publish_all_plugins - # Added sleep because PYPI take some time in publish - - name: Sleep for 180 seconds - uses: jakejarvis/wait-action@master - with: - time: '180s' + - name: Sleep until pypi is available + id: pypiwait + run: | + # from refs/tags/v1.2.3 get 1.2.3 and make sure it's not an empty string + VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') + if [ -z "$VERSION" ] + then + echo "No tagged version found, exiting" + exit 1 + fi + LINK="https://pypi.org/project/flytekit/${VERSION}" + for i in {1..60}; do + if curl -L -I -s -f ${LINK} >/dev/null; then + echo "Found pypi" + exit 0 + else + echo "Did not find - Retrying in 10 seconds..." + sleep 10 + fi + done + exit 1 + shell: bash outputs: version: ${{ steps.bump.outputs.version }} @@ -120,6 +137,91 @@ jobs: tags: ${{ steps.sqlalchemy-names.outputs.tags }} build-args: | VERSION=${{ needs.deploy.outputs.version }} - file: ./plugins/flytekit-sqlalchemy/Dockerfile.py${{ matrix.python-version }} + PYTHON_VERSION=${{ matrix.python-version }} + file: ./plugins/flytekit-sqlalchemy/Dockerfile + cache-from: type=gha + cache-to: type=gha,mode=max + + build-and-push-external-plugin-service-images: + runs-on: ubuntu-latest + needs: deploy + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + if: ${{ github.event_name == 'release' }} + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: "${{ secrets.FLYTE_BOT_USERNAME }}" + password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare External Plugin Service Image Names + id: external-plugin-service-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/external-plugin-service + tags: | + latest + ${{ github.sha }} + ${{ needs.deploy.outputs.version }} + - name: Push External Plugin Service Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "." + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.external-plugin-service-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + file: ./Dockerfile.external-plugin-service + cache-from: type=gha + cache-to: type=gha,mode=max + + build-and-push-spark-images: + runs-on: ubuntu-latest + needs: deploy + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + if: ${{ github.event_name == 'release' }} + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: "${{ secrets.FLYTE_BOT_USERNAME }}" + password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare Spark Image Names + id: spark-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flytekit + tags: | + spark-latest + spark-${{ github.sha }} + spark-${{ needs.deploy.outputs.version }} + - name: Push Spark Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "./plugins/flytekit-spark/" + platforms: linux/arm64, linux/amd64 + push: ${{ github.event_name == 'release' }} + tags: ${{ steps.spark-names.outputs.tags }} + build-args: | + VERSION=${{ needs.deploy.outputs.version }} + file: ./plugins/flytekit-spark/Dockerfile cache-from: type=gha cache-to: type=gha,mode=max diff --git a/.gitignore b/.gitignore index a4fe02503e..b2e20249a8 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ htmlcov *.ipynb *dat docs/source/_tags/ +.hypothesis diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1fd6e6b648..3007f6e64d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,7 @@ repos: rev: v0.8.0.4 hooks: - id: shellcheck +- repo: https://github.com/conorfalvey/check_pdb_hook + rev: 0.0.9 + hooks: + - id: check_pdb_hook diff --git a/.readthedocs.yml b/.readthedocs.yml index 19b1898e94..18f4292317 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -9,6 +9,8 @@ build: os: ubuntu-20.04 tools: python: "3.9" + apt_packages: + - graphviz # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/Dockerfile b/Dockerfile index 9aa462781c..257fcb5143 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,8 +16,7 @@ RUN apt-get update && apt-get install build-essential -y RUN pip install -U flytekit==$VERSION \ flytekitplugins-pod==$VERSION \ flytekitplugins-deck-standard==$VERSION \ - flytekitplugins-data-fsspec[aws]==$VERSION \ - flytekitplugins-data-fsspec[gcp]==$VERSION \ + flytekitplugins-envd==$VERSION \ scikit-learn RUN useradd -u 1000 flytekit diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 0000000000..b7c5104bbc --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,32 @@ +# This Dockerfile is here to help with end-to-end testing +# From flytekit +# $ docker build -f Dockerfile.dev --build-arg PYTHON_VERSION=3.10 -t localhost:30000/flytekittest:someversion . +# $ docker push localhost:30000/flytekittest:someversion +# From your test user code +# $ pyflyte run --image localhost:30000/flytekittest:someversion + +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-buster + +MAINTAINER Flyte Team +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit + +WORKDIR /root + +ARG VERSION + +RUN apt-get update && apt-get install build-essential vim -y + +COPY . /flytekit + +# Pod tasks should be exposed in the default image +RUN pip install -e /flytekit +RUN pip install -e /flytekit/plugins/flytekit-k8s-pod +RUN pip install -e /flytekit/plugins/flytekit-deck-standard +RUN pip install scikit-learn + +ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" + +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit diff --git a/Dockerfile.external-plugin-service b/Dockerfile.external-plugin-service new file mode 100644 index 0000000000..2194f5de23 --- /dev/null +++ b/Dockerfile.external-plugin-service @@ -0,0 +1,9 @@ +FROM python:3.9-slim-buster + +MAINTAINER Flyte Team +LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit + +ARG VERSION +RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION + +CMD pyflyte serve --port 8000 diff --git a/Makefile b/Makefile index eb9f43cdb6..1f50a8d95d 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,7 @@ setup: install-piptools ## Install requirements .PHONY: fmt fmt: ## Format code with black and isort + autoflake --remove-all-unused-imports --ignore-init-module-imports --ignore-pass-after-docstring --in-place -r flytekit plugins tests pre-commit run black --all-files || true pre-commit run isort --all-files || true diff --git a/dev-requirements.in b/dev-requirements.in index 8655b92bae..a912d8c9d9 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -2,6 +2,7 @@ git+https://github.com/flyteorg/pytest-flyte@main#egg=pytest-flyte coverage[toml] +hypothesis joblib mock pytest @@ -21,3 +22,7 @@ tensorflow==2.8.1; platform_machine!='arm64' or platform_system!='Darwin' # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn +types-protobuf<4 +types-croniter +types-mock +autoflake diff --git a/doc-requirements.in b/doc-requirements.in index e17b05ec5b..b495e7616f 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -1,6 +1,7 @@ . -e file:.#egg=flytekit +grpcio<=1.49.1 git+https://github.com/flyteorg/furo@main sphinx sphinx-gallery @@ -11,7 +12,7 @@ sphinx-autoapi sphinx-copybutton sphinx_fontawesome sphinx-panels -sphinxcontrib-yt +sphinxcontrib-youtube cryptography google-api-core[grpc] scikit-learn diff --git a/doc-requirements.txt b/doc-requirements.txt index 2616e043df..7da1b07b74 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -10,8 +10,24 @@ absl-py==1.4.0 # via # tensorboard # tensorflow +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.1.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore aiosignal==1.3.1 - # via ray + # via + # aiohttp + # ray alabaster==0.7.13 # via sphinx alembic==1.9.2 @@ -46,11 +62,25 @@ asttokens==2.2.1 # via stack-data astunparse==1.6.3 # via tensorflow +async-timeout==4.0.2 + # via aiohttp attrs==22.2.0 # via + # aiohttp # jsonschema # ray # visions +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs babel==2.11.0 # via sphinx backcall==0.2.0 @@ -67,8 +97,10 @@ blake3==0.3.3 # via vaex-core bleach==6.0.0 # via nbconvert -botocore==1.29.61 - # via -r doc-requirements.in +botocore==1.29.76 + # via + # -r doc-requirements.in + # aiobotocore bqplot==0.12.36 # via # ipyvolume @@ -86,31 +118,33 @@ certifi==2022.12.7 cffi==1.15.1 # via # argon2-cffi-bindings + # azure-datalake-store # cryptography cfgv==3.3.1 # via pre-commit chardet==5.0.0 # via binaryornot charset-normalizer==3.0.1 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # dask # databricks-cli - # distributed # flask # flytekit # great-expectations # mlflow # papermill # ray + # rich-click # sphinx-click # uvicorn cloudpickle==2.2.1 # via # dask - # distributed # flytekit # mlflow # shap @@ -128,18 +162,20 @@ croniter==1.3.7 cryptography==39.0.0 # via # -r doc-requirements.in + # adal + # azure-identity + # azure-storage-blob # great-expectations + # msal + # pyjwt # pyopenssl # secretstorage css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib -dask[distributed]==2023.1.1 - # via - # -r doc-requirements.in - # distributed - # vaex-core +dask==2023.1.1 + # via vaex-core databricks-cli==0.17.4 # via mlflow dataclasses-json==0.5.7 @@ -150,8 +186,8 @@ debugpy==1.6.6 # via ipykernel decorator==5.1.1 # via + # gcsfs # ipython - # retry defusedxml==0.7.1 # via nbconvert deprecated==1.2.13 @@ -160,8 +196,6 @@ diskcache==5.4.0 # via flytekit distlib==0.3.6 # via virtualenv -distributed==2023.1.1 - # via dask docker==6.0.1 # via # flytekit @@ -201,7 +235,7 @@ flatbuffers==2.0.7 # via # tensorflow # tf2onnx -flyteidl==1.2.9 +flyteidl==1.2.10 # via flytekit fonttools==4.38.0 # via matplotlib @@ -211,19 +245,26 @@ frozendict==2.3.4 # via vaex-core frozenlist==1.3.3 # via + # aiohttp # aiosignal # ray -fsspec==2023.1.0 +fsspec==2023.4.0 # via # -r doc-requirements.in + # adlfs # dask + # flytekit + # gcsfs # modin + # s3fs furo @ git+https://github.com/flyteorg/furo@main # via -r doc-requirements.in future==0.18.3 # via vaex-core gast==0.5.3 # via tensorflow +gcsfs==2023.4.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.30 @@ -235,30 +276,42 @@ google-api-core[grpc]==2.11.0 # -r doc-requirements.in # google-cloud-bigquery # google-cloud-core + # google-cloud-storage google-auth==2.16.0 # via + # gcsfs # google-api-core # google-auth-oauthlib # google-cloud-core + # google-cloud-storage # kubernetes # tensorboard google-auth-oauthlib==0.4.6 - # via tensorboard + # via + # gcsfs + # tensorboard google-cloud==0.34.0 # via -r doc-requirements.in google-cloud-bigquery==3.5.0 # via -r doc-requirements.in google-cloud-core==2.3.2 - # via google-cloud-bigquery + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-storage==2.8.0 + # via gcsfs google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow google-resumable-media==2.4.1 - # via google-cloud-bigquery + # via + # google-cloud-bigquery + # google-cloud-storage googleapis-common-protos==1.58.0 # via # flyteidl + # flytekit # google-api-core # grpcio-status great-expectations==0.15.46 @@ -267,6 +320,7 @@ greenlet==2.0.2 # via sqlalchemy grpcio==1.48.2 # via + # -r doc-requirements.in # flytekit # google-api-core # google-cloud-bigquery @@ -286,8 +340,6 @@ h5py==3.8.0 # via # tensorflow # vaex-hdf5 -heapdict==1.0.1 - # via zict htmlmin==0.1.12 # via ydata-profiling httptools==0.5.0 @@ -299,6 +351,7 @@ idna==3.4 # anyio # jsonschema # requests + # yarl imagehash==4.3.1 # via visions imagesize==1.4.1 @@ -364,6 +417,8 @@ ipywidgets==8.0.4 # ipyvue # jupyter # pythreejs +isodate==0.6.1 + # via azure-storage-blob isoduration==20.11.0 # via jsonschema itsdangerous==2.1.2 @@ -381,7 +436,6 @@ jinja2==3.1.2 # altair # branca # cookiecutter - # distributed # flask # great-expectations # jinja2-time @@ -469,9 +523,7 @@ libclang==15.0.6.1 llvmlite==0.39.1 # via numba locket==1.0.0 - # via - # distributed - # partd + # via partd lxml==4.9.2 # via sphinx-material makefun==1.15.0 @@ -526,10 +578,18 @@ modin==0.18.1 # via -r doc-requirements.in more-itertools==9.0.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity msgpack==1.0.4 + # via ray +multidict==6.0.4 # via - # distributed - # ray + # aiohttp + # yarl multimethod==1.9.1 # via # visions @@ -650,7 +710,6 @@ packaging==22.0 # via # astropy # dask - # distributed # docker # google-cloud-bigquery # great-expectations @@ -720,7 +779,9 @@ platformdirs==2.6.2 # virtualenv plotly==5.13.0 # via -r doc-requirements.in -pre-commit==3.0.2 +portalocker==2.7.0 + # via msal-extensions +pre-commit==3.0.4 # via sphinx-tags progressbar2==4.2.0 # via vaex-core @@ -757,7 +818,6 @@ protoc-gen-swagger==0.1.0 # via flyteidl psutil==5.9.4 # via - # distributed # ipykernel # modin ptyprocess==0.7.0 @@ -766,8 +826,6 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -py==1.11.0 - # via retry py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 @@ -802,8 +860,11 @@ pygments==2.14.0 # rich # sphinx # sphinx-prompt -pyjwt==2.6.0 - # via databricks-cli +pyjwt[crypto]==2.6.0 + # via + # adal + # databricks-cli + # msal pyopenssl==23.0.0 # via flytekit pyparsing==3.0.9 @@ -816,6 +877,7 @@ pyspark==3.3.1 # via -r doc-requirements.in python-dateutil==2.8.2 # via + # adal # arrow # botocore # croniter @@ -859,7 +921,6 @@ pyyaml==6.0 # astropy # cookiecutter # dask - # distributed # flytekit # jupyter-events # kubernetes @@ -891,21 +952,28 @@ regex==2022.10.31 # via docker-image-py requests==2.28.2 # via + # adal + # azure-core + # azure-datalake-store # cookiecutter # databricks-cli # docker # flytekit + # gcsfs # google-api-core # google-cloud-bigquery + # google-cloud-storage # great-expectations # ipyvolume # kubernetes # mlflow + # msal # papermill # ray # requests-oauthlib # responses # sphinx + # sphinxcontrib-youtube # tensorboard # tf2onnx # vaex-core @@ -916,8 +984,6 @@ requests-oauthlib==1.3.1 # kubernetes responses==0.22.0 # via flytekit -retry==0.9.2 - # via flytekit rfc3339-validator==0.1.4 # via # jsonschema @@ -927,14 +993,21 @@ rfc3986-validator==0.1.1 # jsonschema # jupyter-events rich==13.3.1 - # via vaex-core + # via + # flytekit + # rich-click + # vaex-core +rich-click==1.6.1 + # via flytekit rsa==4.9 # via google-auth ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.7 # via ruamel-yaml -scikit-learn==1.1.1 +s3fs==2023.4.0 + # via flytekit +scikit-learn==1.2.1 # via # -r doc-requirements.in # mlflow @@ -966,11 +1039,14 @@ six==1.16.0 # via # asttokens # astunparse + # azure-core + # azure-identity # bleach # databricks-cli # google-auth # google-pasta # grpcio + # isodate # keras-preprocessing # kubernetes # patsy @@ -992,9 +1068,7 @@ sniffio==1.3.0 snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 - # via - # distributed - # flytekit + # via flytekit soupsieve==2.3.2.post1 # via beautifulsoup4 sphinx==4.5.0 @@ -1012,7 +1086,7 @@ sphinx==4.5.0 # sphinx-panels # sphinx-prompt # sphinx-tags - # sphinxcontrib-yt + # sphinxcontrib-youtube sphinx-autoapi==2.0.1 # via -r doc-requirements.in sphinx-basic-ng==1.0.0b1 @@ -1047,7 +1121,7 @@ sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sphinxcontrib-yt==0.2.2 +sphinxcontrib-youtube==1.2.0 # via -r doc-requirements.in sqlalchemy==1.4.46 # via @@ -1070,8 +1144,6 @@ tabulate==0.9.0 # vaex-core tangled-up-in-unicode==0.2.0 # via visions -tblib==1.7.0 - # via distributed tenacity==8.1.0 # via # papermill @@ -1112,13 +1184,11 @@ toolz==0.12.0 # via # altair # dask - # distributed # partd torch==1.13.1 # via -r doc-requirements.in tornado==6.2 # via - # distributed # ipykernel # jupyter-client # jupyter-server @@ -1167,7 +1237,10 @@ types-toml==0.10.8.1 # via responses typing-extensions==4.4.0 # via + # aioitertools # astroid + # azure-core + # azure-storage-blob # flytekit # great-expectations # onnx @@ -1194,7 +1267,6 @@ uri-template==1.2.0 urllib3==1.26.14 # via # botocore - # distributed # docker # flytekit # great-expectations @@ -1275,6 +1347,7 @@ widgetsnbextension==4.0.5 # via ipywidgets wrapt==1.14.1 # via + # aiobotocore # astroid # deprecated # flytekit @@ -1284,10 +1357,10 @@ xarray==2023.1.0 # via vaex-jupyter xyzservices==2022.9.0 # via ipyleaflet +yarl==1.8.2 + # via aiohttp ydata-profiling==4.0.0 # via pandas-profiling -zict==2.2.0 - # via distributed zipp==3.12.0 # via importlib-metadata diff --git a/docs/source/conf.py b/docs/source/conf.py index 205bcb8838..6c0663f6b5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -57,7 +57,7 @@ "sphinx-prompt", "sphinx_copybutton", "sphinx_panels", - "sphinxcontrib.yt", + "sphinxcontrib.youtube", "sphinx_tags", "sphinx_click", ] diff --git a/docs/source/design/authoring.rst b/docs/source/design/authoring.rst index 46a094399d..3118a9c71b 100644 --- a/docs/source/design/authoring.rst +++ b/docs/source/design/authoring.rst @@ -1,40 +1,42 @@ .. _design-authoring: -####################### +################### Authoring Structure -####################### +################### .. tags:: Design, Basic -One of the core features of Flytekit is to enable users to write tasks and workflows. In this section, we will understand how it works internally. - -.. note:: - - Please refer to the `design doc `__. +Flytekit's main focus is to provide users with the ability to create their own tasks and workflows. +In this section, we'll take a closer look at how it works under the hood. ********************* Types and Type Engine ********************* -Flyte has its own type system, which is codified `in the IDL `__. Python has its own type system despite being a dynamic language, which is primarily explained in `PEP 484 `_. Flytekit needs to build a medium to bridge the gap between these two type systems. -Type Engine -============= -This primariliy happens through the :py:class:`flytekit.extend.TypeEngine`. This engine works by invoking a series of :py:class:`TypeTransformers `. Each transformer is responsible for providing the functionality that the engine requires for a given native Python type. +Flyte uses its own type system, which is defined in the `IDL `__. +Despite being a dynamic language, Python also has its own type system which is primarily explained in `PEP 484 `__. +Therefore, Flytekit needs to establish a means of bridging the gap between these two type systems. +This is primariliy accomplished through the use of :py:class:`flytekit.extend.TypeEngine`. +The ``TypeEngine`` works by invoking a series of :py:class:`TypeTransformers `. +Each transformer is responsible for providing the functionality that the engine requires for a given native Python type. ***************** Callable Entities ***************** -:ref:`Tasks `, :ref:`workflows `, and :ref:`launch plans ` form the core of the Flyte user experience. Each of these concepts is backed by one or more Python classes. These classes in turn, are instantiated by decorators (in the case of tasks and workflow) or a regular Python call (in the case of launch plans). + +The Flyte user experience is built around three main concepts: :ref:`Tasks `, :ref:`workflows `, and :ref:`launch plans `. +Each of these concepts is supported by one or more Python classes, which are instantiated by decorators (in the case of tasks and workflows) or a regular Python call (in the case of launch plans). Tasks ===== -This is the current task class hierarchy: + +Here is the existing hierarchy of task classes: .. inheritance-diagram:: flytekit.core.python_function_task.PythonFunctionTask flytekit.core.python_function_task.PythonInstanceTask flytekit.extras.sqlite3.task.SQLite3Task - :parts: 1 :top-classes: flytekit.core.base_task.Task + :parts: 1 -Please see the documentation on each of the classes for details. +For more information on each of the classes, please refer to the corresponding documentation. .. autoclass:: flytekit.core.base_task.Task :noindex: @@ -48,10 +50,10 @@ Please see the documentation on each of the classes for details. .. autoclass:: flytekit.core.python_function_task.PythonFunctionTask :noindex: - Workflows ========== -There are two workflow classes, and both inherit from the :py:class:`WorkflowBase ` class. + +There exist two workflow classes, both of which derive from the ``WorkflowBase`` class. .. autoclass:: flytekit.core.workflow.PythonFunctionWorkflow :noindex: @@ -59,10 +61,10 @@ There are two workflow classes, and both inherit from the :py:class:`WorkflowBas .. autoclass:: flytekit.core.workflow.ImperativeWorkflow :noindex: +Launch Plans +============ -Launch Plan -=========== -There is only one :py:class:`LaunchPlan ` class. +There exists one :py:class:`LaunchPlan ` class. .. autoclass:: flytekit.core.launch_plan.LaunchPlan :noindex: @@ -72,12 +74,13 @@ There is only one :py:class:`LaunchPlan ` ****************** Exception Handling ****************** -Exception handling takes place along two dimensions: -* System vs. User: We try to differentiate between user exceptions and Flytekit/system-level exceptions. For instance, if Flytekit fails to upload its outputs, that's a system exception. If the user raises a ``ValueError`` because of an unexpected input in the task code, that's a user exception. -* Recoverable vs. Non-recoverable: Recoverable errors will be retried and counted against the task's retry count. Non-recoverable errors will simply fail. System exceptions are by default recoverable (since there's a good chance it was just a blip). +Exception handling occurs along two dimensions: + +* System vs. User: We distinguish between Flytekit/system-level exceptions and user exceptions. For instance, if Flytekit encounters an issue while uploading outputs, it is considered a system exception. On the other hand, if a user raises a ``ValueError`` due to an unexpected input in the task code, it is classified as a user exception. +* Recoverable vs. Non-recoverable: Recoverable errors are retried and counted towards the task's retry count, while non-recoverable errors simply fail. System exceptions are recoverable by default since they are usually temporary. -Here's the user exception tree. Feel free to raise any of these exception classes. Note that the ``FlyteRecoverableException`` is the only recoverable exception. All others, along with all the non-Flytekit defined exceptions, are non-recoverable. +The following is the user exception tree, which users can raise as needed. It is important to note that only ``FlyteRecoverableException`` is a recoverable exception. All other exceptions, including non-Flytekit defined exceptions, are non-recoverable. .. inheritance-diagram:: flytekit.exceptions.user.FlyteValidationException flytekit.exceptions.user.FlyteEntityAlreadyExistsException flytekit.exceptions.user.FlyteValueException flytekit.exceptions.user.FlyteTimeout flytekit.exceptions.user.FlyteAuthenticationException flytekit.exceptions.user.FlyteRecoverableException :parts: 1 @@ -85,36 +88,42 @@ Here's the user exception tree. Feel free to raise any of these exception classe Implementation ============== -For those who want to dig deeper, take a look at the :py:class:`flytekit.common.exceptions.scopes.FlyteScopedException` classes. -There are two decorators that are interspersed throughout the codebase. + +If you wish to delve deeper, you can explore the ``FlyteScopedException`` classes. + +There are two decorators that are used throughout the codebase. .. autofunction:: flytekit.exceptions.scopes.system_entry_point .. autofunction:: flytekit.exceptions.scopes.user_entry_point -************** +************* Call Patterns -************** -The above-mentioned entities (tasks, workflows, and launch plan) are callable. They can be invoked to yield a unit (or units) of work in Flyte. +************* -In Pythonic terms, when you add ``()`` to the end of one of the entities, it invokes the ``__call__`` method on the object. +The entities mentioned above (tasks, workflows, and launch plans) are callable and can be invoked to generate one or more units of work in Flyte. -What happens when a callable entity is called depends on the current context, specifically the current :py:class:`flytekit.FlyteContext` +In Pythonic terminology, adding ``()`` to the end of an entity invokes the ``__call__`` method on the object. -Raw Task Execution -=================== -This is what happens when a task is just run as part of a unit test. The ``@task`` decorator actually turns the decorated function into an instance of the ``PythonFunctionTask`` object, but when a user calls the ``task()`` outside of a workflow, the original function is called without any interference by Flytekit. +The behavior that occurs when a callable entity is invoked is dependent on the current context, specifically the current :py:class:`flytekit.FlyteContext`. -Task Execution Inside Workflow -=============================== -When a workflow is run locally (say as a part of a unit test), certain changes occur in the ``task``. +Raw task execution +================== -Before going further, there is a special object that's worth mentioning, the :py:class:`flytekit.extend.Promise`. +When a task is executed as part of a unit test, the ``@task`` decorator transforms the decorated function into an instance of the ``PythonFunctionTask`` object. +However, when a user invokes the ``task()`` function outside of a workflow, the original function is called without any intervention from Flytekit. + +Task execution inside a workflow +================================ + +When a workflow is executed locally (for instance, as part of a unit test), some modifications are made to the task. + +Before proceeding, it is worth noting a special object, the :py:class:`flytekit.extend.Promise`. .. autoclass:: flytekit.core.promise.Promise :noindex: -Let's assume we have a workflow like :: +Consider the following workflow: :: @task def t1(a: int) -> Tuple[int, str]: @@ -130,19 +139,23 @@ Let's assume we have a workflow like :: d = t2(a=y, b=b) return x, d -As discussed in the Promise object's documentation, when a task is called from inside a workflow, the Python native values returned by the raw underlying functions are first converted into Flyte IDL literals and then wrapped inside ``Promise`` objects. One ``Promise`` is created for every return variable. +As stated in the documentation for the Promise object, when a task is invoked within a workflow, the Python native values returned by the underlying functions are first converted into Flyte IDL literals and then encapsulated inside Promise objects. +One Promise object is created for each return variable. -When the next task is called, the logic is triggered to unwrap these Promises. +When the next task is invoked, the values are extracted from these Promises. Compilation =========== -When a workflow is compiled, instead of producing Promise objects that wrap literal values, they wrap a :py:class:`flytekit.core.promise.NodeOutput` instead. This helps track data dependency between tasks. + +During the workflow compilation process, instead of generating Promise objects that encapsulate literal values, the workflow encapsulates a :py:class:`flytekit.core.promise.NodeOutput`. +This approach aids in tracking the data dependencies between tasks. Branch Skip =========== -If a :py:func:`flytekit.conditional` is determined to be false, then Flytekit will skip calling the task. This avoids running the unintended task. +If the condition specified in a :py:func:`flytekit.conditional` evaluates to ``False``, Flytekit will avoid invoking the corresponding task. +This prevents the unintended execution of the task. .. note:: - We discussed about a task's execution pattern above. The same pattern can be applied to workflows and launch plans too! + The execution pattern that we discussed for tasks can be applied to workflows and launch plans as well! diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 20433c2ad1..4bbac570f3 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -26,6 +26,7 @@ map_task ~core.workflow.ImperativeWorkflow ~core.node_creation.create_node + ~core.promise.NodeOutput FlyteContextManager Running Locally @@ -195,6 +196,10 @@ import sys from typing import Generator +from rich import traceback + +from flytekit.lazy_import.lazy_module import lazy_module + if sys.version_info < (3, 10): from importlib_metadata import entry_points else: @@ -206,7 +211,6 @@ from flytekit.core.condition import conditional from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod @@ -223,8 +227,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch, sklearn, tensorflow -from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence +from flytekit.image_spec import ImageSpec from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase @@ -232,7 +235,7 @@ from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types import directory, file, numpy, schema +from flytekit.types import directory, file from flytekit.types.structured.structured_dataset import ( StructuredDataset, StructuredDatasetFormat, @@ -299,3 +302,6 @@ def load_implicit_plugins(): # Load all implicit plugins load_implicit_plugins() + +# Pretty-print exception messages +traceback.install(width=None, extra_lines=0) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ca7a6cf20d..a9b7c313f0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -10,7 +10,6 @@ import click as _click from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit import PythonFunctionTask from flytekit.configuration import ( SERIALIZED_CONTEXT_ENV_VAR, FastSerializationSettings, @@ -23,7 +22,7 @@ from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapTaskResolver from flytekit.core.promise import VoidPromise from flytekit.exceptions import scopes as _scoped_exceptions from flytekit.exceptions import scopes as _scopes @@ -391,12 +390,8 @@ def _execute_map_task( with setup_execution( raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir ) as ctx: - resolver_obj = load_object_from_module(resolver) - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if not isinstance(_task_def, PythonFunctionTask): - raise Exception("Map tasks cannot be run with instance tasks.") - map_task = MapPythonTask(_task_def, max_concurrency) + mtr = MapTaskResolver() + map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) task_index = _compute_array_job_index() output_prefix = os.path.join(output_prefix, str(task_index)) diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index 94afa13612..ec1fd4d3e1 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -269,9 +269,11 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: refresh_token = response_body["refresh_token"] + if "expires_in" in response_body: + expires_in = response_body["expires_in"] access_token = response_body["access_token"] - return Credentials(access_token, refresh_token, self._endpoint) + return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in) def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 183c1787cd..9582c901d8 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -1,12 +1,10 @@ -import base64 import logging import subprocess import typing from abc import abstractmethod from dataclasses import dataclass -import requests - +from . import token_client from .auth_client import AuthorizationClient from .exceptions import AccessTokenNotFoundError, AuthenticationError from .keyring import Credentials, KeyringStore @@ -22,6 +20,7 @@ class ClientConfig: authorization_endpoint: str redirect_uri: str client_id: str + device_authorization_endpoint: typing.Optional[str] = None scopes: typing.List[str] = None header_key: str = "authorization" @@ -155,67 +154,25 @@ class ClientCredentialsAuthenticator(Authenticator): This Authenticator uses ClientId and ClientSecret to authenticate """ - _utf_8 = "utf-8" - def __init__( self, endpoint: str, client_id: str, client_secret: str, cfg_store: ClientConfigStore, - header_key: str = None, + header_key: typing.Optional[str] = None, + scopes: typing.Optional[typing.List[str]] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") cfg = cfg_store.get_client_config() self._token_endpoint = cfg.token_endpoint - self._scopes = cfg.scopes + # Use scopes from `flytekit.configuration.PlatformConfig` if passed + self._scopes = scopes or cfg.scopes self._client_id = client_id self._client_secret = client_secret super().__init__(endpoint, cfg.header_key or header_key) - @staticmethod - def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]: - """ - :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration - in seconds - """ - headers = { - "Authorization": authorization_header, - "Cache-Control": "no-cache", - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded", - } - body = { - "grant_type": "client_credentials", - } - if scopes is not None: - body["scope"] = ",".join(scopes) - response = requests.post(token_endpoint, data=body, headers=headers) - if response.status_code != 200: - logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise AuthenticationError("Non-200 received from IDP") - - response = response.json() - return response["access_token"], response["expires_in"] - - @staticmethod - def get_basic_authorization_header(client_id: str, client_secret: str) -> str: - """ - This function transforms the client id and the client secret into a header that conforms with http basic auth. - It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text - - :param client_id: str - :param client_secret: str - :rtype: str - """ - concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format( - base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode( - ClientCredentialsAuthenticator._utf_8 - ) - ) - def refresh_credentials(self): """ This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler @@ -229,7 +186,56 @@ def refresh_credentials(self): # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") - authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret) - token, expires_in = self.get_token(token_endpoint, authorization_header, scopes) + authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) + token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) + + +class DeviceCodeAuthenticator(Authenticator): + """ + This Authenticator implements the Device Code authorization flow useful for headless user authentication. + + Examples described + - https://developer.okta.com/docs/guides/device-authorization-grant/main/ + - https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow + """ + + def __init__( + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + audience: typing.Optional[str] = None, + ): + self._audience = audience + cfg = cfg_store.get_client_config() + self._client_id = cfg.client_id + self._device_auth_endpoint = cfg.device_authorization_endpoint + self._scope = cfg.scopes + self._token_endpoint = cfg.token_endpoint + if self._device_auth_endpoint is None: + raise AuthenticationError( + "Device Authentication is not available on the Flyte backend / authentication server" + ) + super().__init__( + endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint) + ) + + def refresh_credentials(self): + resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope) + print( + f""" +To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code} +OR copy paste the following URL: {resp.verification_uri_complete} + """ + ) + try: + # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that + # access tokens can be refreshed for once authenticated machines + token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id) + self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) + KeyringStore.store(self._creds) + except Exception: + KeyringStore.delete(self._endpoint) + raise diff --git a/flytekit/clients/auth/exceptions.py b/flytekit/clients/auth/exceptions.py index 6e790e47a4..5086c5b6e1 100644 --- a/flytekit/clients/auth/exceptions.py +++ b/flytekit/clients/auth/exceptions.py @@ -12,3 +12,11 @@ class AuthenticationError(RuntimeError): """ pass + + +class AuthenticationPending(RuntimeError): + """ + This is raised if the token endpoint returns authentication pending + """ + + pass diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py index c2b19c46b6..79f5e86c68 100644 --- a/flytekit/clients/auth/keyring.py +++ b/flytekit/clients/auth/keyring.py @@ -15,6 +15,7 @@ class Credentials(object): access_token: str refresh_token: str = "na" for_endpoint: str = "flyte-default" + expires_in: typing.Optional[int] = None class KeyringStore: @@ -39,7 +40,7 @@ def store(credentials: Credentials) -> Credentials: credentials.access_token, ) except NoKeyringError as e: - logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return credentials @staticmethod @@ -48,7 +49,7 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]: refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key) access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key) except NoKeyringError as e: - logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return None if not access_token: @@ -61,4 +62,4 @@ def delete(for_endpoint: str): _keyring.delete_password(for_endpoint, KeyringStore._access_token_key) _keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key) except NoKeyringError as e: - logging.warning(f"KeyRing not available, tokens will not be cached. Error: {e}") + logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py new file mode 100644 index 0000000000..7cbb42a13e --- /dev/null +++ b/flytekit/clients/auth/token_client.py @@ -0,0 +1,154 @@ +import base64 +import enum +import logging +import time +import typing +import urllib.parse +from dataclasses import dataclass +from datetime import datetime, timedelta + +import requests + +from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending + +utf_8 = "utf-8" + +# Errors that Token endpoint will return +error_slow_down = "slow_down" +error_auth_pending = "authorization_pending" + + +# Grant Types +class GrantType(str, enum.Enum): + CLIENT_CREDS = "client_credentials" + DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" + + +@dataclass +class DeviceCodeResponse: + """ + Response from device auth flow endpoint + {'device_code': 'code', + 'user_code': 'BNDJJFXL', + 'verification_uri': 'url', + 'verification_uri_complete': 'url', + 'expires_in': 600, + 'interval': 5} + """ + + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + @classmethod + def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse": + return cls( + device_code=j["device_code"], + user_code=j["user_code"], + verification_uri=j["verification_uri"], + verification_uri_complete=j["verification_uri_complete"], + expires_in=j["expires_in"], + interval=j["interval"], + ) + + +def get_basic_authorization_header(client_id: str, client_secret: str) -> str: + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. Secrets are + first URL encoded to escape illegal characters. + + :param client_id: str + :param client_secret: str + :rtype: str + """ + encoded = urllib.parse.quote_plus(client_secret) + concatenated = "{}:{}".format(client_id, encoded) + return "Basic {}".format(base64.b64encode(concatenated.encode(utf_8)).decode(utf_8)) + + +def get_token( + token_endpoint: str, + scopes: typing.Optional[typing.List[str]] = None, + authorization_header: typing.Optional[str] = None, + client_id: typing.Optional[str] = None, + device_code: typing.Optional[str] = None, + grant_type: GrantType = GrantType.CLIENT_CREDS, +) -> typing.Tuple[str, int]: + """ + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + if authorization_header: + headers["Authorization"] = authorization_header + body = { + "grant_type": grant_type.value, + } + if client_id: + body["client_id"] = client_id + if device_code: + body["device_code"] = device_code + if scopes is not None: + body["scope"] = ",".join(scopes) + + response = requests.post(token_endpoint, data=body, headers=headers) + if not response.ok: + j = response.json() + if "error" in j: + err = j["error"] + if err == error_auth_pending or err == error_slow_down: + raise AuthenticationPending(f"Token not yet available, try again in some time {err}") + logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + + j = response.json() + return j["access_token"], j["expires_in"] + + +def get_device_code( + device_auth_endpoint: str, + client_id: str, + audience: typing.Optional[str] = None, + scope: typing.Optional[typing.List[str]] = None, +) -> DeviceCodeResponse: + """ + Retrieves the device Authentication code that can be done to authenticate the request using a browser on a + separate device + """ + payload = {"client_id": client_id, "scope": scope, "audience": audience} + resp = requests.post(device_auth_endpoint, payload) + if not resp.ok: + raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") + return DeviceCodeResponse.from_json_response(resp.json()) + + +def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]: + tick = datetime.now() + interval = timedelta(seconds=resp.interval) + end_time = tick + timedelta(seconds=resp.expires_in) + while tick < end_time: + try: + access_token, expires_in = get_token( + token_endpoint, + grant_type=GrantType.DEVICE_CODE, + client_id=client_id, + device_code=resp.device_code, + ) + print("Authentication successful!") + return access_token, expires_in + except AuthenticationPending: + ... + except Exception: + raise + print("Authentication Pending...") + time.sleep(interval.total_seconds()) + tick = tick + interval + raise AuthenticationError("Authentication failed!") diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 41fc5c025f..93bd883324 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -12,6 +12,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @@ -41,6 +42,7 @@ def get_client_config(self) -> ClientConfig: client_id=public_client_config.client_id, scopes=public_client_config.scopes, header_key=public_client_config.authorization_metadata_key or None, + device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, ) @@ -69,6 +71,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth client_id=cfg.client_id, client_secret=cfg.client_credentials_secret, cfg_store=cfg_store, + scopes=cfg.scopes, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None @@ -78,6 +81,8 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth command=cfg.command, header_key=client_cfg.header_key if client_cfg else None, ) + elif cfg_auth == AuthType.DEVICEFLOW: + return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience) else: raise ValueError( f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index d542af5f7e..2b15dfbd50 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1007,7 +1007,7 @@ def get_upload_signed_url( def get_download_signed_url( self, native_url: str, expires_in: datetime.timedelta = None - ) -> _data_proxy_pb2.CreateUploadLocationResponse: + ) -> _data_proxy_pb2.CreateDownloadLocationRequest: expires_in_pb = None if expires_in: expires_in_pb = Duration() diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 21aec1c4ad..47d5c8c641 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -1167,7 +1167,6 @@ def terminate_execution(host, insecure, cause, urn=None): raise _click.UsageError('Missing option "-u" / "--urn" or missing pipe inputs.') except KeyboardInterrupt: _sys.stdout.flush() - pass else: _terminate_one_execution(client, urn, cause) diff --git a/flytekit/clis/sdk_in_container/backfill.py b/flytekit/clis/sdk_in_container/backfill.py index 234b03499f..49c2667d5b 100644 --- a/flytekit/clis/sdk_in_container/backfill.py +++ b/flytekit/clis/sdk_in_container/backfill.py @@ -1,7 +1,7 @@ import typing from datetime import datetime, timedelta -import click +import rich_click as click from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context from flytekit.clis.sdk_in_container.run import DateTimeType, DurationParamType @@ -10,8 +10,8 @@ The backfill command generates and registers a new workflow based on the input launchplan to run an automated backfill. The workflow can be managed using the Flyte UI and can be canceled, relaunched, and recovered. -- ``launchplan`` refers to the name of the Launchplan -- ``launchplan_version`` is optional and should be a valid version for a Launchplan version. + - ``launchplan`` refers to the name of the Launchplan + - ``launchplan_version`` is optional and should be a valid version for a Launchplan version. """ @@ -168,11 +168,12 @@ def backfill( execute=execute, parallel=parallel, ) - if entity: - console_url = remote.generate_console_url(entity) - if execute: - click.secho(f"\n Execution launched {console_url} to see execution in the console.", fg="green") - return - click.secho(f"\n Workflow registered at {console_url}", fg="green") + if dry_run: + return + console_url = remote.generate_console_url(entity) + if execute: + click.secho(f"\n Execution launched {console_url} to see execution in the console.", fg="green") + return + click.secho(f"\n Workflow registered at {console_url}", fg="green") except StopIteration as e: click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/build.py b/flytekit/clis/sdk_in_container/build.py new file mode 100644 index 0000000000..3e18535268 --- /dev/null +++ b/flytekit/clis/sdk_in_container/build.py @@ -0,0 +1,127 @@ +import os +import pathlib +import typing + +import rich_click as click +from typing_extensions import OrderedDict + +from flytekit.clis.sdk_in_container.constants import CTX_MODULE, CTX_PROJECT_ROOT +from flytekit.clis.sdk_in_container.run import RUN_LEVEL_PARAMS_KEY, get_entities_in_file, load_naive_entity +from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.workflow import PythonFunctionWorkflow +from flytekit.tools.script_mode import _find_project_root +from flytekit.tools.translator import get_serializable + + +def get_workflow_command_base_params() -> typing.List[click.Option]: + """ + Return the set of base parameters added to every pyflyte build workflow subcommand. + """ + return [ + click.Option( + param_decls=["--fast"], + required=False, + is_flag=True, + default=False, + help="Use fast serialization. The image won't contain the source code. The value is false by default.", + ), + ] + + +def build_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): + """ + Returns a function that is used to implement WorkflowCommand and build an image for flyte workflows. + """ + + def _build(*args, **kwargs): + m = OrderedDict() + options = None + run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] + + project, domain = run_level_params.get("project"), run_level_params.get("domain") + serialization_settings = SerializationSettings( + project=project, + domain=domain, + image_config=ImageConfig.auto_default_image(), + ) + if not run_level_params.get("fast"): + serialization_settings.source_root = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT) + + _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) + + return _build + + +class WorkflowCommand(click.MultiCommand): + """ + click multicommand at the python file layer, subcommands should be all the workflows in the file. + """ + + def __init__(self, filename: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self._filename = pathlib.Path(filename).resolve() + + def list_commands(self, ctx): + entities = get_entities_in_file(self._filename.__str__()) + return entities.all() + + def get_command(self, ctx, exe_entity): + """ + This command uses the filename with which this command was created, and the string name of the entity passed + after the Python filename on the command line, to load the Python object, and then return the Command that + click should run. + :param ctx: The click Context object. + :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task + function. + :return: + """ + rel_path = os.path.relpath(self._filename) + if rel_path.startswith(".."): + raise ValueError( + f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}" + ) + + project_root = _find_project_root(self._filename) + rel_path = self._filename.relative_to(project_root) + module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") + + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module + + entity = load_naive_entity(module, exe_entity, project_root) + + cmd = click.Command( + name=exe_entity, + callback=build_command(ctx, entity), + help=f"Build an image for {module}.{exe_entity}.", + ) + return cmd + + +class BuildCommand(click.MultiCommand): + """ + A click command group for building a image for flyte workflows & tasks in a file. + """ + + def __init__(self, *args, **kwargs): + params = get_workflow_command_base_params() + super().__init__(*args, params=params, **kwargs) + + def list_commands(self, ctx): + return [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] + + def get_command(self, ctx, filename): + if ctx.obj: + ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params + return WorkflowCommand(filename, name=filename, help="Build an image for [workflow|task]") + + +_build_help = """ +This command can build an image for a workflow or a task from the command line, for fully self-contained scripts. +""" + +build = BuildCommand( + name="build", + help=_build_help, +) diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index d228babf43..67391abb4d 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -1,4 +1,4 @@ -import click as _click +import rich_click as _click CTX_PROJECT = "project" CTX_DOMAIN = "domain" @@ -10,6 +10,7 @@ CTX_PROJECT_ROOT = "project_root" CTX_MODULE = "module" CTX_VERBOSE = "verbose" +CTX_COPY_ALL = "copy_all" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 6ac451be92..4c66a2046c 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -1,7 +1,7 @@ from dataclasses import replace from typing import Optional -import click +import rich_click as click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE from flytekit.configuration import Config, ImageConfig, get_config_file diff --git a/flytekit/clis/sdk_in_container/init.py b/flytekit/clis/sdk_in_container/init.py index 1ec2f57c32..627b393578 100644 --- a/flytekit/clis/sdk_in_container/init.py +++ b/flytekit/clis/sdk_in_container/init.py @@ -1,4 +1,4 @@ -import click +import rich_click as click from cookiecutter.main import cookiecutter diff --git a/flytekit/clis/sdk_in_container/launchplan.py b/flytekit/clis/sdk_in_container/launchplan.py new file mode 100644 index 0000000000..2d33e2e3d7 --- /dev/null +++ b/flytekit/clis/sdk_in_container/launchplan.py @@ -0,0 +1,74 @@ +import rich_click as click + +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.models.launch_plan import LaunchPlanState + +_launchplan_help = """ +The launchplan command activates or deactivates a specified or the latest version of the launchplan. +If ``--activate`` is chosen then the previous version of the launchplan will be deactivated. + +- ``launchplan`` refers to the name of the Launchplan +- ``launchplan_version`` is optional and should be a valid version for a Launchplan version. If not specified the latest will be used. +""" + + +@click.command("launchplan", help=_launchplan_help) +@click.option( + "-p", + "--project", + required=False, + type=str, + default="flytesnacks", + help="Fecth launchplan from this project", +) +@click.option( + "-d", + "--domain", + required=False, + type=str, + default="development", + help="Fetch launchplan from this domain", +) +@click.option( + "--activate/--deactivate", + required=True, + type=bool, + is_flag=True, + help="Activate or Deactivate the launchplan", +) +@click.argument( + "launchplan", + required=True, + type=str, +) +@click.argument( + "launchplan-version", + required=False, + type=str, + default=None, +) +@click.pass_context +def launchplan( + ctx: click.Context, + project: str, + domain: str, + activate: bool, + launchplan: str, + launchplan_version: str, +): + remote = get_and_save_remote_with_click_context(ctx, project, domain) + try: + launchplan = remote.fetch_launch_plan( + project=project, + domain=domain, + name=launchplan, + version=launchplan_version, + ) + state = LaunchPlanState.ACTIVE if activate else LaunchPlanState.INACTIVE + remote.client.update_launch_plan(id=launchplan.id, state=state) + click.secho( + f"\n Launchplan was set to {LaunchPlanState.enum_to_string(state)}: {launchplan.name}:{launchplan.id.version}", + fg="green", + ) + except StopIteration as e: + click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/local_cache.py b/flytekit/clis/sdk_in_container/local_cache.py index 0dbdc9c621..b0923b842a 100644 --- a/flytekit/clis/sdk_in_container/local_cache.py +++ b/flytekit/clis/sdk_in_container/local_cache.py @@ -1,4 +1,4 @@ -import click +import rich_click as click from flytekit.core.local_cache import LocalTaskCache diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index e457b3d649..f1c6f526e2 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -1,6 +1,6 @@ import os -import click +import rich_click as click from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 5e1136d14c..d9a8fb0c2a 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,18 +1,21 @@ import typing -import click import grpc +import rich_click as click from google.protobuf.json_format import MessageToJson from flytekit import configuration from flytekit.clis.sdk_in_container.backfill import backfill +from flytekit.clis.sdk_in_container.build import build from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_PACKAGES, CTX_VERBOSE from flytekit.clis.sdk_in_container.init import init +from flytekit.clis.sdk_in_container.launchplan import launchplan from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.package import package from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize +from flytekit.clis.sdk_in_container.serve import serve from flytekit.configuration.internal import LocalSDK from flytekit.exceptions.base import FlyteException from flytekit.exceptions.user import FlyteInvalidInputException @@ -36,8 +39,8 @@ def validate_package(ctx, param, values): def pretty_print_grpc_error(e: grpc.RpcError): if isinstance(e, grpc._channel._InactiveRpcError): # noqa - click.secho(f"RPC Failed, with Status: {e.code()}", fg="red") - click.secho(f"\tdetails: {e.details()}", fg="magenta") + click.secho(f"RPC Failed, with Status: {e.code()}", fg="red", bold=True) + click.secho(f"\tdetails: {e.details()}", fg="magenta", bold=True) click.secho(f"\tDebug string {e.debug_error_string()}", dim=True) return @@ -51,12 +54,11 @@ def pretty_print_exception(e: Exception): raise e if isinstance(e, FlyteException): + click.secho(f"Failed with Exception Code: {e._ERROR_CODE}", fg="red") # noqa if isinstance(e, FlyteInvalidInputException): click.secho("Request rejected by the API, due to Invalid input.", fg="red") - click.secho(f"\tReason: {str(e)}", dim=True) click.secho(f"\tInput Request: {MessageToJson(e.request)}", dim=True) - return - click.secho(f"Failed with Exception: Reason: {e._ERROR_CODE}", fg="red") # noqa + cause = e.__cause__ if cause: if isinstance(cause, grpc.RpcError): @@ -72,7 +74,7 @@ def pretty_print_exception(e: Exception): click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa -class ErrorHandlingCommand(click.Group): +class ErrorHandlingCommand(click.RichGroup): def invoke(self, ctx: click.Context) -> typing.Any: try: return super().invoke(ctx) @@ -132,6 +134,9 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: bool): main.add_command(run) main.add_command(register) main.add_command(backfill) +main.add_command(serve) +main.add_command(build) +main.add_command(launchplan) main.epilog if __name__ == "__main__": diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 30c955e351..afc7aeb99e 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -1,7 +1,7 @@ import os import typing -import click +import rich_click as click from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 136831c0bc..9c7228ec46 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -9,7 +9,8 @@ from dataclasses import dataclass from typing import cast -import click +import rich_click as click +import yaml from dataclasses_json import DataClassJsonMixin from pytimeparse import parse from typing_extensions import get_args @@ -17,6 +18,7 @@ from flytekit import BlobType, Literal, Scalar from flytekit.clis.sdk_in_container.constants import ( CTX_CONFIG_FILE, + CTX_COPY_ALL, CTX_DOMAIN, CTX_MODULE, CTX_PROJECT, @@ -37,7 +39,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.models import literals from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union +from flytekit.models.literals import Blob, BlobMetadata, LiteralCollection, LiteralMap, Primitive, Union from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader, script_mode @@ -55,10 +57,6 @@ def remove_prefix(text, prefix): return text -class JsonParamType(click.ParamType): - name = "json object" - - @dataclass class Directory(object): dir_path: str @@ -81,7 +79,7 @@ def convert( raise ValueError( f"Currently only directories containing one file are supported, found [{len(files)}] files found in {p.resolve()}" ) - return Directory(dir_path=value, local_file=files[0].resolve()) + return Directory(dir_path=str(p), local_file=files[0].resolve()) raise click.BadParameter(f"parameter should be a valid directory path, {value}") @@ -134,6 +132,33 @@ def convert( return datetime.timedelta(seconds=parse(value)) +class JsonParamType(click.ParamType): + name = "json object OR json/yaml file path" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value is None: + raise click.BadParameter("None value cannot be converted to a Json type.") + if type(value) == dict or type(value) == list: + return value + try: + return json.loads(value) + except Exception: # noqa + try: + # We failed to load the json, so we'll try to load it as a file + if os.path.exists(value): + # if the value is a yaml file, we'll try to load it as yaml + if value.endswith(".yaml") or value.endswith(".yml"): + with open(value, "r") as f: + return yaml.safe_load(f) + with open(value, "r") as f: + return json.load(f) + raise + except json.JSONDecodeError as e: + raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}") + + @dataclass class DefaultConverter(object): click_type: click.ParamType @@ -215,16 +240,18 @@ def is_bool(self) -> bool: return self._literal_type.simple == SimpleType.BOOLEAN return False - def get_uri_for_dir(self, value: Directory, remote_filename: typing.Optional[str] = None): + def get_uri_for_dir( + self, ctx: typing.Optional[click.Context], value: Directory, remote_filename: typing.Optional[str] = None + ): uri = value.dir_path if self._remote and value.local: md5, _ = script_mode.hash_file(value.local_file) if not remote_filename: remote_filename = value.local_file.name - df_remote_location = self._create_upload_fn(filename=remote_filename, content_md5=md5) - self._flyte_ctx.file_access.put_data(value.local_file, df_remote_location.signed_url) - uri = df_remote_location.native_url[: -len(remote_filename)] + remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] + _, native_url = remote.upload_file(value.local_file) + uri = native_url[: -len(remote_filename)] return uri @@ -232,7 +259,7 @@ def convert_to_structured_dataset( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory ) -> Literal: - uri = self.get_uri_for_dir(value, "00000.parquet") + uri = self.get_uri_for_dir(ctx, value, "00000.parquet") lit = Literal( scalar=Scalar( @@ -254,15 +281,13 @@ def convert_to_blob( value: typing.Union[Directory, FileParam], ) -> Literal: if isinstance(value, Directory): - uri = self.get_uri_for_dir(value) + uri = self.get_uri_for_dir(ctx, value) else: uri = value.filepath if self._remote and value.local: fp = pathlib.Path(value.filepath) - md5, _ = script_mode.hash_file(value.filepath) - df_remote_location = self._create_upload_fn(filename=fp.name, content_md5=md5) - self._flyte_ctx.file_access.put_data(fp, df_remote_location.signed_url) - uri = df_remote_location.native_url + remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] + _, uri = remote.upload_file(fp) lit = Literal( scalar=Scalar( @@ -299,6 +324,68 @@ def convert_to_union( logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") + def convert_to_list( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: list + ) -> Literal: + """ + Convert a python list into a Flyte Literal + """ + if not value: + raise click.BadParameter("Expected non-empty list") + if not isinstance(value, list): + raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(value)}") + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + self._literal_type.collection_type, + type(value[0]), + self._create_upload_fn, + ) + lt = Literal(collection=LiteralCollection([])) + for v in value: + click_val = converter._click_type.convert(v, param, ctx) + lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val)) + return lt + + def convert_to_map( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: dict + ) -> Literal: + """ + Convert a python dict into a Flyte Literal. + It is assumed that the click parameter type is a JsonParamType. The map is also assumed to be univariate. + """ + if not value: + raise click.BadParameter("Expected non-empty dict") + if not isinstance(value, dict): + raise click.BadParameter(f"Expected json dict '{{...}}', parsed value is {type(value)}") + converter = FlyteLiteralConverter( + ctx, + self._flyte_ctx, + self._literal_type.map_value_type, + type(value[list(value.keys())[0]]), + self._create_upload_fn, + ) + lt = Literal(map=LiteralMap({})) + for k, v in value.items(): + click_val = converter._click_type.convert(v, param, ctx) + lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val) + return lt + + def convert_to_struct( + self, + ctx: typing.Optional[click.Context], + param: typing.Optional[click.Parameter], + value: typing.Union[dict, typing.Any], + ) -> Literal: + """ + Convert the loaded json object to a Flyte Literal struct type. + """ + if type(value) != self._python_type: + o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) + else: + o = value + return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) + def convert_to_literal( self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any ) -> Literal: @@ -308,30 +395,18 @@ def convert_to_literal( if self._literal_type.blob: return self.convert_to_blob(ctx, param, value) - if self._literal_type.collection_type or self._literal_type.map_value_type: - # TODO Does not support nested flytefile, flyteschema types - v = json.loads(value) if isinstance(value, str) else value - if self._literal_type.collection_type and not isinstance(v, list): - raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(v)}") - if self._literal_type.map_value_type and not isinstance(v, dict): - raise click.BadParameter("Expected json map '{}', parsed value is {%s}" % type(v)) - return TypeEngine.to_literal(self._flyte_ctx, v, self._python_type, self._literal_type) + if self._literal_type.collection_type: + return self.convert_to_list(ctx, param, value) + + if self._literal_type.map_value_type: + return self.convert_to_map(ctx, param, value) if self._literal_type.union_type: return self.convert_to_union(ctx, param, value) if self._literal_type.simple or self._literal_type.enum_type: if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - if self._python_type == dict: - if type(value) != str: - # The type of default value is dict, so we have to convert it to json string - value = json.dumps(value) - o = json.loads(value) - elif type(value) != self._python_type: - o = cast(DataClassJsonMixin, self._python_type).from_json(value) - else: - o = value - return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) + return self.convert_to_struct(ctx, param, value) return Literal(scalar=self._converter.convert(value, self._python_type)) if self._literal_type.schema: @@ -342,10 +417,15 @@ def convert_to_literal( ) def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]: - lit = self.convert_to_literal(ctx, param, value) - if not self._remote: - return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) - return lit + try: + lit = self.convert_to_literal(ctx, param, value) + if not self._remote: + return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) + return lit + except click.BadParameter: + raise + except Exception as e: + raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e def to_click_option( @@ -368,6 +448,13 @@ def to_click_option( if literal_converter.is_bool() and not default_val: default_val = False + if literal_var.type.simple == SimpleType.STRUCT: + if default_val: + if type(default_val) == dict or type(default_val) == list: + default_val = json.dumps(default_val) + else: + default_val = cast(DataClassJsonMixin, default_val).to_json() + return click.Option( param_decls=[f"--{input_name}"], type=literal_converter.click_type, @@ -426,6 +513,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: default="/root", help="Directory inside the image where the tar file containing the code will be copied to", ), + click.Option( + param_decls=["--copy-all", "copy_all"], + required=False, + is_flag=True, + default=False, + help="Copy all files in the source root directory to the destination directory", + ), click.Option( param_decls=["-i", "--image", "image_config"], required=False, @@ -557,6 +651,7 @@ def _run(*args, **kwargs): destination_dir=run_level_params.get("destination_dir"), source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), + copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL), ) options = None @@ -586,7 +681,7 @@ def _run(*args, **kwargs): return _run -class WorkflowCommand(click.MultiCommand): +class WorkflowCommand(click.RichGroup): """ click multicommand at the python file layer, subcommands should be all the workflows in the file. """ @@ -654,7 +749,7 @@ def get_command(self, ctx, exe_entity): return cmd -class RunCommand(click.MultiCommand): +class RunCommand(click.RichGroup): """ A click command group for registering and executing flyte workflows & tasks in a file. """ diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index eef055ad3c..0c328248e5 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -3,7 +3,7 @@ import typing from enum import Enum as _Enum -import click +import rich_click as click from flytekit.clis.sdk_in_container import constants from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES @@ -69,7 +69,7 @@ def serialize_all( serialize_to_folder(pkgs, serialization_settings, local_source_root, folder) -@click.group("serialize") +@click.group("serialize", cls=click.RichGroup) @click.option( "--image", required=False, @@ -124,7 +124,7 @@ def serialize(ctx, image, local_source_root, in_container_config_path, in_contai ctx.obj[CTX_PYTHON_INTERPRETER] = sys.executable -@click.command("workflows") +@click.command("workflows", cls=click.RichCommand) # For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the # directory for you so it shouldn't be a problem. @click.option("-f", "--folder", type=click.Path(exists=True)) @@ -148,13 +148,13 @@ def workflows(ctx, folder=None): ) -@click.group("fast") +@click.group("fast", cls=click.RichGroup) @click.pass_context def fast(ctx): pass -@click.command("workflows") +@click.command("workflows", cls=click.RichCommand) @click.option( "--deref-symlinks", default=False, diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py new file mode 100644 index 0000000000..71b539d36c --- /dev/null +++ b/flytekit/clis/sdk_in_container/serve.py @@ -0,0 +1,46 @@ +from concurrent import futures + +import click +import grpc +from flyteidl.service.external_plugin_service_pb2_grpc import add_ExternalPluginServiceServicer_to_server + +from flytekit.extend.backend.external_plugin_service import BackendPluginServer + +_serve_help = """Start a grpc server for the external plugin service.""" + + +@click.command("serve", help=_serve_help) +@click.option( + "--port", + default="8000", + is_flag=False, + type=int, + help="Grpc port for the external plugin service", +) +@click.option( + "--worker", + default="10", + is_flag=False, + type=int, + help="Number of workers for the grpc server", +) +@click.option( + "--timeout", + default=None, + is_flag=False, + type=int, + help="It will wait for the specified number of seconds before shutting down grpc server. It should only be used " + "for testing.", +) +@click.pass_context +def serve(_: click.Context, port, worker, timeout): + """ + Start a grpc server for the external plugin service. + """ + click.secho("Starting the external plugin service...", fg="blue") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker)) + add_ExternalPluginServiceServicer_to_server(BackendPluginServer(), server) + + server.add_insecure_port(f"[::]:{port}") + server.start() + server.wait_for_termination(timeout=timeout) diff --git a/tests/flytekit/unit/extras/persistence/__init__.py b/flytekit/clis/sdk_in_container/utils.py similarity index 100% rename from tests/flytekit/unit/extras/persistence/__init__.py rename to flytekit/clis/sdk_in_container/utils.py diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index afed857a26..75b80409cc 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -143,12 +143,14 @@ from io import BytesIO from typing import Dict, List, Optional +import yaml from dataclasses_json import dataclass_json -from docker_image import reference from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.image_spec import ImageSpec +from flytekit.image_spec.image_spec import ImageBuildEngine from flytekit.loggers import logger PROJECT_PLACEHOLDER = "{{ registration.project }}" @@ -205,6 +207,15 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image :param Text tag: e.g. somedocker.com/myimage:someversion123 :rtype: Text """ + from docker_image import reference + + if os.path.isfile(tag): + with open(tag, "r") as f: + image_spec_dict = yaml.safe_load(f) + image_spec = ImageSpec(**image_spec_dict) + ImageBuildEngine.build(image_spec) + tag = image_spec.image_name() + ref = reference.Reference.parse(tag) if not optional_tag and ref["tag"] is None: raise AssertionError(f"Incorrectly formatted image {tag}, missing tag value") @@ -344,6 +355,7 @@ class AuthType(enum.Enum): CLIENTSECRET = "ClientSecret" PKCE = "Pkce" EXTERNALCOMMAND = "ExternalCommand" + DEVICEFLOW = "DeviceFlow" @dataclass(init=True, repr=True, eq=True, frozen=True) @@ -352,7 +364,7 @@ class PlatformConfig(object): This object contains the settings to talk to a Flyte backend (the DNS location of your Admin server basically). :param endpoint: DNS for Flyte backend - :param insecure: Whether to use SSL + :param insecure: Whether or not to use SSL :param insecure_skip_verify: Whether to skip SSL certificate verification :param console_endpoint: endpoint for console if different from Flyte backend :param command: This command is executed to return a token using an external process @@ -376,6 +388,7 @@ class PlatformConfig(object): client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD + audience: typing.Optional[str] = None rpc_retries: int = 3 @classmethod @@ -697,6 +710,7 @@ class SerializationSettings(object): fast_serialization_settings (Optional[FastSerializationSettings]): If the code is being serialized so that it can be fast registered (and thus omit building a Docker image) this object contains additional parameters for serialization. + source_root (Optional[str]): The root directory of the source code. """ image_config: ImageConfig @@ -708,6 +722,7 @@ class SerializationSettings(object): python_interpreter: str = DEFAULT_RUNTIME_PYTHON_INTERPRETER flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None + source_root: Optional[str] = None def __post_init__(self): if self.flytekit_virtualenv_root is None: @@ -781,6 +796,7 @@ def new_builder(self) -> Builder: flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, + source_root=self.source_root, ) def should_fast_serialize(self) -> bool: @@ -831,6 +847,7 @@ class Builder(object): flytekit_virtualenv_root: Optional[str] = None python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None + source_root: Optional[str] = None def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: self.fast_serialization_settings = fss @@ -847,4 +864,5 @@ def build(self) -> SerializationSettings: flytekit_virtualenv_root=self.flytekit_virtualenv_root, python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, + source_root=self.source_root, ) diff --git a/flytekit/configuration/default_images.py b/flytekit/configuration/default_images.py index 8c01041eed..625e69d9ae 100644 --- a/flytekit/configuration/default_images.py +++ b/flytekit/configuration/default_images.py @@ -30,14 +30,19 @@ def default_image(cls) -> str: def find_image_for( cls, python_version: typing.Optional[PythonVersion] = None, flytekit_version: typing.Optional[str] = None ) -> str: + if python_version is None: + python_version = PythonVersion((sys.version_info.major, sys.version_info.minor)) + + return cls._DEFAULT_IMAGE_PREFIXES[python_version] + ( + flytekit_version.replace("v", "") if flytekit_version else cls.get_version_suffix() + ) + + @classmethod + def get_version_suffix(cls) -> str: from flytekit import __version__ if not __version__ or __version__ == "0.0.0+develop": version_suffix = "latest" else: version_suffix = __version__ - if python_version is None: - python_version = PythonVersion((sys.version_info.major, sys.version_info.minor)) - return cls._DEFAULT_IMAGE_PREFIXES[python_version] + ( - flytekit_version.replace("v", "") if flytekit_version else version_suffix - ) + return version_suffix diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 2cf8032a6f..8bbe636227 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -37,8 +37,8 @@ translate_inputs_to_literals, ) from flytekit.core.tracker import TrackedInstance -from flytekit.core.type_engine import TypeEngine -from flytekit.deck.deck import Deck +from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError +from flytekit.core.utils import timeit from flytekit.loggers import logger from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import interface as _interface_models @@ -239,12 +239,17 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # Promises as essentially inputs from previous task executions # native constants are just bound to this specific task (default values for a task input) # Also along with promises and constants, there could be dictionary or list of promises or constants - kwargs = translate_inputs_to_literals( - ctx, - incoming_values=kwargs, - flyte_interface_types=self.interface.inputs, # type: ignore - native_types=self.get_input_types(), - ) + try: + kwargs = translate_inputs_to_literals( + ctx, + incoming_values=kwargs, + flyte_interface_types=self.interface.inputs, + native_types=self.get_input_types(), # type: ignore + ) + except TypeTransformerFailedError as exc: + msg = f"Failed to convert inputs of task '{self.name}':\n {exc}" + logger.error(msg) + raise TypeError(msg) from exc input_literal_map = _literal_models.LiteralMap(literals=kwargs) # if metadata.cache is set, check memoized version @@ -503,13 +508,21 @@ def dispatch_execute( ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now # Translate the input literals to Python native - native_inputs = TypeEngine.literal_map_to_kwargs(exec_ctx, input_literal_map, self.python_interface.inputs) + try: + native_inputs = TypeEngine.literal_map_to_kwargs( + exec_ctx, input_literal_map, self.python_interface.inputs + ) + except Exception as exc: + msg = f"Failed to convert inputs of task '{self.name}':\n {exc}" + logger.error(msg) + raise type(exc)(msg) from exc # TODO: Logger should auto inject the current context information to indicate if the task is running within # a workflow or a subworkflow etc logger.info(f"Invoking {self.name} with inputs: {native_inputs}") try: - native_outputs = self.execute(**native_inputs) + with timeit("Execute user level code"): + native_outputs = self.execute(**native_inputs) except Exception as e: logger.exception(f"Exception when executing {e}") raise e @@ -546,22 +559,26 @@ def dispatch_execute( # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption # built into the IDL that all the values of a literal map are of the same type. - literals = {} - for k, v in native_outputs_as_map.items(): - literal_type = self._outputs_interface[k].type - py_type = self.get_type_for_output_var(k, v) - - if isinstance(v, tuple): - raise TypeError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}") - try: - literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) - except Exception as e: - logger.error(f"Failed to convert return value for var {k} with error {type(e)}: {e}") - raise TypeError( - f"Failed to convert return value for var {k} for function {self.name} with error {type(e)}: {e}" - ) from e + with timeit("Translate the output to literals"): + literals = {} + for i, (k, v) in enumerate(native_outputs_as_map.items()): + literal_type = self._outputs_interface[k].type + py_type = self.get_type_for_output_var(k, v) + + if isinstance(v, tuple): + raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") + try: + literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) + except Exception as e: + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = k if k != f"o{i}" else i + msg = f"Failed to convert outputs of task '{self.name}' at position {key}:\n {e}" + logger.error(msg) + raise TypeError(msg) from e if self._disable_deck is False: + from flytekit.deck.deck import Deck + INPUT = "input" OUTPUT = "output" diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index c1eb933ec6..4b4cfd16f3 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -126,7 +126,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): fa.upload_directory(str(cp), self._checkpoint_dest) else: fname = cp.stem + cp.suffix - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, fname) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), fname]) fa.upload(str(cp), rpath) return @@ -138,7 +138,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): with dest_cp.open("wb") as f: f.write(cp.read()) - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, self.TMP_DST_PATH) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), self.TMP_DST_PATH]) fa.upload(str(dest_cp), rpath) def read(self) -> typing.Optional[bytes]: diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index d470fb54fe..bec3915430 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -4,13 +4,15 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface +from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import Resources, ResourceSpec -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext +_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" + -# TODO: do we need pod_template here? Seems that it is a raw container not running in pods class ContainerTask(PythonTask): """ This is an intermediate class that represents Flyte Tasks that run a container at execution time. This is the vast @@ -47,6 +49,8 @@ def __init__( metadata_format: MetadataFormat = MetadataFormat.JSON, io_strategy: IOStrategy = None, secret_requests: Optional[List[Secret]] = None, + pod_template: Optional["PodTemplate"] = None, + pod_template_name: Optional[str] = None, **kwargs, ): sec_ctx = None @@ -55,6 +59,11 @@ def __init__( if not isinstance(s, Secret): raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) + + # pod_template_name overwrites the metadata.pod_template_name + metadata = metadata or TaskMetadata() + metadata.pod_template_name = pod_template_name + super().__init__( task_type="raw-container", name=name, @@ -74,6 +83,7 @@ def __init__( self._resources = ResourceSpec( requests=requests if requests else Resources(), limits=limits if limits else Resources() ) + self.pod_template = pod_template @property def resources(self) -> ResourceSpec: @@ -91,19 +101,29 @@ def execute(self, **kwargs) -> Any: return None def get_container(self, settings: SerializationSettings) -> _task_model.Container: + # if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container + if self.pod_template is not None: + return None + + return self._get_container(settings) + + def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: + return _task_model.DataLoadingConfig( + input_path=self._input_data_dir, + output_path=self._output_data_dir, + format=self._md_format.value, + enabled=True, + io_strategy=self._io_strategy.value if self._io_strategy else None, + ) + + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env return _get_container_definition( image=self._image, command=self._cmd, args=self._args, - data_loading_config=_task_model.DataLoadingConfig( - input_path=self._input_data_dir, - output_path=self._output_data_dir, - format=self._md_format.value, - enabled=True, - io_strategy=self._io_strategy.value if self._io_strategy else None, - ), + data_loading_config=self._get_data_loading_config(), environment=env, storage_request=self.resources.requests.storage, ephemeral_storage_request=self.resources.requests.ephemeral_storage, @@ -116,3 +136,20 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe gpu_limit=self.resources.limits.gpu, memory_limit=self.resources.limits.mem, ) + + def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: + if self.pod_template is None: + return None + return _task_model.K8sPod( + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), + metadata=_task_model.K8sObjectMetadata( + labels=self.pod_template.labels, + annotations=self.pod_template.annotations, + ), + data_config=self._get_data_loading_config(), + ) + + def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]]: + if self.pod_template is None: + return {} + return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name} diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 7e4600b3bb..3bc776008d 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -27,7 +27,6 @@ from enum import Enum from typing import Generator, List, Optional, Union -from flytekit.clients import friendly as friendly_client # noqa from flytekit.configuration import Config, SecretsConfig, SerializationSettings from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint @@ -39,7 +38,8 @@ from flytekit.models.core import identifier as _identifier if typing.TYPE_CHECKING: - from flytekit.deck.deck import Deck + from flytekit import Deck + from flytekit.clients import friendly as friendly_client # noqa # TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin @@ -84,7 +84,7 @@ class Builder(object): decks: List[Deck] raw_output_prefix: Optional[str] = None execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] = None - working_dir: typing.Optional[utils.AutoDeletingTempDir] = None + working_dir: typing.Optional[str] = None checkpoint: typing.Optional[Checkpoint] = None execution_date: typing.Optional[datetime] = None logging: Optional[_logging.Logger] = None @@ -202,12 +202,10 @@ def raw_output_prefix(self) -> str: return self._raw_output_prefix @property - def working_directory(self) -> utils.AutoDeletingTempDir: + def working_directory(self) -> str: """ A handle to a special working directory for easily producing temporary files. - TODO: Usage examples - TODO: This does not always return a AutoDeletingTempDir """ return self._working_directory @@ -264,10 +262,24 @@ def decks(self) -> typing.List: @property def default_deck(self) -> Deck: - from flytekit.deck.deck import Deck + from flytekit import Deck return Deck("default") + @property + def timeline_deck(self) -> "TimeLineDeck": # type: ignore + from flytekit.deck.deck import TimeLineDeck + + time_line_deck = None + for deck in self.decks: + if isinstance(deck, TimeLineDeck): + time_line_deck = deck + break + if time_line_deck is None: + time_line_deck = TimeLineDeck("Timeline") + + return time_line_deck + def __getattr__(self, attr_name: str) -> typing.Any: """ This houses certain task specific context. For example in Spark, it houses the SparkSession, etc @@ -331,13 +343,13 @@ def __getattr__(self, item: str) -> _GroupSecrets: """ return self._GroupSecrets(item, self) - def get(self, group: str, key: str) -> str: + def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError """ - self.check_group_key(group, key) - env_var = self.get_secrets_env_var(group, key) - fpath = self.get_secrets_file(group, key) + self.check_group_key(group) + env_var = self.get_secrets_env_var(group, key, group_version) + fpath = self.get_secrets_file(group, key, group_version) v = os.environ.get(env_var) if v is not None: return v @@ -348,26 +360,27 @@ def get(self, group: str, key: str) -> str: f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" ) - def get_secrets_env_var(self, group: str, key: str) -> str: + def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Returns a string that matches the ENV Variable to look for the secrets """ - self.check_group_key(group, key) - return f"{self._env_prefix}{group.upper()}_{key.upper()}" + self.check_group_key(group) + l = [k.upper() for k in filter(None, (group, group_version, key))] + return f"{self._env_prefix}{'_'.join(l)}" - def get_secrets_file(self, group: str, key: str) -> str: + def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: """ Returns a path that matches the file to look for the secrets """ - self.check_group_key(group, key) - return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}") + self.check_group_key(group) + l = [k.lower() for k in filter(None, (group, group_version, key))] + l[-1] = f"{self._file_prefix}{l[-1]}" + return os.path.join(self._base_dir, *l) @staticmethod - def check_group_key(group: str, key: str): + def check_group_key(group: str): if group is None or group == "": raise ValueError("secrets group is a mandatory field.") - if key is None or key == "": - raise ValueError("secrets key is a mandatory field.") @dataclass(frozen=True) @@ -538,7 +551,7 @@ class FlyteContext(object): file_access: FileAccessProvider level: int = 0 - flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None + flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None serialization_settings: Optional[SerializationSettings] = None @@ -647,7 +660,7 @@ class Builder(object): level: int = 0 compilation_state: Optional[CompilationState] = None execution_state: Optional[ExecutionState] = None - flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None + flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False @@ -726,7 +739,7 @@ class FlyteContextManager(object): FlyteContextManager manages the execution context within Flytekit. It holds global state of either compilation or Execution. It is not thread-safe and can only be run as a single threaded application currently. Context's within Flytekit is useful to manage compilation state and execution state. Refer to ``CompilationState`` - and ``ExecutionState`` for for information. FlyteContextManager provides a singleton stack to manage these contexts. + and ``ExecutionState`` for more information. FlyteContextManager provides a singleton stack to manage these contexts. Typical usage is diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index d407b3528b..2080f73b9f 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -14,303 +14,52 @@ :template: custom.rst :nosignatures: - DataPersistence - DataPersistencePlugins - DiskPersistence FileAccessProvider - UnsupportedPersistenceOp """ - import os import pathlib -import re import shutil -import sys import tempfile import typing -from abc import abstractmethod -from shutil import copyfile -from typing import Dict, Union +from typing import Any, Dict, Union, cast from uuid import UUID +import fsspec +from fsspec.utils import get_protocol + +from flytekit import configuration from flytekit.configuration import DataConfig -from flytekit.core.utils import PerformanceTimer -from flytekit.exceptions.user import FlyteAssertion, FlyteValueException +from flytekit.core.utils import timeit +from flytekit.exceptions.user import FlyteAssertion from flytekit.interfaces.random import random from flytekit.loggers import logger -CURRENT_PYTHON = sys.version_info[:2] -THREE_SEVEN = (3, 7) - - -class UnsupportedPersistenceOp(Exception): - """ - This exception is raised for all methods when a method is not supported by the data persistence layer - """ - - def __init__(self, message: str): - super(UnsupportedPersistenceOp, self).__init__(message) - - -class DataPersistence(object): - """ - Base abstract type for all DataPersistence operations. This can be extended using the flytekitplugins architecture - """ - - def __init__(self, name: str, default_prefix: typing.Optional[str] = None, **kwargs): - self._name = name - self._default_prefix = default_prefix - - @property - def name(self) -> str: - return self._name - - @property - def default_prefix(self) -> typing.Optional[str]: - return self._default_prefix - - def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: - """ - Returns true if the given path exists, else false - """ - raise UnsupportedPersistenceOp(f"Listing a directory is not supported by the persistence plugin {self.name}") - - @abstractmethod - def exists(self, path: str) -> bool: - """ - Returns true if the given path exists, else false - """ - pass - - @abstractmethod - def get(self, from_path: str, to_path: str, recursive: bool = False): - """ - Retrieves data from from_path and writes to the given to_path (to_path is locally accessible) - """ - pass - - @abstractmethod - def put(self, from_path: str, to_path: str, recursive: bool = False): - """ - Stores data from from_path and writes to the given to_path (from_path is locally accessible) - """ - pass - - @abstractmethod - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: - """ - if add_protocol is true then is prefixed else - Constructs a path in the format *args - delim is dependent on the storage medium. - each of the args is joined with the delim - """ - pass - - -class DataPersistencePlugins(object): - """ - DataPersistencePlugins is the core plugin registry that stores all DataPersistence plugins. To add a new plugin use - - .. code-block:: python - - DataPersistencePlugins.register_plugin("s3:/", DataPersistence(), force=True|False) - - These plugins should always be registered. Follow the plugin registration guidelines to auto-discover your plugins. - """ - - _PLUGINS: Dict[str, typing.Type[DataPersistence]] = {} - - @classmethod - def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], force: bool = False): - """ - Registers the supplied plugin for the specified protocol if one does not already exist. - If one exists and force is default or False, then a TypeError is raised. - If one does not exist then it is registered - If one exists, but force == True then the existing plugin is overridden - """ - if protocol in cls._PLUGINS: - p = cls._PLUGINS[protocol] - if p == plugin: - return - if not force: - raise TypeError( - f"Cannot register plugin {plugin.name} for protocol {protocol} as plugin {p.name} is already" - f" registered for the same protocol. You can force register the new plugin by passing force=True" - ) - - cls._PLUGINS[protocol] = plugin - - @staticmethod - def get_protocol(url: str): - # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391 - parts = re.split(r"(\:\:|\://)", url, 1) - if len(parts) > 1: - return parts[0] - logger.info("Setting protocol to file") - return "file" - - @classmethod - def find_plugin(cls, path: str) -> typing.Type[DataPersistence]: - """ - Returns a plugin for the given protocol, else raise a TypeError - """ - for k, p in cls._PLUGINS.items(): - if cls.get_protocol(path) == k.replace("://", "") or path.startswith(k): - return p - raise TypeError(f"No plugin found for matching protocol of path {path}") +# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 +# for key and secret +_FSSPEC_S3_KEY_ID = "key" +_FSSPEC_S3_SECRET = "secret" +_ANON = "anon" - @classmethod - def print_all_plugins(cls): - """ - Prints all the plugins and their associated protocoles - """ - for k, p in cls._PLUGINS.items(): - print(f"Plugin {p.name} registered for protocol {k}") - @classmethod - def is_supported_protocol(cls, protocol: str) -> bool: - """ - Returns true if the given protocol is has a registered plugin for it - """ - return protocol in cls._PLUGINS +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): + kwargs: Dict[str, Any] = { + "cache_regions": True, + } + if s3_cfg.access_key_id: + kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - @classmethod - def supported_protocols(cls) -> typing.List[str]: - return [k for k in cls._PLUGINS.keys()] + if s3_cfg.secret_access_key: + kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key + # S3fs takes this as a special arg + if s3_cfg.endpoint is not None: + kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} -class DiskPersistence(DataPersistence): - """ - The simplest form of persistence that is available with default flytekit - Disk-based persistence. - This will store all data locally and retrieve the data from local. This is helpful for local execution and simulating - runs. - """ + if anonymous: + kwargs[_ANON] = True - PROTOCOL = "file://" - - def __init__(self, default_prefix: typing.Optional[str] = None, **kwargs): - super().__init__(name="local", default_prefix=default_prefix, **kwargs) - - @staticmethod - def _make_local_path(path): - if not os.path.exists(path): - try: - pathlib.Path(path).mkdir(parents=True, exist_ok=True) - except OSError: # Guard against race condition - if not os.path.isdir(path): - raise - - @staticmethod - def strip_file_header(path: str) -> str: - """ - Drops file:// if it exists from the file - """ - if path.startswith("file://"): - return path.replace("file://", "", 1) - return path - - def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: - if not recursive: - files = os.listdir(self.strip_file_header(path)) - for f in files: - yield f - return - - for root, subdirs, files in os.walk(self.strip_file_header(path)): - for f in files: - yield os.path.join(root, f) - return - - def exists(self, path: str): - return os.path.exists(self.strip_file_header(path)) - - def copy_tree(self, from_path: str, to_path: str): - # TODO: Remove this code after support for 3.7 is dropped and inline this function back - # 3.7 doesn't have dirs_exist_ok - if CURRENT_PYTHON == THREE_SEVEN: - tp = pathlib.Path(self.strip_file_header(to_path)) - if tp.exists(): - if not tp.is_dir(): - raise FlyteValueException(tp, f"Target {tp} exists but is not a dir") - files = os.listdir(tp) - if len(files) != 0: - logger.debug(f"Deleting existing target dir {tp} with files {files}") - shutil.rmtree(tp) - shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path)) - else: - # copytree will overwrite existing files in the to_path - shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True) - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if from_path != to_path: - if recursive: - self.copy_tree(from_path, to_path) - else: - copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - if from_path != to_path: - if recursive: - self.copy_tree(from_path, to_path) - else: - # Emulate s3's flat storage by automatically creating directory path - self._make_local_path(os.path.dirname(self.strip_file_header(to_path))) - # Write the object to a local file in the temp local folder - copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def construct_path(self, _: bool, add_prefix: bool, *args: str) -> str: - # Ignore add_protocol for now. Only complicates things - if add_prefix: - prefix = self.default_prefix if self.default_prefix else "" - return os.path.join(prefix, *args) - return os.path.join(*args) - - -def stringify_path(filepath): - """ - Copied from `filesystem_spec `__ - - Attempt to convert a path-like object to a string. - Parameters - ---------- - filepath: object to be converted - Returns - ------- - filepath_str: maybe a string version of the object - Notes - ----- - Objects supporting the fspath protocol (Python 3.6+) are coerced - according to its __fspath__ method. - For backwards compatibility with older Python version, pathlib.Path - objects are specially coerced. - Any other object is passed through unchanged, which includes bytes, - strings, buffers, or anything else that's not even path-like. - """ - if isinstance(filepath, str): - return filepath - elif hasattr(filepath, "__fspath__"): - return filepath.__fspath__() - elif isinstance(filepath, pathlib.Path): - return str(filepath) - elif hasattr(filepath, "path"): - return filepath.path - else: - return filepath - - -def split_protocol(urlpath): - """ - Copied from `filesystem_spec `__ - Return protocol, path pair - """ - urlpath = stringify_path(urlpath) - if "://" in urlpath: - protocol, path = urlpath.split("://", 1) - if len(protocol) > 1: - # excludes Windows paths - return protocol, path - return None, urlpath + return kwargs class FileAccessProvider(object): @@ -335,13 +84,18 @@ def __init__( local_sandbox_dir_appended = os.path.join(local_sandbox_dir, "local_flytekit") self._local_sandbox_dir = pathlib.Path(local_sandbox_dir_appended) self._local_sandbox_dir.mkdir(parents=True, exist_ok=True) - self._local = DiskPersistence(default_prefix=local_sandbox_dir_appended) + self._local = fsspec.filesystem(None) - self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix)( - default_prefix=raw_output_prefix, data_config=data_config - ) - self._raw_output_prefix = raw_output_prefix self._data_config = data_config if data_config else DataConfig.auto() + self._default_protocol = get_protocol(raw_output_prefix) + self._default_remote = cast(fsspec.AbstractFileSystem, self.get_filesystem(self._default_protocol)) + if os.name == "nt" and raw_output_prefix.startswith("file://"): + raise FlyteAssertion("Cannot use the file:// prefix on Windows.") + self._raw_output_prefix = ( + raw_output_prefix + if raw_output_prefix.endswith(self.sep(self._default_remote)) + else raw_output_prefix + self.sep(self._default_remote) + ) @property def raw_output_prefix(self) -> str: @@ -351,38 +105,112 @@ def raw_output_prefix(self) -> str: def data_config(self) -> DataConfig: return self._data_config + def get_filesystem( + self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs + ) -> typing.Optional[fsspec.AbstractFileSystem]: + if not protocol: + return self._default_remote + if protocol == "file": + kwargs["auto_mkdir"] = True + elif protocol == "s3": + s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + s3kwargs.update(kwargs) + return fsspec.filesystem(protocol, **s3kwargs) # type: ignore + elif protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. + if anonymous: + return None + + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: + protocol = get_protocol(path) + return self.get_filesystem(protocol, anonymous=anonymous, **kwargs) + @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: """ - Deprecated. Lets find a replacement + Deprecated. Let's find a replacement """ - protocol, _ = split_protocol(path) + protocol = get_protocol(path) if protocol is None: return False return protocol != "file" @property def local_sandbox_dir(self) -> os.PathLike: + """ + This is a context based temp dir. + """ return self._local_sandbox_dir @property - def local_access(self) -> DiskPersistence: + def local_access(self) -> fsspec.AbstractFileSystem: return self._local - def construct_random_path( - self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None - ) -> str: + @staticmethod + def strip_file_header(path: str, trim_trailing_sep: bool = False) -> str: """ - Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name + Drops file:// if it exists from the file """ - key = UUID(int=random.getrandbits(128)).hex - if file_path_or_file_name: - _, tail = os.path.split(file_path_or_file_name) - if tail: - return persist.construct_path(False, True, key, tail) - else: - logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") - return persist.construct_path(False, True, key) + if path.startswith("file://"): + return path.replace("file://", "", 1) + return path + + @staticmethod + def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: + f = os.path.join(f, "") + t = os.path.join(t, "") + return f, t + + def sep(self, file_system: typing.Optional[fsspec.AbstractFileSystem]) -> str: + if file_system is None or file_system.protocol == "file": + return os.sep + return file_system.sep + + def exists(self, path: str) -> bool: + try: + file_system = self.get_filesystem_for_path(path) + return file_system.exists(path) + except OSError as oe: + logger.debug(f"Error in exists checking {path} {oe}") + anon_fs = self.get_filesystem(get_protocol(path), anonymous=True) + if anon_fs is not None: + logger.debug(f"Attempting anonymous exists with {anon_fs}") + return anon_fs.exists(path) + raise oe + + def get(self, from_path: str, to_path: str, recursive: bool = False): + file_system = self.get_filesystem_for_path(from_path) + if recursive: + from_path, to_path = self.recursive_paths(from_path, to_path) + try: + if os.name == "nt" and file_system.protocol == "file" and recursive: + return _copytree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + return file_system.get(from_path, to_path, recursive=recursive) + except OSError as oe: + logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") + file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) + if file_system is not None: + logger.debug(f"Attempting anonymous get with {file_system}") + return file_system.get(from_path, to_path, recursive=recursive) + raise oe + + def put(self, from_path: str, to_path: str, recursive: bool = False): + file_system = self.get_filesystem_for_path(to_path) + from_path = self.strip_file_header(from_path) + if recursive: + # Only check this for the local filesystem + if file_system.protocol == "file" and not file_system.isdir(from_path): + raise FlyteAssertion(f"Source path {from_path} is not a directory") + if os.name == "nt" and file_system.protocol == "file": + return _copytree(self.strip_file_header(from_path), self.strip_file_header(to_path)) + from_path, to_path = self.recursive_paths(from_path, to_path) + return file_system.put(from_path, to_path, recursive=recursive) def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ @@ -391,7 +219,20 @@ def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._default_remote, file_path_or_file_name) + default_protocol = self._default_remote.protocol + if type(default_protocol) == list: + default_protocol = default_protocol[0] + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + sep = self.sep(self._default_remote) + tail = sep + tail if tail else tail + if default_protocol == "file": + # Special case the local case, users will not expect to see a file:// prefix + return self.strip_file_header(self.raw_output_prefix) + key + tail + + return self._default_remote.unstrip_protocol(self.raw_output_prefix + key + tail) def get_random_remote_directory(self): return self.get_random_remote_path(None) @@ -400,19 +241,19 @@ def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = N """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._local, file_path_or_file_name) + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + if tail: + return os.path.join(self._local_sandbox_dir, key, tail) + return os.path.join(self._local_sandbox_dir, key) def get_random_local_directory(self) -> str: _dir = self.get_random_local_path(None) pathlib.Path(_dir).mkdir(parents=True, exist_ok=True) return _dir - def exists(self, path: str) -> bool: - """ - checks if the given path exists - """ - return DataPersistencePlugins.find_plugin(path)().exists(path) - def download_directory(self, remote_path: str, local_path: str): """ Downloads directory from given remote to local path @@ -439,39 +280,36 @@ def upload_directory(self, local_path: str, remote_path: str): """ return self.put_data(local_path, remote_path, is_multipart=True) - def get_data(self, remote_path: str, local_path: str, is_multipart=False): + @timeit("Download data to local from remote") + def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False): """ - :param Text remote_path: - :param Text local_path: - :param bool is_multipart: + :param remote_path: + :param local_path: + :param is_multipart: """ try: - with PerformanceTimer(f"Copying ({remote_path} -> {local_path})"): - pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - data_persistence_plugin = DataPersistencePlugins.find_plugin(remote_path) - data_persistence_plugin(data_config=self.data_config).get( - remote_path, local_path, recursive=is_multipart - ) + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + self.get(remote_path, to_path=local_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" f"Original exception: {str(ex)}" ) - def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False): + @timeit("Upload data to remote") + def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False): """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure we don't use the true local proxy if the remote path is a file:// - :param Text local_path: - :param Text remote_path: - :param bool is_multipart: + :param local_path: + :param remote_path: + :param is_multipart: """ try: - with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"): - DataPersistencePlugins.find_plugin(remote_path)(data_config=self.data_config).put( - local_path, remote_path, recursive=is_multipart - ) + local_path = str(local_path) + + self.put(cast(str, local_path), remote_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" @@ -479,8 +317,18 @@ def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_mul ) from ex -DataPersistencePlugins.register_plugin("file://", DiskPersistence) -DataPersistencePlugins.register_plugin("/", DiskPersistence) +def _copytree(source, destination): + if not os.path.exists(destination): + os.makedirs(destination) + for item in os.listdir(source): + s = os.path.join(source, item) + d = os.path.join(destination, item) + if os.path.isdir(s): + _copytree(s, d) + else: + shutil.copy2(s, d) + return destination + flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 954c1ae409..b7d1ee997c 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -21,6 +21,28 @@ T = typing.TypeVar("T") +def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str: + if isinstance(v, tuple): + if v[1]: + return f"{k}: {v[0]}={v[1]}" + return f"{k}: {v[0]}" + return f"{k}: {v}" + + +def repr_type_signature(io: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]]) -> str: + """ + Converts an inputs and outputs to a type signature + """ + s = "(" + i = 0 + for k, v in io.items(): + if i > 0: + s += ", " + s += repr_kv(k, v) + i = i + 1 + return s + ")" + + class Interface(object): """ A Python native interface object, like inspect.signature but simpler. @@ -57,7 +79,9 @@ def __init__( variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): + class Output( # type: ignore + collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore + ): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -167,6 +191,12 @@ def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: new_outputs[k] = v return Interface(self._inputs, new_outputs) + def __str__(self): + return f"{repr_type_signature(self._inputs)} -> {repr_type_signature(self._outputs)}" + + def __repr__(self): + return str(self) + def transform_inputs_to_parameters( ctx: context_manager.FlyteContext, interface: Interface @@ -220,7 +250,7 @@ def transform_interface_to_typed_interface( return _interface_models.TypedInterface(inputs_map, outputs_map) -def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: +def transform_types_to_list_of_type(m: Dict[str, type], bound_inputs: typing.Set[str]) -> Dict[str, type]: """ Converts a given variables to be collections of their type. This is useful for array jobs / map style code. It will create a collection of types even if any one these types is not a collection type @@ -230,6 +260,10 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: all_types_are_collection = True for k, v in m.items(): + if k in bound_inputs: + # Skip the inputs that are bound. If they are bound, it does not matter if they are collection or + # singletons + continue v_type = type(v) if v_type != typing.List and v_type != list: all_types_are_collection = False @@ -240,17 +274,22 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] + if k in bound_inputs: + om[k] = v + else: + om[k] = typing.List[v] # type: ignore return om # type: ignore -def transform_interface_to_list_interface(interface: Interface) -> Interface: +def transform_interface_to_list_interface(interface: Interface, bound_inputs: typing.Set[str]) -> Interface: """ Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map like functions + :param interface: Interface to be upgraded toa list interface + :param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars. """ - map_inputs = transform_types_to_list_of_type(interface.inputs) - map_outputs = transform_types_to_list_of_type(interface.outputs) + map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs) + map_outputs = transform_types_to_list_of_type(interface.outputs, set()) return Interface(inputs=map_inputs, outputs=map_outputs) @@ -286,7 +325,6 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 11cb3b926c..e0b205ca5b 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,10 +1,12 @@ from typing import Optional -import joblib from diskcache import Cache +from flytekit import lazy_module from flytekit.models.literals import Literal, LiteralCollection, LiteralMap +joblib = lazy_module("joblib") + # Location on the filesystem where serialized objects will be stored # TODO: read from config CACHE_LOCATION = "~/.flyte/local-cache" diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 3b5c0a09ca..44f18de2a3 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -2,71 +2,92 @@ Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with a reference task as well as run-time parameters that limit execution concurrency and failure tolerations. """ - +import functools +import hashlib +import logging import os import typing from contextlib import contextmanager -from itertools import count -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Set from flytekit.configuration import SerializationSettings from flytekit.core import tracker -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.tracker import TrackedInstance +from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes from flytekit.models.array_job import ArrayJob from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql +from flytekit.tools.module_loader import load_object_from_module class MapPythonTask(PythonTask): """ A MapPythonTask defines a :py:class:`flytekit.PythonTask` which specifies how to run an inner :py:class:`flytekit.PythonFunctionTask` across a range of inputs in parallel. - TODO: support lambda functions """ - # To support multiple map tasks declared around identical python function tasks, we keep a global count of - # MapPythonTask instances to uniquely differentiate map task names for each declared instance. - _ids = count(0) - def __init__( self, - python_function_task: PythonFunctionTask, - concurrency: int = None, - min_success_ratio: float = None, + python_function_task: typing.Union[PythonFunctionTask, functools.partial], + concurrency: Optional[int] = None, + min_success_ratio: Optional[float] = None, + bound_inputs: Optional[Set[str]] = None, **kwargs, ): """ + Wrapper that creates a MapPythonTask + :param python_function_task: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given - batch size + batch size :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete - successfully before terminating this task and marking it successful. + successfully before terminating this task and marking it successful + :param bound_inputs: List[str] specifies a list of variable names within the interface of python_function_task, + that are already bound and should not be considered as list inputs, but scalar values. This is mostly + useful at runtime and is passed in by MapTaskResolver. This field is not required when a `partial` method + is specified. The bound_vars will be auto-deduced from the `partial.keywords`. """ - if len(python_function_task.python_interface.inputs.keys()) > 1: - raise ValueError("Map tasks only accept python function tasks with 0 or 1 inputs") - - if len(python_function_task.python_interface.outputs.keys()) > 1: + self._partial = None + if isinstance(python_function_task, functools.partial): + # TODO: We should be able to support partial tasks with lists as inputs + for arg in python_function_task.keywords.values(): + if isinstance(arg, list): + raise ValueError("Map tasks do not support partial tasks with lists as inputs. ") + self._partial = python_function_task + actual_task = self._partial.func + else: + actual_task = python_function_task + + if not isinstance(actual_task, PythonFunctionTask): + raise ValueError("Map tasks can only compose of Python Functon Tasks currently") + + if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") - collection_interface = transform_interface_to_list_interface(python_function_task.python_interface) - instance = next(self._ids) - _, mod, f, _ = tracker.extract_task_module(python_function_task.task_function) - name = f"{mod}.mapper_{f}_{instance}" - - self._cmd_prefix = None - self._run_task = python_function_task - self._max_concurrency = concurrency - self._min_success_ratio = min_success_ratio - self._array_task_interface = python_function_task.python_interface - if "metadata" not in kwargs and python_function_task.metadata: - kwargs["metadata"] = python_function_task.metadata - if "security_ctx" not in kwargs and python_function_task.security_context: - kwargs["security_ctx"] = python_function_task.security_context + self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set() + if self._partial: + self._bound_inputs = set(self._partial.keywords.keys()) + + collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) + self._run_task: PythonFunctionTask = actual_task + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() + name = f"{mod}.map_{f}_{h}" + + self._cmd_prefix: typing.Optional[typing.List[str]] = None + self._max_concurrency: typing.Optional[int] = concurrency + self._min_success_ratio: typing.Optional[float] = min_success_ratio + self._array_task_interface = actual_task.python_interface + if "metadata" not in kwargs and actual_task.metadata: + kwargs["metadata"] = actual_task.metadata + if "security_ctx" not in kwargs and actual_task.security_context: + kwargs["security_ctx"] = actual_task.security_context super().__init__( name=name, interface=collection_interface, @@ -76,7 +97,15 @@ def __init__( **kwargs, ) + @property + def bound_inputs(self) -> Set[str]: + return self._bound_inputs + def get_command(self, settings: SerializationSettings) -> List[str]: + """ + TODO ADD bound variables to the resolver. Maybe we need a different resolver? + """ + mt = MapTaskResolver() container_args = [ "pyflyte-map-execute", "--inputs", @@ -90,9 +119,9 @@ def get_command(self, settings: SerializationSettings) -> List[str]: "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - self._run_task.task_resolver.location, + mt.name(), "--", - *self._run_task.task_resolver.loader_args(settings, self._run_task), + *mt.loader_args(settings, self), ] if self._cmd_prefix: @@ -100,7 +129,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): - self._cmd_prefix = cmd # type: ignore + self._cmd_prefix = cmd @contextmanager def prepare_target(self): @@ -135,6 +164,18 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] def run_task(self) -> PythonFunctionTask: return self._run_task + def __call__(self, *args, **kwargs): + """ + This call method modifies the kwargs and adds kwargs from partial. + This is mostly done in the local_execute and compilation only. + At runtime, the map_task is created with all the inputs filled in. to support this, we have modified + the map_task interface in the constructor. + """ + if self._partial: + """If partial exists, then mix-in all partial values""" + kwargs = {**self._partial.keywords, **kwargs} + return super().__call__(*args, **kwargs) + def execute(self, **kwargs) -> Any: ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: @@ -168,7 +209,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]: return self.interface.outputs return self._run_task.interface.outputs - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> type: """ We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this interface to construct outputs. Each instance of an container_array task will however produce outputs @@ -191,7 +232,11 @@ def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: task_index = self._compute_array_job_index() map_task_inputs = {} for k in self.interface.inputs.keys(): - map_task_inputs[k] = kwargs[k][task_index] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + map_task_inputs[k] = v[task_index] + else: + map_task_inputs[k] = v return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs) def _raw_execute(self, **kwargs) -> Any: @@ -213,7 +258,11 @@ def _raw_execute(self, **kwargs) -> Any: for i in range(len(kwargs[any_input_key])): single_instance_inputs = {} for k in self.interface.inputs.keys(): - single_instance_inputs[k] = kwargs[k][i] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + single_instance_inputs[k] = kwargs[k][i] + else: + single_instance_inputs[k] = kwargs[k] o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs) if outputs_expected: outputs.append(o) @@ -221,7 +270,12 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs): +def map_task( + task_function: typing.Union[PythonFunctionTask, functools.partial], + concurrency: int = 0, + min_success_ratio: float = 1.0, + **kwargs, +): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. @@ -267,8 +321,64 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_succes successfully before terminating this task and marking it successful. """ - if not isinstance(task_function, PythonFunctionTask): - raise ValueError( - f"Only Flyte python task types are supported in map tasks currently, received {type(task_function)}" - ) return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs) + + +class MapTaskResolver(TrackedInstance, TaskResolverMixin): + """ + Special resolver that is used for MapTasks. + This exists because it is possible that MapTasks are created using nested "partial" subtasks. + When a maptask is created its interface is interpolated from the interface of the subtask - the interpolation, + simply converts every input into a list/collection input. + + For example: + interface -> (i: int, j: str) -> str => map_task interface -> (i: List[int], j: List[str]) -> List[str] + + But in cases in which `j` is bound to a fixed value by using `functools.partial` we need a way to ensure that + the interface is not simply interpolated, but only the unbound inputs are interpolated. + + .. code-block:: python + + def foo((i: int, j: str) -> str: + ... + + mt = map_task(functools.partial(foo, j=10)) + + print(mt.interface) + + output: + + (i: List[int], j: str) -> List[str] + + But, at runtime this information is lost. To reconstruct this, we use MapTaskResolver that records the "bound vars" + and then at runtime reconstructs the interface with this knowledge + """ + + def name(self) -> str: + return "MapTaskResolver" + + @timeit("Load map task") + def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPythonTask: + """ + Loader args should be of the form + vars "var1,var2,.." resolver "resolver" [resolver_args] + """ + _, bound_vars, _, resolver, *resolver_args = loader_args + logging.info(f"MapTask found task resolver {resolver} and arguments {resolver_args}") + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + bound_inputs = set(bound_vars.split(",")) + return MapPythonTask(python_function_task=_task_def, max_concurrency=max_concurrency, bound_inputs=bound_inputs) + + def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore + return [ + "vars", + f'{",".join(t.bound_inputs)}', + "resolver", + t.run_task.task_resolver.location, + *t.run_task.task_resolver.loader_args(settings, t.run_task), + ] + + def get_all_tasks(self) -> List[Task]: + raise NotImplementedError("MapTask resolver cannot return every instance of the map task") diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 617790746f..9caf00adad 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -84,7 +84,9 @@ def metadata(self) -> _workflow_model.NodeMetadata: def with_overrides(self, *args, **kwargs): if "node_name" in kwargs: - self._id = kwargs["node_name"] + # Convert the node name into a DNS-compliant. + # https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names + self._id = _dnsify(kwargs["node_name"]) if "aliases" in kwargs: alias_dict = kwargs["aliases"] if not isinstance(alias_dict, dict): diff --git a/flytekit/core/pod_template.py b/flytekit/core/pod_template.py index 5e9c746911..98ba92af36 100644 --- a/flytekit/core/pod_template.py +++ b/flytekit/core/pod_template.py @@ -1,22 +1,27 @@ from dataclasses import dataclass -from typing import Dict, Optional - -from kubernetes.client.models import V1PodSpec +from typing import TYPE_CHECKING, Dict, Optional from flytekit.exceptions import user as _user_exceptions +if TYPE_CHECKING: + from kubernetes.client import V1PodSpec + PRIMARY_CONTAINER_DEFAULT_NAME = "primary" -@dataclass +@dataclass(init=True, repr=True, eq=True, frozen=False) class PodTemplate(object): """Custom PodTemplate specification for a Task.""" - pod_spec: V1PodSpec = V1PodSpec(containers=[]) + pod_spec: Optional["V1PodSpec"] = None primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME labels: Optional[Dict[str, str]] = None annotations: Optional[Dict[str, str]] = None def __post_init__(self): + if self.pod_spec is None: + from kubernetes.client import V1PodSpec + + self.pod_spec = V1PodSpec(containers=[]) if not self.primary_container_name: raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 3a851a50ea..24628cd52e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -13,8 +13,9 @@ from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import Interface from flytekit.core.node import Node -from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine +from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models import literals as _literals_models @@ -86,6 +87,12 @@ def extract_value( if len(input_val) == 0: raise sub_type = type(input_val[0]) + # To maintain consistency between translate_inputs_to_literals and ListTransformer.to_literal for batchable types, + # directly call ListTransformer.to_literal to batch process the list items. This is necessary because processing + # each list item separately could lead to errors since ListTransformer.to_python_value may treat the literal + # as it is batched for batchable types. + if ListTransformer.is_batchable(python_type): + return TypeEngine.to_literal(ctx, input_val, python_type, lt) literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) elif isinstance(input_val, dict): @@ -135,7 +142,10 @@ def extract_value( raise ValueError(f"Received unexpected keyword argument {k}") var = flyte_interface_types[k] t = native_types[k] - result[k] = extract_value(ctx, v, t, var.type) + try: + result[k] = extract_value(ctx, v, t, var.type) + except TypeTransformerFailedError as exc: + raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc return result @@ -471,10 +481,14 @@ def create_native_named_tuple( if isinstance(promises, Promise): k, v = [(k, v) for k, v in entity_interface.outputs.items()][0] # get output native type + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = k if k != "o0" else 0 try: return TypeEngine.to_python_value(ctx, promises.val, v) except Exception as e: - raise AssertionError(f"Failed to convert value of output {k}, expected type {v}.") from e + raise TypeError( + f"Failed to convert output in position {key} of value {promises.val}, expected type {v}." + ) from e if len(promises) == 0: return None @@ -484,7 +498,7 @@ def create_native_named_tuple( named_tuple_name = entity_interface.output_tuple_name outputs = {} - for p in promises: + for i, p in enumerate(cast(Tuple[Promise], promises)): if not isinstance(p, Promise): raise AssertionError( "Workflow outputs can only be promises that are returned by tasks. Found a value of" @@ -494,7 +508,9 @@ def create_native_named_tuple( try: outputs[p.var] = TypeEngine.to_python_value(ctx, p.val, t) except Exception as e: - raise AssertionError(f"Failed to convert value of output {p.var}, expected type {t}.") from e + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = p.var if p.var != f"o{i}" else i + raise TypeError(f"Failed to convert output in position {key} of value {p.val}, expected type {t}.") from e # Should this class be part of the Interface? t = collections.namedtuple(named_tuple_name, list(outputs.keys())) @@ -597,11 +613,22 @@ def binding_data_from_python_std( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) - elif isinstance(t_value, list): - if expected_literal_type.collection_type is None: - raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") + elif expected_literal_type.union_type is not None: + for i in range(len(expected_literal_type.union_type.variants)): + try: + lt_type = expected_literal_type.union_type.variants[i] + python_type = get_args(t_value_type)[i] if t_value_type else None + return binding_data_from_python_std(ctx, lt_type, t_value, python_type) + except Exception: + logger.debug( + f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}." + ) + raise AssertionError( + f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}." + ) - sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None + elif isinstance(t_value, list): + sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value @@ -1049,7 +1076,7 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): for k, v in kwargs.items(): if k not in cast(SupportsNodeCreation, entity).python_interface.inputs: raise ValueError( - f"Received unexpected keyword argument {k} in function {cast(SupportsNodeCreation, entity).name}" + f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'" ) ctx = FlyteContextManager.current_context() diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 2d05df3c3d..66a49a819e 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,12 +3,7 @@ import importlib import re from abc import ABC -from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union - -from flyteidl.core import tasks_pb2 as _core_task -from kubernetes.client import ApiClient -from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements +from typing import Callable, Dict, List, Optional, TypeVar, Union from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin @@ -17,7 +12,8 @@ from flytekit.core.resources import Resources, ResourceSpec from flytekit.core.tracked_abc import FlyteTrackedABC from flytekit.core.tracker import TrackedInstance, extract_task_module -from flytekit.core.utils import _get_container_definition +from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec from flytekit.loggers import logger from flytekit.models import task as _task_model from flytekit.models.security import Secret, SecurityContext @@ -26,10 +22,6 @@ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" -def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str: - return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") - - class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): """ A Python AutoContainer task should be used as the base for all extensions that want the user's code to be in the @@ -44,7 +36,7 @@ def __init__( name: str, task_config: T, task_type="python-task", - container_image: Optional[str] = None, + container_image: Optional[Union[str, ImageSpec]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, environment: Optional[Dict[str, str]] = None, @@ -86,7 +78,7 @@ def __init__( raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) - # pod_template_name overwrites the metedata.pod_template_name + # pod_template_name overwrites the metadata.pod_template_name kwargs["metadata"] = kwargs["metadata"] if "metadata" in kwargs else TaskMetadata() kwargs["metadata"].pod_template_name = pod_template_name @@ -124,7 +116,7 @@ def task_resolver(self) -> Optional[TaskResolverMixin]: return self._task_resolver @property - def container_image(self) -> Optional[str]: + def container_image(self) -> Optional[Union[str, ImageSpec]]: return self._container_image @property @@ -189,6 +181,9 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain for elem in (settings.env, self.environment): if elem: env.update(elem) + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self.container_image, ImageSpec): + self.container_image.source_root = settings.source_root return _get_container_definition( image=get_registerable_container_image(self.container_image, settings.image_config), command=[], @@ -207,52 +202,11 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain memory_limit=self.resources.limits.mem, ) - def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = self.pod_template.pod_spec.containers - primary_exists = False - - for container in containers: - if container.name == self.pod_template.primary_container_name: - primary_exists = True - break - - if not primary_exists: - # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=self.pod_template.primary_container_name)) - final_containers = [] - for container in containers: - # In the case of the primary container, we overwrite specific container attributes - # with the default values used in the regular Python task. - # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == self.pod_template.primary_container_name: - sdk_default_container = self._get_container(settings) - container.image = sdk_default_container.image - # clear existing commands - container.command = sdk_default_container.command - # also clear existing args - container.args = sdk_default_container.args - limits, requests = {}, {} - for resource in sdk_default_container.resources.limits: - limits[_sanitize_resource_name(resource)] = resource.value - for resource in sdk_default_container.resources.requests: - requests[_sanitize_resource_name(resource)] = resource.value - resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) - if len(limits) > 0 or len(requests) > 0: - # Important! Only copy over resource requirements if they are non-empty. - container.resources = resource_requirements - container.env = [V1EnvVar(name=key, value=val) for key, val in sdk_default_container.env.items()] + ( - container.env or [] - ) - final_containers.append(container) - self.pod_template.pod_spec.containers = final_containers - - return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec) - def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: return None return _task_model.K8sPod( - pod_spec=self._serialize_pod_spec(settings), + pod_spec=_serialize_pod_spec(self.pod_template, self._get_container(settings)), metadata=_task_model.K8sObjectMetadata( labels=self.pod_template.labels, annotations=self.pod_template.annotations, @@ -274,7 +228,8 @@ class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): def name(self) -> str: return "DefaultTaskResolver" - def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask: + @timeit("Load task") + def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, task_module, _, task_name, *_ = loader_args task_module = importlib.import_module(task_module) @@ -298,12 +253,16 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: default_task_resolver = DefaultTaskResolver() -def get_registerable_container_image(img: Optional[str], cfg: ImageConfig) -> str: +def get_registerable_container_image(img: Optional[Union[str, ImageSpec]], cfg: ImageConfig) -> str: """ - :param img: Configured image + :param img: Configured image or image spec :param cfg: Registration configuration :return: """ + if isinstance(img, ImageSpec): + ImageBuildEngine.build(img) + return img.image_name() + if img is not None and img != "": matches = _IMAGE_REPLACE_REGEX.findall(img) if matches is None or len(matches) == 0: diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 28c5b5def7..5b08bb6fc8 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,6 +1,6 @@ import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface @@ -8,6 +8,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.image_spec.image_spec import ImageSpec from flytekit.models.documentation import Documentation from flytekit.models.security import Secret @@ -74,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction return PythonFunctionTask +T = TypeVar("T") + + +@overload +def task( + _task_function: None = ..., + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]: + ... + + +@overload +def task( + _task_function: Callable[..., Any], + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> PythonFunctionTask[T]: + ... + + def task( - _task_function: Optional[Callable] = None, - task_config: Optional[Any] = None, + _task_function: Optional[Callable[..., Any]] = None, + task_config: Optional[T] = None, cache: bool = False, cache_serialize: bool = False, cache_version: str = "", @@ -84,7 +140,7 @@ def task( interruptible: Optional[bool] = None, deprecated: str = "", timeout: Union[_datetime.timedelta, int] = 0, - container_image: Optional[str] = None, + container_image: Optional[Union[str, ImageSpec]] = None, environment: Optional[Dict[str, str]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, @@ -93,9 +149,9 @@ def task( task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, disable_deck: bool = True, - pod_template: Optional[PodTemplate] = None, + pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable, PythonFunctionTask]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]: """ This is the core decorator to use for any task type in flytekit. @@ -189,7 +245,7 @@ def foo2(): :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ - def wrapper(fn) -> PythonFunctionTask: + def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize, diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 2a203d4861..93fbc99a4f 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -91,7 +91,6 @@ def find_lhs(self) -> str: # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and # continue looping through m. logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err)) - pass logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 7bfc85d1ef..da1f614ba1 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -22,14 +22,15 @@ from google.protobuf.json_format import ParseDict as _ParseDict from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions -from marshmallow_jsonschema import JSONSchema from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation from flytekit.core.context_manager import FlyteContext from flytekit.core.hash import HashMethod from flytekit.core.type_helpers import load_type_from_tag +from flytekit.core.utils import timeit from flytekit.exceptions import user as user_exceptions +from flytekit.lazy_import.lazy_module import is_imported from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import types as _type_models @@ -88,7 +89,7 @@ def type_assertions_enabled(self) -> bool: def assert_type(self, t: Type[T], v: T): if not hasattr(t, "__origin__") and not isinstance(v, t): - raise TypeTransformerFailedError(f"Type of Val '{v}' is not an instance of {t}") + raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @abstractmethod def get_literal_type(self, t: Type[T]) -> LiteralType: @@ -129,7 +130,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) - @abstractmethod def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -167,7 +167,9 @@ def get_literal_type(self, t: Type[T] = None) -> LiteralType: def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != self._type: - raise TypeTransformerFailedError(f"Expected value of type {self._type} but got type {type(python_val)}") + raise TypeTransformerFailedError( + f"Expected value of type {self._type} but got '{python_val}' of type {type(python_val)}" + ) return self._to_literal_transformer(python_val) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: @@ -186,7 +188,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return res except AttributeError: # Assume that this is because a property on `lv` was None - raise TypeTransformerFailedError(f"Cannot convert literal {lv}") + raise TypeTransformerFailedError(f"Cannot convert literal {lv} to {self._type}") def guess_python_type(self, literal_type: LiteralType) -> Type[T]: if literal_type.simple is not None and literal_type.simple == self._lt.simple: @@ -327,6 +329,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 if isinstance(v, EnumField): v.load_by = LoadDumpOptions.name + from marshmallow_jsonschema import JSONSchema + schema = JSONSchema().dump(s) except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 @@ -374,7 +378,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: # In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, # so here we convert it back to the Structured Dataset. - from flytekit import StructuredDataset + from flytekit.types.structured import StructuredDataset if python_type == StructuredDataset and type(python_val) == dict: return StructuredDataset(**python_val) @@ -657,7 +661,8 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] - _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() + _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore + has_lazy_import = False @classmethod def register( @@ -701,24 +706,32 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: d = dictionary of registered transformers, where is a python `type` v = lookup type Step 1: - find a transformer that matches v exactly + If the type is annotated with a TypeTransformer instance, use that. Step 2: - find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc + find a transformer that matches v exactly Step 3: + find a transformer that matches the generic type of v. e.g List[int], Dict[str, int] etc + + Step 4: Walk the inheritance hierarchy of v and find a transformer that matches the first base class. This is potentially non-deterministic - will depend on the registration pattern. TODO lets make this deterministic by using an ordered dict - Step 4: + Step 5: if v is of type data class, use the dataclass transformer """ - + cls.lazy_import_transformers() # Step 1 if get_origin(python_type) is Annotated: - python_type = get_args(python_type)[0] + args = get_args(python_type) + for annotation in args: + if isinstance(annotation, TypeTransformer): + return annotation + + python_type = args[0] if python_type in cls._REGISTRY: return cls._REGISTRY[python_type] @@ -757,6 +770,39 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer") + @classmethod + def lazy_import_transformers(cls): + """ + Only load the transformers if needed. + """ + if cls.has_lazy_import: + return + cls.has_lazy_import = True + from flytekit.types.structured import ( + register_arrow_handlers, + register_bigquery_handlers, + register_pandas_handlers, + ) + + if is_imported("tensorflow"): + from flytekit.extras import tensorflow # noqa: F401 + if is_imported("torch"): + from flytekit.extras import pytorch # noqa: F401 + if is_imported("sklearn"): + from flytekit.extras import sklearn # noqa: F401 + if is_imported("pandas"): + try: + from flytekit.types import schema # noqa: F401 + except ValueError: + logger.debug("Transformer for pandas is already registered.") + register_pandas_handlers() + if is_imported("pyarrow"): + register_arrow_handlers() + if is_imported("google.cloud.bigquery"): + register_bigquery_handlers() + if is_imported("numpy"): + from flytekit.types import numpy # noqa: F401 + @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: """ @@ -820,7 +866,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T return transformer.to_python_value(ctx, lv, expected_python_type) @classmethod - def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[T]) -> str: + def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str: transformer = cls.get_transformer(expected_python_type) if get_origin(expected_python_type) is Annotated: expected_python_type, *annotate_args = get_args(expected_python_type) @@ -843,6 +889,7 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. return _interface_models.VariableMap(variables=variables) @classmethod + @timeit("Translate literal to python value") def literal_map_to_kwargs( cls, ctx: FlyteContext, lm: LiteralMap, python_types: typing.Dict[str, type] ) -> typing.Dict[str, typing.Any]: @@ -853,7 +900,13 @@ def literal_map_to_kwargs( raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} + kwargs = {} + for i, k in enumerate(lm.literals): + try: + kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) + except TypeTransformerFailedError as exc: + raise TypeTransformerFailedError(f"Error converting input '{k}' at position {i}:\n {exc}") from exc + return kwargs @classmethod def dict_to_literal_map( @@ -957,12 +1010,43 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") + @staticmethod + def is_batchable(t: Type): + """ + This function evaluates whether the provided type is batchable or not. + It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle. + """ + from flytekit.types.pickle import FlytePickle + + if get_origin(t) is Annotated: + return ListTransformer.is_batchable(get_args(t)[0]) + if get_origin(t) is list: + subtype = get_args(t)[0] + if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle): + return True + return False + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") - t = self.get_sub_type(python_type) - lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore + if ListTransformer.is_batchable(python_type): + from flytekit.types.pickle.pickle import BatchSize, FlytePickle + + batch_size = len(python_val) # default batch size + # parse annotated to get the number of items saved in a pickle file. + if get_origin(python_type) is Annotated: + for annotation in get_args(python_type)[1:]: + if isinstance(annotation, BatchSize): + batch_size = annotation.val + break + if batch_size > 0: + lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size)] # type: ignore + else: + lit_list = [] + else: + t = self.get_sub_type(python_type) + lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[T]: @@ -970,9 +1054,18 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() - - st = self.get_sub_type(expected_python_type) - return [TypeEngine.to_python_value(ctx, x, st) for x in lits] + if self.is_batchable(expected_python_type): + from flytekit.types.pickle import FlytePickle + + batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] + if len(batch_list) > 0 and type(batch_list[0]) is list: + # Make it have backward compatibility. The upstream task may use old version of Flytekit that + # won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. + return [item for batch in batch_list for item in batch] + return batch_list + else: + st = self.get_sub_type(expected_python_type) + return [TypeEngine.to_python_value(ctx, x, st) for x in lits] def guess_python_type(self, literal_type: LiteralType) -> Type[list]: if literal_type.collection_type: @@ -1033,7 +1126,7 @@ def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool: if len(ucols) != len(dcols): return False - for (u, d) in zip(ucols, dcols): + for u, d in zip(ucols, dcols): if u.name != d.name: return False diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index d23aae3fbb..437d2b71a4 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -1,13 +1,20 @@ +import datetime import os as _os import shutil as _shutil import tempfile as _tempfile import time as _time +from functools import wraps from hashlib import sha224 as _sha224 from pathlib import Path -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast +from flyteidl.core import tasks_pb2 as _core_task + +from flytekit.core.pod_template import PodTemplate from flytekit.loggers import logger -from flytekit.models import task as _task_models + +if TYPE_CHECKING: + from flytekit.models import task as task_models def _dnsify(value: str) -> str: @@ -51,8 +58,8 @@ def _dnsify(value: str) -> str: def _get_container_definition( image: str, command: List[str], - args: List[str], - data_loading_config: Optional[_task_models.DataLoadingConfig] = None, + args: Optional[List[str]] = None, + data_loading_config: Optional["task_models.DataLoadingConfig"] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, cpu_request: Optional[str] = None, @@ -64,7 +71,7 @@ def _get_container_definition( gpu_limit: Optional[str] = None, memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, -) -> _task_models.Container: +) -> "task_models.Container": storage_limit = storage_limit storage_request = storage_request ephemeral_storage_limit = ephemeral_storage_limit @@ -76,56 +83,107 @@ def _get_container_definition( memory_limit = memory_limit memory_request = memory_request + from flytekit.models import task as task_models + + # TODO: Use convert_resources_to_resource_model instead of manually fixing the resources. requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_request) ) if ephemeral_storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request ) ) if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) - ) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_request)) limits = [] if storage_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_limit)) if ephemeral_storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit ) ) if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_limit)) if environment is None: environment = {} - return _task_models.Container( + return task_models.Container( image=image, command=command, args=args, - resources=_task_models.Resources(limits=limits, requests=requests), + resources=task_models.Resources(limits=limits, requests=requests), env=environment, config={}, data_loading_config=data_loading_config, ) +def _sanitize_resource_name(resource: "task_models.Resources.ResourceEntry") -> str: + return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") + + +def _serialize_pod_spec(pod_template: "PodTemplate", primary_container: "task_models.Container") -> Dict[str, Any]: + from kubernetes.client import ApiClient, V1PodSpec + from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements + + if pod_template.pod_spec is None: + return {} + containers = cast(V1PodSpec, pod_template.pod_spec).containers + primary_exists = False + + for container in containers: + if container.name == cast(PodTemplate, pod_template).primary_container_name: + primary_exists = True + break + + if not primary_exists: + # insert a placeholder primary container if it is not defined in the pod spec. + containers.append(V1Container(name=cast(PodTemplate, pod_template).primary_container_name)) + final_containers = [] + for container in containers: + # In the case of the primary container, we overwrite specific container attributes + # with the values given to ContainerTask. + # The attributes include: image, command, args, resource, and env (env is unioned) + if container.name == cast(PodTemplate, pod_template).primary_container_name: + container.image = primary_container.image + container.command = primary_container.command + container.args = primary_container.args + + limits, requests = {}, {} + for resource in primary_container.resources.limits: + limits[_sanitize_resource_name(resource)] = resource.value + for resource in primary_container.resources.requests: + requests[_sanitize_resource_name(resource)] = resource.value + resource_requirements = V1ResourceRequirements(limits=limits, requests=requests) + if len(limits) > 0 or len(requests) > 0: + # Important! Only copy over resource requirements if they are non-empty. + container.resources = resource_requirements + if primary_container.env is not None: + container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + ( + container.env or [] + ) + final_containers.append(container) + cast(V1PodSpec, pod_template.pod_spec).containers = final_containers + + return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec) + + def load_proto_from_file(pb2_type, path): with open(path, "rb") as reader: out = pb2_type() @@ -209,26 +267,66 @@ def __str__(self): return self.__repr__() -class PerformanceTimer(object): - def __init__(self, context_statement): +class timeit: + """ + A context manager and a decorator that measures the execution time of the wrapped code block or functions. + It will append a timing information to TimeLineDeck. For instance: + + @timeit("Function description") + def function() + + with timeit("Wrapped code block description"): + # your code + """ + + def __init__(self, name: str = ""): """ - :param Text context_statement: the statement to log + :param name: A string that describes the wrapped code block or function being executed. """ - self._context_statement = context_statement + self._name = name + self.start_time = None self._start_wall_time = None self._start_process_time = None + def __call__(self, func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + def __enter__(self): - logger.info("Entering timed context: {}".format(self._context_statement)) + self.start_time = datetime.datetime.utcnow() self._start_wall_time = _time.perf_counter() self._start_process_time = _time.process_time() + return self def __exit__(self, exc_type, exc_val, exc_tb): + """ + The exception, if any, will propagate outside the context manager, as the purpose of this context manager + is solely to measure the execution time of the wrapped code block. + """ + from flytekit.core.context_manager import FlyteContextManager + + end_time = datetime.datetime.utcnow() end_wall_time = _time.perf_counter() end_process_time = _time.process_time() + + timeline_deck = FlyteContextManager.current_context().user_space_params.timeline_deck + timeline_deck.append_time_info( + dict( + Name=self._name, + Start=self.start_time, + Finish=end_time, + WallTime=end_wall_time - self._start_wall_time, + ProcessTime=end_process_time - self._start_process_time, + ) + ) + logger.info( - "Exiting timed context: {} [Wall Time: {}s, Process Time: {}s]".format( - self._context_statement, + "{}. [Wall Time: {}s, Process Time: {}s]".format( + self._name, end_wall_time - self._start_wall_time, end_process_time - self._start_process_time, ) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 8ba307b767..93b14f9528 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -1,9 +1,12 @@ from __future__ import annotations +import typing from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, overload + +from typing_extensions import get_args from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -32,14 +35,16 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference from flytekit.core.tracker import extract_task_module -from flytekit.core.type_engine import TypeEngine +from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError, UnionTransformer from flytekit.exceptions import scopes as exception_scopes from flytekit.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models +from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.documentation import Description, Documentation +from flytekit.models.types import TypeStructure GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -49,6 +54,8 @@ flyte_entity=None, ) +T = typing.TypeVar("T") + class WorkflowFailurePolicy(Enum): """ @@ -258,7 +265,11 @@ def __call__(self, *args, **kwargs): input_kwargs = self.python_interface.default_inputs_as_kwargs input_kwargs.update(kwargs) self.compile() - return flyte_entity_call_handler(self, *args, **input_kwargs) + try: + return flyte_entity_call_handler(self, *args, **input_kwargs) + except Exception as exc: + exc.args = (f"Encountered error while executing workflow '{self.name}':\n {exc}", *exc.args[1:]) + raise exc def execute(self, **kwargs): raise Exception("Should not be called") @@ -266,19 +277,63 @@ def execute(self, **kwargs): def compile(self, **kwargs): pass + def ensure_literal( + self, ctx, py_type: Type[T], input_type: type_models.LiteralType, python_value: Any + ) -> _literal_models.Literal: + """ + This function will attempt to convert a python value to a literal. If the python value is a promise, it will + return the promise's value. + """ + if input_type.union_type is not None: + if python_value is None and UnionTransformer.is_optional_type(py_type): + return _literal_models.Literal(scalar=_literal_models.Scalar(none_type=_literal_models.Void())) + for i in range(len(input_type.union_type.variants)): + lt_type = input_type.union_type.variants[i] + python_type = get_args(py_type)[i] + try: + final_lt = self.ensure_literal(ctx, python_type, lt_type, python_value) + lt_type._structure = TypeStructure(tag=TypeEngine.get_transformer(python_type).name) + return _literal_models.Literal( + scalar=_literal_models.Scalar(union=_literal_models.Union(value=final_lt, stored_type=lt_type)) + ) + except Exception as e: + logger.debug(f"Failed to convert {python_value} to {lt_type} with error {e}") + raise TypeError(f"Failed to convert {python_value} to {input_type}") + if isinstance(python_value, list) and input_type.collection_type: + collection_lit_type = input_type.collection_type + collection_py_type = get_args(py_type)[0] + xx = [self.ensure_literal(ctx, collection_py_type, collection_lit_type, pv) for pv in python_value] + return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=xx)) + elif isinstance(python_value, dict) and input_type.map_value_type: + mapped_lit_type = input_type.map_value_type + mapped_py_type = get_args(py_type)[1] + xx = {k: self.ensure_literal(ctx, mapped_py_type, mapped_lit_type, v) for k, v in python_value.items()} # type: ignore + return _literal_models.Literal(map=_literal_models.LiteralMap(literals=xx)) + # It is a scalar, convert to Promise if necessary. + else: + if isinstance(python_value, Promise): + return python_value.val + if not isinstance(python_value, Promise): + try: + res = TypeEngine.to_literal(ctx, python_value, py_type, input_type) + return res + except TypeTransformerFailedError as exc: + raise TypeError( + f"Failed to convert input '{python_value}' of workflow '{self.name}':\n {exc}" + ) from exc + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. for k, v in kwargs.items(): - if not isinstance(v, Promise): - t = self.python_interface.inputs[k] - kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) + py_type = self.python_interface.inputs[k] + lit_type = self.interface.inputs[k].type + kwargs[k] = Promise(var=k, val=self.ensure_literal(ctx, py_type, lit_type, v)) - # The output of this will always be a combination of Python native values and Promises containing Flyte - # Literals. + # The output of this will always be a combination of Python native values and Promises containing Flyte + # Literals. self.compile() function_outputs = self.execute(**kwargs) - # First handle the empty return case. # A workflow function may return a task that doesn't return anything # def wf(): @@ -595,9 +650,9 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): def __init__( self, - workflow_function: Callable, - metadata: Optional[WorkflowMetadata], - default_metadata: Optional[WorkflowMetadataDefaults], + workflow_function: Callable[..., Any], + metadata: WorkflowMetadata, + default_metadata: WorkflowMetadataDefaults, docstring: Optional[Docstring] = None, docs: Optional[Documentation] = None, ): @@ -719,12 +774,32 @@ def execute(self, **kwargs): return exception_scopes.user_entry_point(self._workflow_function)(**kwargs) +@overload +def workflow( + _workflow_function: None = ..., + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]: + ... + + +@overload +def workflow( + _workflow_function: Callable[..., Any], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> PythonFunctionWorkflow: + ... + + def workflow( - _workflow_function=None, + _workflow_function: Optional[Callable[..., Any]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, docs: Optional[Documentation] = None, -) -> WorkflowBase: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. @@ -755,7 +830,7 @@ def workflow( :param docs: Description entity for the workflow """ - def wrapper(fn): + def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow: workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -770,7 +845,7 @@ def wrapper(fn): update_wrapper(workflow_instance, fn) return workflow_instance - if _workflow_function: + if _workflow_function is not None: return wrapper(_workflow_function) else: return wrapper diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index cec59e7318..0d53ec18d6 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -2,8 +2,6 @@ import typing from typing import Optional -from jinja2 import Environment, FileSystemLoader, select_autoescape - from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.loggers import logger @@ -74,6 +72,58 @@ def html(self) -> str: return self._html +class TimeLineDeck(Deck): + """ + The TimeLineDeck class is designed to render the execution time of each part of a task. + Unlike deck class, the conversion of data to HTML is delayed until the html property is accessed. + This approach is taken because rendering a timeline graph with partial data would not provide meaningful insights. + Instead, the complete data set is used to create a comprehensive visualization of the execution time of each part of the task. + """ + + def __init__(self, name: str, html: Optional[str] = ""): + super().__init__(name, html) + self.time_info = [] + + def append_time_info(self, info: dict): + assert isinstance(info, dict) + self.time_info.append(info) + + @property + def html(self) -> str: + try: + from flytekitplugins.deck.renderer import GanttChartRenderer, TableRenderer + except ImportError: + warning_info = "Plugin 'flytekit-deck-standard' is not installed. To display time line, install the plugin in the image." + logger.warning(warning_info) + return warning_info + + if len(self.time_info) == 0: + return "" + + import pandas + + df = pandas.DataFrame(self.time_info) + note = """ +

Note:

+
    +
  1. if the time duration is too small(< 1ms), it may be difficult to see on the time line graph.
  2. +
  3. For accurate execution time measurements, users should refer to wall time and process time.
  4. +
+ """ + # set the accuracy to microsecond + df["ProcessTime"] = df["ProcessTime"].apply(lambda time: "{:.6f}".format(time)) + df["WallTime"] = df["WallTime"].apply(lambda time: "{:.6f}".format(time)) + + width = 1400 + gantt_chart_html = GanttChartRenderer().to_html(df, chart_width=width) + time_table_html = TableRenderer().to_html( + df[["Name", "WallTime", "ProcessTime"]], + header_labels=["Name", "Wall Time(s)", "Process Time(s)"], + table_width=width, + ) + return gantt_chart_html + time_table_html + note + + def _ipython_check() -> bool: """ Check if interface is launching from iPython (not colab) @@ -98,10 +148,12 @@ def _get_deck( If ignore_jupyter is set to True, then it will return a str even in a jupyter environment. """ deck_map = {deck.name: deck.html for deck in new_user_params.decks} - raw_html = template.render(metadata=deck_map) + raw_html = get_deck_template().render(metadata=deck_map) if not ignore_jupyter and _ipython_check(): - from IPython.core.display import HTML - + try: + from IPython.core.display import HTML + except ImportError: + ... return HTML(raw_html) return raw_html @@ -118,15 +170,18 @@ def _output_deck(task_name: str, new_user_params: ExecutionParameters): logger.info(f"{task_name} task creates flyte deck html to file://{deck_path}") -root = os.path.dirname(os.path.abspath(__file__)) -templates_dir = os.path.join(root, "html") -env = Environment( - loader=FileSystemLoader(templates_dir), - # 🔥 include autoescaping for security purposes - # sources: - # - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping - # - https://stackoverflow.com/a/38642558/8474894 (see in comments) - # - https://stackoverflow.com/a/68826578/8474894 - autoescape=select_autoescape(enabled_extensions=("html",)), -) -template = env.get_template("template.html") +def get_deck_template() -> "Template": + from jinja2 import Environment, FileSystemLoader, select_autoescape + + root = os.path.dirname(os.path.abspath(__file__)) + templates_dir = os.path.join(root, "html") + env = Environment( + loader=FileSystemLoader(templates_dir), + # 🔥 include autoescaping for security purposes + # sources: + # - https://jinja.palletsprojects.com/en/3.0.x/api/#autoescaping + # - https://stackoverflow.com/a/38642558/8474894 (see in comments) + # - https://stackoverflow.com/a/68826578/8474894 + autoescape=select_autoescape(enabled_extensions=("html",)), + ) + return env.get_template("template.html") diff --git a/flytekit/deck/html/template.html b/flytekit/deck/html/template.html index 6bec37effe..19e0256880 100644 --- a/flytekit/deck/html/template.html +++ b/flytekit/deck/html/template.html @@ -53,17 +53,19 @@ } #flyte-frame-container { - width: 100%; + width: auto; } #flyte-frame-container > div { - display: none; + display: None; } #flyte-frame-container > div.active { - display: block; + display: Block; padding: 2rem 4rem; + width: 100%; } + diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index dddb88e420..cfea92ec4e 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,14 +1,22 @@ -from typing import Any +from typing import TYPE_CHECKING, Any -import pandas -import pyarrow from typing_extensions import Protocol, runtime_checkable +from flytekit import lazy_module + +if TYPE_CHECKING: + # Always import these modules in type-checking mode or when running pytest + import pandas + import pyarrow +else: + pandas = lazy_module("pandas") + pyarrow = lazy_module("pyarrow") + @runtime_checkable class Renderable(Protocol): def to_html(self, python_value: Any) -> str: - """Convert a object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string. + """Convert an object(markdown, pandas.dataframe) to HTML and return HTML as a unicode string. Returns: An HTML document as a string. """ raise NotImplementedError @@ -27,16 +35,16 @@ def __init__(self, max_rows: int = DEFAULT_MAX_ROWS, max_cols: int = DEFAULT_MAX self._max_rows = max_rows self._max_cols = max_cols - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: "pandas.DataFrame") -> str: assert isinstance(df, pandas.DataFrame) return df.to_html(max_rows=self._max_rows, max_cols=self._max_cols) class ArrowRenderer: """ - Render a Arrow dataframe as an HTML table. + Render an Arrow dataframe as an HTML table. """ - def to_html(self, df: pyarrow.Table) -> str: + def to_html(self, df: "pyarrow.Table") -> str: assert isinstance(df, pyarrow.Table) return df.to_string() diff --git a/flytekit/exceptions/scopes.py b/flytekit/exceptions/scopes.py index 60a4afa97e..bdfb2ba182 100644 --- a/flytekit/exceptions/scopes.py +++ b/flytekit/exceptions/scopes.py @@ -194,10 +194,13 @@ def user_entry_point(wrapped, instance, args, kwargs): _CONTEXT_STACK.append(_USER_CONTEXT) if _is_base_context(): # See comment at this location for system_entry_point + fn_name = wrapped.__name__ try: return wrapped(*args, **kwargs) - except FlyteScopedException as ex: - raise ex.value + except FlyteScopedException as exc: + raise exc.type(f"Error encountered while executing '{fn_name}':\n {exc.value}") from exc + except Exception as exc: + raise type(exc)(f"Error encountered while executing '{fn_name}':\n {exc}") from exc else: try: return wrapped(*args, **kwargs) diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index f6635a4a57..7223d13523 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -29,8 +29,6 @@ PythonCustomizedContainerTask ExecutableTemplateShimTask ShimTaskExecutor - DataPersistence - DataPersistencePlugins """ from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -39,7 +37,7 @@ from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.context_manager import ExecutionState, SecretsManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.interface import Interface from flytekit.core.promise import Promise from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask diff --git a/flytekit/extend/backend/__init__.py b/flytekit/extend/backend/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/extend/backend/base_plugin.py b/flytekit/extend/backend/base_plugin.py new file mode 100644 index 0000000000..9fc1bc206b --- /dev/null +++ b/flytekit/extend/backend/base_plugin.py @@ -0,0 +1,107 @@ +import typing +from abc import ABC, abstractmethod + +import grpc +from flyteidl.core.tasks_pb2 import TaskTemplate +from flyteidl.service.external_plugin_service_pb2 import ( + RETRYABLE_FAILURE, + RUNNING, + SUCCEEDED, + State, + TaskCreateResponse, + TaskDeleteResponse, + TaskGetResponse, +) + +from flytekit import logger +from flytekit.models.literals import LiteralMap + + +class BackendPluginBase(ABC): + """ + This is the base class for all backend plugins. It defines the interface that all plugins must implement. + The external plugins service will be run either locally or in a pod, and will be responsible for + invoking backend plugins. The propeller will communicate with the external plugins service + to create tasks, get the status of tasks, and delete tasks. + + All the backend plugins should be registered in the BackendPluginRegistry. External plugins service + will look up the plugin based on the task type. Every task type can only have one plugin. + """ + + def __init__(self, task_type: str): + self._task_type = task_type + + @property + def task_type(self) -> str: + """ + task_type is the name of the task type that this plugin supports. + """ + return self._task_type + + @abstractmethod + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> TaskCreateResponse: + """ + Return a Unique ID for the task that was created. It should return error code if the task creation failed. + """ + pass + + @abstractmethod + def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + """ + Return the status of the task, and return the outputs in some cases. For example, bigquery job + can't write the structured dataset to the output location, so it returns the output literals to the propeller, + and the propeller will write the structured dataset to the blob store. + """ + pass + + @abstractmethod + def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse: + """ + Delete the task. This call should be idempotent. + """ + pass + + +class BackendPluginRegistry(object): + """ + This is the registry for all backend plugins. The external plugins service will look up the plugin + based on the task type. + """ + + _REGISTRY: typing.Dict[str, BackendPluginBase] = {} + + @staticmethod + def register(plugin: BackendPluginBase): + if plugin.task_type in BackendPluginRegistry._REGISTRY: + raise ValueError(f"Duplicate plugin for task type {plugin.task_type}") + BackendPluginRegistry._REGISTRY[plugin.task_type] = plugin + logger.info(f"Registering backend plugin for task type {plugin.task_type}") + + @staticmethod + def get_plugin(context: grpc.ServicerContext, task_type: str) -> typing.Optional[BackendPluginBase]: + if task_type not in BackendPluginRegistry._REGISTRY: + logger.error(f"Cannot find backend plugin for task type [{task_type}]") + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Cannot find backend plugin for task type [{task_type}]") + return None + return BackendPluginRegistry._REGISTRY[task_type] + + +def convert_to_flyte_state(state: str) -> State: + """ + Convert the state from the backend plugin to the state in flyte. + """ + state = state.lower() + if state in ["failed"]: + return RETRYABLE_FAILURE + elif state in ["done", "succeeded"]: + return SUCCEEDED + elif state in ["running"]: + return RUNNING + raise ValueError(f"Unrecognized state: {state}") diff --git a/flytekit/extend/backend/external_plugin_service.py b/flytekit/extend/backend/external_plugin_service.py new file mode 100644 index 0000000000..e820a320b1 --- /dev/null +++ b/flytekit/extend/backend/external_plugin_service.py @@ -0,0 +1,53 @@ +import grpc +from flyteidl.service.external_plugin_service_pb2 import ( + PERMANENT_FAILURE, + TaskCreateRequest, + TaskCreateResponse, + TaskDeleteRequest, + TaskDeleteResponse, + TaskGetRequest, + TaskGetResponse, +) +from flyteidl.service.external_plugin_service_pb2_grpc import ExternalPluginServiceServicer + +from flytekit import logger +from flytekit.extend.backend.base_plugin import BackendPluginRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class BackendPluginServer(ExternalPluginServiceServicer): + def CreateTask(self, request: TaskCreateRequest, context: grpc.ServicerContext) -> TaskCreateResponse: + try: + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + plugin = BackendPluginRegistry.get_plugin(context, tmp.type) + if plugin is None: + return TaskCreateResponse() + return plugin.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp) + except Exception as e: + logger.error(f"failed to create task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to create task with error {e}") + + def GetTask(self, request: TaskGetRequest, context: grpc.ServicerContext) -> TaskGetResponse: + try: + plugin = BackendPluginRegistry.get_plugin(context, request.task_type) + if plugin is None: + return TaskGetResponse(state=PERMANENT_FAILURE) + return plugin.get(context=context, job_id=request.job_id) + except Exception as e: + logger.error(f"failed to get task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to get task with error {e}") + + def DeleteTask(self, request: TaskDeleteRequest, context: grpc.ServicerContext) -> TaskDeleteResponse: + try: + plugin = BackendPluginRegistry.get_plugin(context, request.task_type) + if plugin is None: + return TaskDeleteResponse() + return plugin.delete(context=context, job_id=request.job_id) + except Exception as e: + logger.error(f"failed to delete task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to delete task with error {e}") diff --git a/flytekit/extras/persistence/__init__.py b/flytekit/extras/persistence/__init__.py deleted file mode 100644 index a677632fd8..0000000000 --- a/flytekit/extras/persistence/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -======================= -DataPersistence Extras -======================= - -.. currentmodule:: flytekit.extras.persistence - -This module provides some default implementations of :py:class:`flytekit.DataPersistence`. These implementations -use command-line clients to download and upload data. The actual binaries need to be installed for these extras to work. -The binaries are not bundled with flytekit to keep it lightweight. - -Persistence Extras -=================== - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - GCSPersistence - HttpPersistence - S3Persistence -""" - -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.http import HttpPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py deleted file mode 100644 index 0ddb600024..0000000000 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import posixpath -import typing -from shutil import which as shell_which - -from flytekit.configuration import DataConfig, GCSConfig -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions.user import FlyteUserException -from flytekit.tools import subprocess - - -def _update_cmd_config_and_execute(cmd): - env = os.environ.copy() - return subprocess.check_call(cmd, env=env) - - -def _amend_path(path): - return posixpath.join(path, "*") if not path.endswith("*") else path - - -class GCSPersistence(DataPersistence): - """ - This DataPersistence plugin uses a preinstalled GSUtil binary in the container to download and upload data. - - The binary can be installed in multiple ways including simply, - - .. prompt:: - - pip install gsutil - - """ - - _GS_UTIL_CLI = "gsutil" - PROTOCOL = "gs://" - - def __init__(self, default_prefix: typing.Optional[str] = None, data_config: typing.Optional[DataConfig] = None): - super(GCSPersistence, self).__init__(name="gcs-gsutil", default_prefix=default_prefix) - self.gcs_cfg = data_config.gcs if data_config else GCSConfig.auto() - - @staticmethod - def _check_binary(): - """ - Make sure that the `gsutil` cli is present - """ - if not shell_which(GCSPersistence._GS_UTIL_CLI): - raise FlyteUserException("gsutil (gcloud cli) not found! Please install using `pip install gsutil`.") - - def _maybe_with_gsutil_parallelism(self, *gsutil_args): - """ - Check if we should run `gsutil` with the `-m` flag that enables - parallelism via multiple threads/processes. Additional tweaking of - this behavior can be achieved via the .boto configuration file. See: - https://cloud.google.com/storage/docs/boto-gsutil - """ - cmd = [GCSPersistence._GS_UTIL_CLI] - if self.gcs_cfg.gsutil_parallelism: - cmd.append("-m") - cmd.extend(gsutil_args) - - return cmd - - def exists(self, remote_path): - """ - :param Text remote_path: remote gs:// path - :rtype bool: whether the gs file exists or not - """ - GCSPersistence._check_binary() - - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - cmd = [GCSPersistence._GS_UTIL_CLI, "-q", "stat", remote_path] - try: - _update_cmd_config_and_execute(cmd) - return True - except Exception: - return False - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if not from_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - GCSPersistence._check_binary() - if recursive: - cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(from_path), to_path) - else: - cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - - return _update_cmd_config_and_execute(cmd) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - GCSPersistence._check_binary() - - if recursive: - cmd = self._maybe_with_gsutil_parallelism( - "cp", - "-r", - _amend_path(from_path), - to_path if to_path.endswith("/") else to_path + "/", - ) - else: - cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - return _update_cmd_config_and_execute(cmd) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - paths = list(paths) # make type check happy - if add_prefix: - paths.insert(0, self.default_prefix) - path = "/".join(paths) - if add_protocol: - return f"{self.PROTOCOL}{path}" - return path - - -DataPersistencePlugins.register_plugin(GCSPersistence.PROTOCOL, GCSPersistence) diff --git a/flytekit/extras/persistence/http.py b/flytekit/extras/persistence/http.py deleted file mode 100644 index ce6079300d..0000000000 --- a/flytekit/extras/persistence/http.py +++ /dev/null @@ -1,84 +0,0 @@ -import base64 -import os -import pathlib - -import requests - -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions import user -from flytekit.loggers import logger -from flytekit.tools import script_mode - - -class HttpPersistence(DataPersistence): - """ - DataPersistence implementation for the HTTP protocol. only supports downloading from an http source. Uploads are - not supported currently. - """ - - PROTOCOL_HTTP = "http" - PROTOCOL_HTTPS = "https" - _HTTP_OK = 200 - _HTTP_FORBIDDEN = 403 - _HTTP_NOT_FOUND = 404 - ALLOWED_CODES = { - _HTTP_OK, - _HTTP_NOT_FOUND, - _HTTP_FORBIDDEN, - } - - def __init__(self, *args, **kwargs): - super(HttpPersistence, self).__init__(name="http/https", *args, **kwargs) - - def exists(self, path: str): - rsp = requests.head(path) - if rsp.status_code not in self.ALLOWED_CODES: - raise user.FlyteValueException( - rsp.status_code, - f"Data at {path} could not be checked for existence. Expected one of: {self.ALLOWED_CODES}", - ) - return rsp.status_code == self._HTTP_OK - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if recursive: - raise user.FlyteAssertion("Reading data recursively from HTTP endpoint is not currently supported.") - rsp = requests.get(from_path) - if rsp.status_code != self._HTTP_OK: - raise user.FlyteValueException( - rsp.status_code, - "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK), - ) - head, _ = os.path.split(to_path) - if head and head.startswith("/"): - logger.debug(f"HttpPersistence creating {head} so that parent dirs exist") - pathlib.Path(head).mkdir(parents=True, exist_ok=True) - with open(to_path, "wb") as writer: - writer.write(rsp.content) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - if recursive: - raise user.FlyteAssertion("Recursive writing data to HTTP endpoint is not currently supported.") - - md5, _ = script_mode.hash_file(from_path) - encoded_md5 = base64.b64encode(md5) - with open(from_path, "+rb") as local_file: - content = local_file.read() - content_length = len(content) - rsp = requests.put( - to_path, data=content, headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5} - ) - - if rsp.status_code != self._HTTP_OK: - raise user.FlyteValueException( - rsp.status_code, - f"Request to send data {to_path} failed.", - ) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - raise user.FlyteAssertion( - "There are multiple ways of creating http links / paths, this is not supported by the persistence layer" - ) - - -DataPersistencePlugins.register_plugin("http://", HttpPersistence) -DataPersistencePlugins.register_plugin("https://", HttpPersistence) diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py deleted file mode 100644 index 0b00227ca0..0000000000 --- a/flytekit/extras/persistence/s3_awscli.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -import os as _os -import re as _re -import string as _string -import time -import typing -from shutil import which as shell_which -from typing import Dict, List, Optional - -from flytekit.configuration import DataConfig, S3Config -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions.user import FlyteUserException -from flytekit.loggers import logger -from flytekit.tools import subprocess - -S3_ANONYMOUS_FLAG = "--no-sign-request" -S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" -S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" - - -def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]): - env = _os.environ.copy() - - if s3_cfg.enable_debug: - cmd.insert(1, "--debug") - - if s3_cfg.endpoint is not None: - cmd.insert(1, s3_cfg.endpoint) - cmd.insert(1, "--endpoint-url") - - if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: - if s3_cfg.access_key_id: - env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id - - if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ: - if s3_cfg.secret_access_key: - env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key - - retry = 0 - while True: - try: - try: - return subprocess.check_call(cmd, env=env) - except Exception as e: - if retry > 0: - logger.info(f"AWS command failed with error {e}, command: {cmd}, retry {retry}") - - logger.debug(f"Appending anonymous flag and retrying command {cmd}") - anonymous_cmd = cmd[:] # strings only, so this is deep enough - anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG) - return subprocess.check_call(anonymous_cmd, env=env) - - except Exception as e: - logger.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") - retry += 1 - if retry > s3_cfg.retries: - raise - secs = s3_cfg.backoff - logger.info(f"Sleeping before retrying again, after {secs.total_seconds()} seconds") - time.sleep(secs.total_seconds()) - logger.info("Retrying again") - - -def _extra_args(extra_args: Dict[str, str]) -> List[str]: - cmd = [] - if "ContentType" in extra_args: - cmd += ["--content-type", extra_args["ContentType"]] - if "ContentEncoding" in extra_args: - cmd += ["--content-encoding", extra_args["ContentEncoding"]] - if "ACL" in extra_args: - cmd += ["--acl", extra_args["ACL"]] - return cmd - - -class S3Persistence(DataPersistence): - """ - DataPersistence plugin for AWS S3 (and Minio). Use aws cli to manage the transfer. The binary needs to be installed - separately - - .. prompt:: - - pip install awscli - - """ - - PROTOCOL = "s3://" - _AWS_CLI = "aws" - _SHARD_CHARACTERS = [str(x) for x in range(10)] + list(_string.ascii_lowercase) - - def __init__(self, default_prefix: Optional[str] = None, data_config: typing.Optional[DataConfig] = None): - super().__init__(name="awscli-s3", default_prefix=default_prefix) - self.s3_cfg = data_config.s3 if data_config else S3Config.auto() - - @staticmethod - def _check_binary(): - """ - Make sure that the AWS cli is present - """ - if not shell_which(S3Persistence._AWS_CLI): - raise FlyteUserException("AWS CLI not found! Please install it with `pip install awscli`.") - - @staticmethod - def _split_s3_path_to_bucket_and_key(path: str) -> typing.Tuple[str, str]: - """ - splits a valid s3 uri into bucket and key - """ - path = path[len("s3://") :] - first_slash = path.index("/") - return path[:first_slash], path[first_slash + 1 :] - - def exists(self, remote_path): - """ - Given a remote path of the format s3://, checks if the remote file exists - """ - S3Persistence._check_binary() - - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - bucket, file_path = self._split_s3_path_to_bucket_and_key(remote_path) - cmd = [ - S3Persistence._AWS_CLI, - "s3api", - "head-object", - "--bucket", - bucket, - "--key", - file_path, - ] - try: - _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - return True - except Exception as ex: - # The s3api command returns an error if the object does not exist. The error message contains - # the http status code: "An error occurred (404) when calling the HeadObject operation: Not Found" - # This is a best effort for returning if the object does not exist by searching - # for existence of (404) in the error message. This should not be needed when we get off the cli and use lib - if _re.search("(404)", str(ex)): - return False - else: - raise ex - - def get(self, from_path: str, to_path: str, recursive: bool = False): - S3Persistence._check_binary() - - if not from_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - if recursive: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", from_path, to_path] - else: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", from_path, to_path] - return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - extra_args = { - "ACL": "bucket-owner-full-control", - } - - if not to_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - S3Persistence._check_binary() - cmd = [S3Persistence._AWS_CLI, "s3", "cp"] - if recursive: - cmd += ["--recursive"] - cmd.extend(_extra_args(extra_args)) - cmd += [from_path, to_path] - return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: - paths = list(paths) # make type check happy - if add_prefix: - paths.insert(0, self.default_prefix) - path = "/".join(paths) - if add_protocol: - return f"{self.PROTOCOL}{path}" - return path - - -DataPersistencePlugins.register_plugin(S3Persistence.PROTOCOL, S3Persistence) diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 8e7d8b3b29..ef8013a5da 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -14,7 +14,6 @@ from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor from flytekit.models import task as task_models -from flytekit.types.schema import FlyteSchema def unarchive_file(local_path: str, to_dir: str): @@ -78,12 +77,14 @@ def __init__( query_template: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, - output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + output_schema_type: typing.Optional[typing.Type["FlyteSchema"]] = None, # type: ignore container_image: typing.Optional[str] = None, **kwargs, ): if task_config is None or task_config.uri is None: raise ValueError("SQLite DB uri is required.") + from flytekit.types.schema import FlyteSchema + outputs = kwtypes(results=output_schema_type if output_schema_type else FlyteSchema) super().__init__( name=name, diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 12ef36af3e..87b60126d6 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -1,5 +1,6 @@ import datetime import os +import platform import string import subprocess import typing @@ -213,6 +214,9 @@ def execute(self, **kwargs) -> typing.Any: print("\n==============================================\n") try: + if platform.system() == "Windows" and os.environ.get("ComSpec") is None: + # https://github.com/python/cpython/issues/101283 + os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe" subprocess.check_call(gen_script, shell=True) except subprocess.CalledProcessError as e: files = os.listdir(".") @@ -356,7 +360,6 @@ def execute(self, **kwargs) -> typing.Any: # This utility function allows for the specification of env variables, arguments, and the actual script within the # workflow definition rather than at `RawShellTask` instantiation def get_raw_shell_task(name: str) -> RawShellTask: - return RawShellTask( name=name, debug=True, diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index b5699906fb..2447ffa826 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -23,9 +23,10 @@ if _tensorflow_installed: + from .model import TensorFlowModelTransformer from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer else: logger.info( - "We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer " + "We won't register TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer and TensorFlowModelTransformer" "because tensorflow is not installed." ) diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py new file mode 100644 index 0000000000..857ec2c984 --- /dev/null +++ b/flytekit/extras/tensorflow/model.py @@ -0,0 +1,76 @@ +import pathlib +from typing import Type + +import tensorflow as tf + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + + +class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]): + TENSORFLOW_FORMAT = "TensorFlowModel" + + def __init__(self): + super().__init__(name="TensorFlow Model", t=tf.keras.Model) + + def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: tf.keras.Model, + python_type: Type[tf.keras.Model], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save model in SavedModel format + tf.keras.models.save_model(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path() + ctx.file_access.put_data(local_path, remote_path, is_multipart=True) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model] + ) -> tf.keras.Model: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=True) + + # load model + return tf.keras.models.load_model(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return tf.keras.Model + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlowModelTransformer()) diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index d5d750b521..17e7c37ddd 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -159,7 +159,6 @@ def to_literal( def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory] ) -> TFRecordDatasetV2: - uri, metadata = extract_metadata_and_uri(lv, expected_python_type) local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(uri, local_dir, is_multipart=True) diff --git a/flytekit/image_spec/__init__.py b/flytekit/image_spec/__init__.py new file mode 100644 index 0000000000..ca1bdedee6 --- /dev/null +++ b/flytekit/image_spec/__init__.py @@ -0,0 +1 @@ +from .image_spec import ImageSpec diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py new file mode 100644 index 0000000000..a0ddbc1977 --- /dev/null +++ b/flytekit/image_spec/image_spec.py @@ -0,0 +1,170 @@ +import base64 +import hashlib +import os +import typing +from abc import abstractmethod +from copy import copy +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional + +import click +import requests +from dataclasses_json import dataclass_json + +DOCKER_HUB = "docker.io" +_F_IMG_ID = "_F_IMG_ID" + + +@dataclass_json +@dataclass +class ImageSpec: + """ + This class is used to specify the docker image that will be used to run the task. + + Args: + name: name of the image. + python_version: python version of the image. Use default python in the base image if None. + builder: Type of plugin to build the image. Use envd by default. + source_root: source root of the image. + env: environment variables of the image. + registry: registry of the image. + packages: list of python packages to install. + apt_packages: list of apt packages to install. + base_image: base image of the image. + """ + + name: str = "flytekit" + python_version: str = None # Use default python in the base image if None. + builder: str = "envd" + source_root: Optional[str] = None + env: Optional[typing.Dict[str, str]] = None + registry: Optional[str] = None + packages: Optional[List[str]] = None + apt_packages: Optional[List[str]] = None + base_image: Optional[str] = None + + def image_name(self) -> str: + """ + return full image name with tag. + """ + tag = calculate_hash_from_image_spec(self) + container_image = f"{self.name}:{tag}" + if self.registry: + container_image = f"{self.registry}/{container_image}" + return container_image + + def is_container(self) -> bool: + from flytekit.core.context_manager import ExecutionState, FlyteContextManager + + state = FlyteContextManager.current_context().execution_state + if state and state.mode and state.mode != ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + return os.environ.get(_F_IMG_ID) == self.image_name() + return True + + @lru_cache(maxsize=128) + def exist(self) -> bool: + """ + Check if the image exists in the registry. + """ + import docker + from docker.errors import APIError, ImageNotFound + + try: + client = docker.from_env() + if self.registry: + client.images.get_registry_data(self.image_name()) + else: + client.images.get(self.image_name()) + return True + except APIError as e: + if e.response.status_code == 404: + return False + except ImageNotFound: + return False + except Exception as e: + tag = calculate_hash_from_image_spec(self) + # if docker engine is not running locally + container_registry = DOCKER_HUB + if "/" in self.registry: + container_registry = self.registry.split("/")[0] + if container_registry == DOCKER_HUB: + url = f"https://hub.docker.com/v2/repositories/{self.registry}/{self.name}/tags/{tag}" + response = requests.get(url) + if response.status_code == 200: + return True + + if response.status_code == 404: + return False + + click.secho(f"Failed to check if the image exists with error : {e}", fg="red") + click.secho("Flytekit assumes that the image already exists.", fg="blue") + return True + + def __hash__(self): + return hash(self.to_json()) + + +class ImageSpecBuilder: + @abstractmethod + def build_image(self, image_spec: ImageSpec): + """ + Build the docker image and push it to the registry. + + Args: + image_spec: image spec of the task. + """ + raise NotImplementedError("This method is not implemented in the base class.") + + +class ImageBuildEngine: + """ + ImageBuildEngine contains a list of builders that can be used to build an ImageSpec. + """ + + _REGISTRY: typing.Dict[str, ImageSpecBuilder] = {} + + @classmethod + def register(cls, builder_type: str, image_spec_builder: ImageSpecBuilder): + cls._REGISTRY[builder_type] = image_spec_builder + + @classmethod + def build(cls, image_spec: ImageSpec): + if image_spec.builder not in cls._REGISTRY: + raise Exception(f"Builder {image_spec.builder} is not registered.") + if not image_spec.exist(): + click.secho(f"Image {image_spec.image_name()} not found. Building...", fg="blue") + cls._REGISTRY[image_spec.builder].build_image(image_spec) + else: + click.secho(f"Image {image_spec.image_name()} found. Skip building.", fg="blue") + + +@lru_cache(maxsize=128) +def calculate_hash_from_image_spec(image_spec: ImageSpec): + """ + Calculate the hash from the image spec. + """ + # copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different. + spec = copy(image_spec) + spec.source_root = hash_directory(image_spec.source_root) if image_spec.source_root else b"" + image_spec_bytes = bytes(image_spec.to_json(), "utf-8") + tag = base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii") + # replace "=" with "." to make it a valid tag + return tag.replace("=", ".") + + +def hash_directory(path): + """ + Return the SHA-256 hash of the directory at the given path. + """ + hasher = hashlib.sha256() + for root, dirs, files in os.walk(path): + for file in files: + with open(os.path.join(root, file), "rb") as f: + while True: + # Read file in small chunks to avoid loading large files into memory all at once + chunk = f.read(4096) + if not chunk: + break + hasher.update(chunk) + return bytes(hasher.hexdigest(), "utf-8") diff --git a/flytekit/lazy_import/__init__.py b/flytekit/lazy_import/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/lazy_import/lazy_module.py b/flytekit/lazy_import/lazy_module.py new file mode 100644 index 0000000000..553386eb52 --- /dev/null +++ b/flytekit/lazy_import/lazy_module.py @@ -0,0 +1,33 @@ +import importlib.util +import sys + +LAZY_MODULES = [] + + +def is_imported(module_name): + """ + This function is used to check if a module has been imported by the regular import. + """ + return module_name in sys.modules and module_name not in LAZY_MODULES + + +def lazy_module(fullname): + """ + This function is used to lazily import modules. It is used in the following way: + .. code-block:: python + from flytekit.lazy_import import lazy_module + sklearn = lazy_module("sklearn") + sklearn.svm.SVC() + :param Text fullname: The full name of the module to import + """ + if fullname in sys.modules: + return sys.modules[fullname] + # https://docs.python.org/3/library/importlib.html#implementing-lazy-imports + spec = importlib.util.find_spec(fullname) + loader = importlib.util.LazyLoader(spec.loader) + spec.loader = loader + module = importlib.util.module_from_spec(spec) + sys.modules[fullname] = module + LAZY_MODULES.append(module) + loader.exec_module(module) + return module diff --git a/flytekit/loggers.py b/flytekit/loggers.py index f047348de0..fdc3c75d3a 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -2,6 +2,8 @@ import os from pythonjsonlogger import jsonlogger +from rich.console import Console +from rich.logging import RichHandler # Note: # The environment variable controls exposed to affect the individual loggers should be considered to be beta. @@ -10,6 +12,7 @@ # For now, assume this is the environment variable whose usage will remain unchanged and controls output for all # loggers defined in this file. LOGGING_ENV_VAR = "FLYTE_SDK_LOGGING_LEVEL" +LOGGING_FMT_ENV_VAR = "FLYTE_SDK_LOGGING_FORMAT" # By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning logger = logging.getLogger("flytekit") @@ -33,8 +36,18 @@ user_space_logger = child_loggers["user_space"] # create console handler -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) +try: + handler = RichHandler( + rich_tracebacks=True, + omit_repeated_times=False, + keywords=["[flytekit]"], + log_time_format="%Y-%m-%d %H:%M:%S,%f", + console=Console(width=os.get_terminal_size().columns), + ) +except OSError: + handler = logging.StreamHandler() + +handler.setLevel(logging.DEBUG) # Root logger control # Don't want to import the configuration library since that will cause all sorts of circular imports, let's @@ -63,10 +76,14 @@ child_logger.setLevel(logging.WARNING) # create formatter -formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") +logging_fmt = os.environ.get(LOGGING_FMT_ENV_VAR, "json") +if logging_fmt == "json": + formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") +else: + formatter = logging.Formatter(fmt="[%(name)s] %(message)s") -# add formatter to ch -ch.setFormatter(formatter) +# add formatter to the handler +handler.setFormatter(formatter) # add ch to logger -logger.addHandler(ch) +logger.addHandler(handler) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 62018c1eef..4f030e25a4 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -1,5 +1,6 @@ import abc as _abc import json as _json +import re from flyteidl.admin import common_pb2 as _common_pb2 from google.protobuf import json_format as _json_format @@ -57,7 +58,8 @@ def short_string(self): """ :rtype: Text """ - return str(self.to_flyte_idl()) + literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip() + return f"" def verbose_string(self): """ diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 7babb859e4..9af90a4b8a 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -36,15 +36,13 @@ class MountType(Enum): """ group: str - key: str + key: Optional[str] = None group_version: Optional[str] = None mount_requirement: MountType = MountType.ANY def __post_init__(self): if self.group is None: raise ValueError("Group is a required parameter") - if self.key is None: - raise ValueError("Key is also a required parameter") def to_flyte_idl(self) -> _sec.Secret: return _sec.Secret( @@ -59,7 +57,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret": return cls( group=pb2_object.group, group_version=pb2_object.group_version if pb2_object.group_version else None, - key=pb2_object.key, + key=pb2_object.key if pb2_object.key else None, mount_requirement=Secret.MountType(pb2_object.mount_requirement), ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index fc79c87a2d..f7f1d710c9 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -868,12 +868,18 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata): class K8sPod(_common.FlyteIdlEntity): - def __init__(self, metadata: K8sObjectMetadata = None, pod_spec: typing.Dict[str, typing.Any] = None): + def __init__( + self, + metadata: K8sObjectMetadata = None, + pod_spec: typing.Dict[str, typing.Any] = None, + data_config: typing.Optional[DataLoadingConfig] = None, + ): """ This defines a kubernetes pod target. It will build the pod target during task execution """ self._metadata = metadata self._pod_spec = pod_spec + self._data_config = data_config @property def metadata(self) -> K8sObjectMetadata: @@ -883,10 +889,15 @@ def metadata(self) -> K8sObjectMetadata: def pod_spec(self) -> typing.Dict[str, typing.Any]: return self._pod_spec + @property + def data_config(self) -> typing.Optional[DataLoadingConfig]: + return self._data_config + def to_flyte_idl(self) -> _core_task.K8sPod: return _core_task.K8sPod( metadata=self._metadata.to_flyte_idl(), pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None, + data_config=self.data_config.to_flyte_idl() if self.data_config else None, ) @classmethod @@ -894,6 +905,9 @@ def from_flyte_idl(cls, pb2_object: _core_task.K8sPod): return cls( metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata), pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None, + data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) + if pb2_object.HasField("data_config") + else None, ) diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 4358d7229e..3e3c778d6b 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -259,7 +259,7 @@ def __init__( :param flytekit.models.core.types.StructuredDatasetType structured_dataset_type: structured dataset :param dict[Text, T] metadata: Additional data describing the type :param flytekit.models.annotation.TypeAnnotation annotation: Additional data - describing the type _intended to be saturated by the client_ + describing the type intended to be saturated by the client """ self._simple = simple self._schema = schema diff --git a/flytekit/remote/backfill.py b/flytekit/remote/backfill.py index 154bf4d1b4..2f31889060 100644 --- a/flytekit/remote/backfill.py +++ b/flytekit/remote/backfill.py @@ -68,6 +68,8 @@ def create_backfill_workflow( logging.info(f"Generating backfill from {start_date} -> {end_date}. Parallel?[{parallel}]") wf = ImperativeWorkflow(name=f"backfill-{for_lp.name}") + + input_name = schedule.kickoff_time_input_arg date_iter = croniter(cron_schedule.schedule, start_time=start_date, ret_type=datetime) prev_node = None actual_start = None @@ -79,7 +81,10 @@ def create_backfill_workflow( if next_start_date >= end_date: break actual_end = next_start_date - next_node = wf.add_launch_plan(for_lp, t=next_start_date) + inputs = {} + if input_name: + inputs[input_name] = next_start_date + next_node = wf.add_launch_plan(for_lp, **inputs) next_node = next_node.with_overrides( name=f"b-{next_start_date}", retries=per_node_retries, timeout=per_node_timeout ) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 03cc9a66e9..91189ede74 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -6,17 +6,19 @@ from __future__ import annotations import base64 -import functools import hashlib import os import pathlib +import tempfile import time import typing import uuid +from base64 import b64encode from collections import OrderedDict from dataclasses import asdict, dataclass from datetime import datetime, timedelta +import requests from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 @@ -34,7 +36,11 @@ from flytekit.core.type_engine import LiteralsResolver, TypeEngine from flytekit.core.workflow import WorkflowBase from flytekit.exceptions import user as user_exceptions -from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException +from flytekit.exceptions.user import ( + FlyteEntityAlreadyExistsException, + FlyteEntityNotExistException, + FlyteValueException, +) from flytekit.loggers import remote_logger from flytekit.models import common as common_models from flytekit.models import filters as filter_models @@ -62,7 +68,7 @@ from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.fast_registration import fast_package -from flytekit.tools.script_mode import fast_register_single_script, hash_file +from flytekit.tools.script_mode import compress_scripts, hash_file from flytekit.tools.translator import ( FlyteControlPlaneEntity, FlyteLocalEntity, @@ -615,6 +621,10 @@ def _serialize_and_register( version=version, ) is_dummy_serialization_setting = True + + if serialization_settings.version is None: + serialization_settings.version = version + _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) ident = None @@ -704,9 +714,9 @@ def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: s md5_bytes, _ = hash_file(pathlib.Path(zip_file)) # Upload zip file to Admin using FlyteRemote. - return self._upload_file(pathlib.Path(zip_file)) + return self.upload_file(pathlib.Path(zip_file)) - def _upload_file( + def upload_file( self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None ) -> typing.Tuple[bytes, str]: """ @@ -728,7 +738,23 @@ def _upload_file( content_md5=md5_bytes, filename=to_upload.name, ) - self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) + + encoded_md5 = b64encode(md5_bytes) + with open(str(to_upload), "+rb") as local_file: + content = local_file.read() + content_length = len(content) + rsp = requests.put( + upload_location.signed_url, + data=content, + headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5}, + ) + + if rsp.status_code != requests.codes["OK"]: + raise FlyteValueException( + rsp.status_code, + f"Request to send data {upload_location.signed_url} failed.", + ) + remote_logger.debug( f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" ) @@ -773,17 +799,19 @@ def register_script( project: typing.Optional[str] = None, domain: typing.Optional[str] = None, destination_dir: str = ".", - default_launch_plan: typing.Optional[bool] = True, + copy_all: bool = False, + default_launch_plan: bool = True, options: typing.Optional[Options] = None, source_path: typing.Optional[str] = None, module_name: typing.Optional[str] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. - :param destination_dir: - :param domain: - :param project: - :param image_config: + :param destination_dir: The destination directory where the workflow will be copied to. + :param copy_all: If true, the entire source directory will be copied over to the destination directory. + :param domain: The domain to register the workflow in. + :param project: The project to register the workflow in. + :param image_config: The image config to use for the workflow. :param version: version for the entity to be registered as :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow @@ -795,16 +823,16 @@ def register_script( if image_config is None: image_config = ImageConfig.auto_default_image() - upload_location, md5_bytes = fast_register_single_script( - source_path, - module_name, - functools.partial( - self.client.get_upload_signed_url, - project=project or self.default_project, - domain=domain or self.default_domain, - filename="scriptmode.tar.gz", - ), - ) + with tempfile.TemporaryDirectory() as tmp_dir: + if copy_all: + md5_bytes, upload_native_url = self.fast_package(pathlib.Path(source_path), False, tmp_dir) + else: + archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) + compress_scripts(source_path, str(archive_fname), module_name) + md5_bytes, upload_native_url = self.upload_file( + archive_fname, project or self.default_project, domain or self.default_domain + ) + serialization_settings = SerializationSettings( project=project, domain=domain, @@ -813,7 +841,7 @@ def register_script( fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, - distribution_location=upload_location.native_url, + distribution_location=upload_native_url, ), ) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index b2f7efcc65..f115480112 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -12,6 +12,7 @@ import click from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.utils import timeit from flytekit.tools.ignore import DockerIgnore, GitIgnore, IgnoreGroup, StandardIgnore from flytekit.tools.script_mode import tar_strip_file_attributes @@ -97,6 +98,7 @@ def get_additional_distribution_loc(remote_location: str, identifier: str) -> st return posixpath.join(remote_location, "{}.{}".format(identifier, "tar.gz")) +@timeit("Download distribution") def download_distribution(additional_distribution: str, destination: str): """ Downloads a remote code distribution and overwrites any local files. diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 3c9fe64068..82d4c2c226 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -37,7 +37,7 @@ def serialize( :param pkgs: Dot-delimited Python packages/subpackages to look into for serialization. :param local_source_root: Where to start looking for the code. """ - + settings.source_root = local_source_root ctx = FlyteContextManager.current_context().with_serialization_settings(settings) with FlyteContextManager.with_context(ctx) as ctx: # Scan all modules. the act of loading populates the global singleton that contains all objects @@ -60,6 +60,8 @@ def serialize_to_folder( """ Serialize the given set of python packages to a folder """ + if folder is None: + folder = "." loaded_entities = serialize(pkgs, settings, local_source_root, options=options) persist_registrable_entities(loaded_entities, folder) diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 29b617824c..ecc71a2398 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -8,13 +8,12 @@ import typing from pathlib import Path -from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 - -from flytekit.core import context_manager +from flytekit import PythonFunctionTask from flytekit.core.tracker import get_full_module_path +from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase -def compress_single_script(source_path: str, destination: str, full_module_name: str): +def compress_scripts(source_path: str, destination: str, module_name: str): """ Compresses the single script while maintaining the folder structure for that file. @@ -39,33 +38,14 @@ def compress_single_script(source_path: str, destination: str, full_module_name: │   ├── example.py │   └── __init__.py - Note how `another_example.py` and `yet_another_example.py` were not copied to the destination. + Note: If `example.py` didn't import tasks or workflows from `another_example.py` and `yet_another_example.py`, these files were not copied to the destination.. + """ with tempfile.TemporaryDirectory() as tmp_dir: destination_path = os.path.join(tmp_dir, "code") - # This is the script relative path to the root of the project - script_relative_path = Path() - # For each package in pkgs, create a directory and copy the __init__.py in it. - # Skip the last package as that is the script file. - pkgs = full_module_name.split(".") - for p in pkgs[:-1]: - os.makedirs(os.path.join(destination_path, p)) - source_path = os.path.join(source_path, p) - destination_path = os.path.join(destination_path, p) - script_relative_path = Path(script_relative_path, p) - init_file = Path(os.path.join(source_path, "__init__.py")) - if init_file.exists(): - shutil.copy(init_file, Path(os.path.join(tmp_dir, "code", script_relative_path, "__init__.py"))) - - # Ensure destination path exists to cover the case of a single file and no modules. - os.makedirs(destination_path, exist_ok=True) - script_file = Path(source_path, f"{pkgs[-1]}.py") - script_file_destination = Path(destination_path, f"{pkgs[-1]}.py") - # Build the final script relative path and copy it to a known place. - shutil.copy( - script_file, - script_file_destination, - ) + + visited: typing.List[str] = [] + copy_module_to_destination(source_path, destination_path, module_name, visited) tar_path = os.path.join(tmp_dir, "tmp.tar") with tarfile.open(tar_path, "w") as tar: tar.add(os.path.join(tmp_dir, "code"), arcname="", filter=tar_strip_file_attributes) @@ -74,6 +54,54 @@ def compress_single_script(source_path: str, destination: str, full_module_name: gzipped.write(tar_file.read()) +def copy_module_to_destination( + original_source_path: str, original_destination_path: str, module_name: str, visited: typing.List[str] +): + """ + Copy the module (file) to the destination directory. If the module relative imports other modules, flytekit will + recursively copy them as well. + """ + mod = importlib.import_module(module_name) + full_module_name = get_full_module_path(mod, mod.__name__) + if full_module_name in visited: + return + visited.append(full_module_name) + + source_path = original_source_path + destination_path = original_destination_path + pkgs = full_module_name.split(".") + + for p in pkgs[:-1]: + os.makedirs(os.path.join(destination_path, p), exist_ok=True) + destination_path = os.path.join(destination_path, p) + source_path = os.path.join(source_path, p) + init_file = Path(os.path.join(source_path, "__init__.py")) + if init_file.exists(): + shutil.copy(init_file, Path(os.path.join(destination_path, "__init__.py"))) + + # Ensure destination path exists to cover the case of a single file and no modules. + os.makedirs(destination_path, exist_ok=True) + script_file = Path(source_path, f"{pkgs[-1]}.py") + script_file_destination = Path(destination_path, f"{pkgs[-1]}.py") + # Build the final script relative path and copy it to a known place. + shutil.copy( + script_file, + script_file_destination, + ) + + # Try to copy other files to destination if tasks or workflows aren't in the same file + for flyte_entity_name in mod.__dict__: + flyte_entity = mod.__dict__[flyte_entity_name] + if ( + isinstance(flyte_entity, (PythonFunctionTask, WorkflowBase)) + and not isinstance(flyte_entity, ImperativeWorkflow) + and flyte_entity.instantiated_in + ): + copy_module_to_destination( + original_source_path, original_destination_path, flyte_entity.instantiated_in, visited + ) + + # Takes in a TarInfo and returns the modified TarInfo: # https://docs.python.org/3/library/tarfile.html#tarinfo-objects # intented to be passed as a filter to tarfile.add @@ -96,24 +124,6 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: return tar_info -def fast_register_single_script( - source_path: str, module_name: str, create_upload_location_fn: typing.Callable -) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes): - - # Open a temp directory and dump the contents of the digest. - with tempfile.TemporaryDirectory() as tmp_dir: - archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz") - mod = importlib.import_module(module_name) - compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__)) - - flyte_ctx = context_manager.FlyteContextManager.current_context() - md5, _ = hash_file(archive_fname) - upload_location = create_upload_location_fn(content_md5=md5) - flyte_ctx.file_access.put_data(archive_fname, upload_location.signed_url) - - return upload_location, md5 - - def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str): """ Hash a file and produce a digest to be used as a version @@ -131,7 +141,7 @@ def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str): return h.digest(), h.hexdigest() -def _find_project_root(source_path) -> Path: +def _find_project_root(source_path) -> str: """ Find the root of the project. The root of the project is considered to be the first ancestor from source_path that does @@ -143,4 +153,4 @@ def _find_project_root(source_path) -> Path: path = Path(source_path).parent.resolve() while os.path.exists(os.path.join(path, "__init__.py")): path = path.parent - return path + return str(path) diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 69af2b96b4..86a029d411 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -10,12 +10,10 @@ from flytekit.core import context_manager as flyte_context from flytekit.core.base_task import PythonTask from flytekit.core.workflow import WorkflowBase -from flytekit.exceptions.user import FlyteValidationException from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.admin.workflow import WorkflowSpec -from flytekit.models.core import identifier as _identifier from flytekit.models.task import TaskSpec from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.translator import FlyteControlPlaneEntity, Options, get_serializable @@ -44,20 +42,6 @@ def _should_register_with_admin(entity) -> bool: ) and not isinstance(entity, RemoteEntity) -def _find_duplicate_tasks(tasks: typing.List[task_models.TaskSpec]) -> typing.Set[task_models.TaskSpec]: - """ - Given a list of `TaskSpec`, this function returns a set containing the duplicated `TaskSpec` if any exists. - """ - seen: typing.Set[_identifier.Identifier] = set() - duplicate_tasks: typing.Set[task_models.TaskSpec] = set() - for task in tasks: - if task.template.id not in seen: - seen.add(task.template.id) - else: - duplicate_tasks.add(task) - return duplicate_tasks - - def get_registrable_entities( ctx: flyte_context.FlyteContext, options: typing.Optional[Options] = None ) -> typing.List[FlyteControlPlaneEntity]: @@ -78,19 +62,6 @@ def get_registrable_entities( new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) - serializable_tasks: typing.List[task_models.TaskSpec] = [ - entity for entity in entities_to_be_serialized if isinstance(entity, task_models.TaskSpec) - ] - # Detect if any of the tasks is duplicated. Duplicate tasks are defined as having the same - # metadata identifiers (see :py:class:`flytekit.common.core.identifier.Identifier`). Duplicate - # tasks are considered invalid at registration - # time and usually indicate user error, so we catch this common mistake at serialization time. - duplicate_tasks = _find_duplicate_tasks(serializable_tasks) - if len(duplicate_tasks) > 0: - duplicate_task_names = [task.template.id.name for task in duplicate_tasks] - raise FlyteValidationException( - f"Multiple definitions of the following tasks were found: {duplicate_task_names}" - ) return entities_to_be_serialized diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 5ec249fa4b..b2835dca10 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -9,6 +9,7 @@ from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode +from flytekit.core.container_task import ContainerTask from flytekit.core.gate import Gate from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.map_task import MapPythonTask @@ -189,7 +190,7 @@ def get_serializable_task( # If the pod spec is not None, we have to get it again, because the one we retrieved above will be incorrect. # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. - elif pod: + elif pod and not isinstance(entity, ContainerTask): if isinstance(entity, MapPythonTask): entity.set_command_prefix(get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) @@ -662,11 +663,6 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) - elif isinstance(entity, GateNode): - import ipdb - - ipdb.set_trace() - elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): if entity.should_register: if isinstance(entity, FlyteTask): diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index afb59d58d0..b31bfc855c 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -2,20 +2,25 @@ import os import pathlib +import random import typing from dataclasses import dataclass, field from pathlib import Path +from typing import Any, Generator, Tuple +from uuid import UUID +import fsspec from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types.file import FileExt +from flytekit.types.file import FileExt, FlyteFile T = typing.TypeVar("T") PathType = typing.Union[str, os.PathLike] @@ -143,6 +148,18 @@ def __fspath__(self): def extension(cls) -> str: return "" + @classmethod + def new_remote(cls) -> FlyteDirectory: + """ + Create a new FlyteDirectory object using the currently configured default remote in the context (i.e. + the raw_output_prefix configured in the current FileAccessProvider object in the context). + This is used if you explicitly have a folder somewhere that you want to create files under. + If you want to write a whole folder, you can let your task return a FlyteDirectory object, + and let flytekit handle the uploading. + """ + d = FlyteContext.current_context().file_access.get_random_remote_directory() + return FlyteDirectory(path=d) + def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: return cls @@ -171,6 +188,12 @@ def downloaded(self) -> bool: def remote_directory(self) -> typing.Optional[str]: return self._remote_directory + @property + def sep(self) -> str: + if os.name == "nt" and get_protocol(self.path or self.remote_source or self.remote_directory) == "file": + return "\\" + return "/" + @property def remote_source(self) -> str: """ @@ -179,9 +202,67 @@ def remote_source(self) -> str: """ return typing.cast(str, self._remote_source) + def new_file(self, name: typing.Optional[str] = None) -> FlyteFile: + """ + This will create a new file under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + # TODO we may want to use - https://github.com/fsspec/universal_pathlib + if not name: + name = UUID(int=random.getrandbits(128)).hex + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteFile(path=new_path) + + def new_dir(self, name: typing.Optional[str] = None) -> FlyteDirectory: + """ + This will create a new folder under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + if not name: + name = UUID(int=random.getrandbits(128)).hex + + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteDirectory(path=new_path) + def download(self) -> str: return self.__fspath__() + def crawl( + self, maxdepth: typing.Optional[int] = None, topdown: bool = True, **kwargs + ) -> Generator[Tuple[typing.Union[str, os.PathLike[Any]], typing.Dict[Any, Any]], None, None]: + """ + Crawl returns a generator of all files prefixed by any sub-folders under the given "FlyteDirectory". + if details=True is passed, then it will return a dictionary as specified by fsspec. + + Example: + + >>> list(fd.crawl()) + [("/base", "file1"), ("/base", "dir1/file1"), ("/base", "dir2/file1"), ("/base", "dir1/dir/file1")] + + >>> list(x.crawl(detail=True)) + [('/tmp/test', {'my-dir/ab.py': {'name': '/tmp/test/my-dir/ab.py', 'size': 0, 'type': 'file', + 'created': 1677720780.2318847, 'islink': False, 'mode': 33188, 'uid': 501, 'gid': 0, + 'mtime': 1677720780.2317934, 'ino': 1694329, 'nlink': 1}})] + """ + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_directory: + final_path = self.remote_directory + ctx = FlyteContextManager.current_context() + fs = ctx.file_access.get_filesystem_for_path(final_path) + base_path_len = len(fsspec.core.strip_protocol(final_path)) + 1 # Add additional `/` at the end + for base, _, files in fs.walk(final_path, maxdepth, topdown, **kwargs): + current_base = base[base_path_len:] + if isinstance(files, dict): + for f, v in files.items(): + yield final_path, {os.path.join(current_base, f): v} + else: + for f in files: + yield final_path, os.path.join(current_base, f) + def __repr__(self): return self.path diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9fc55f76ce..9508dee2e2 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -3,12 +3,14 @@ import os import pathlib import typing +from contextlib import contextmanager from dataclasses import dataclass, field from dataclasses_json import config, dataclass_json from marshmallow import fields +from typing_extensions import Annotated, get_args, get_origin -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.core.types import BlobType @@ -27,7 +29,9 @@ def noop(): @dataclass_json @dataclass class FlyteFile(os.PathLike, typing.Generic[T]): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore + path: typing.Union[str, os.PathLike] = field( + default=None, metadata=config(mm_field=fields.String()) + ) # type: ignore """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -148,6 +152,15 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: def extension(cls) -> str: return "" + @classmethod + def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: + """ + Create a new FlyteFile object with a remote path. + """ + ctx = FlyteContextManager.current_context() + remote_path = ctx.file_access.get_random_remote_path(name) + return cls(path=remote_path) + def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: from . import FileExt @@ -226,6 +239,57 @@ def remote_source(self) -> str: def download(self) -> str: return self.__fspath__() + @contextmanager + def open( + self, + mode: str, + cache_type: typing.Optional[str] = None, + cache_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + ): + """ + Returns a streaming File handle + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with ff.open("rb", cache_type="readahead", cache={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + Alternatively + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with fsspec.open(f"readahead::{ff.remote_path}", "rb", readahead={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + + :param mode: str Open mode like 'rb', 'rt', 'wb', ... + :param cache_type: optional str Specify if caching is to be used. Cache protocol can be ones supported by + fsspec https://filesystem-spec.readthedocs.io/en/latest/api.html#readbuffering, + especially useful for large file reads + :param cache_options: optional Dict[str, Any] Refer to fsspec caching options. This is strongly coupled to the + cache_protocol + """ + ctx = FlyteContextManager.current_context() + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_path: + final_path = self.remote_path + fs = ctx.file_access.get_filesystem_for_path(final_path) + f = fs.open(final_path, mode, cache_type=cache_type, cache_options=cache_options) + yield f + f.close() + def __repr__(self): return self.path @@ -272,6 +336,10 @@ def to_literal( if python_val is None: raise TypeTransformerFailedError("None value cannot be converted to a file.") + # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] + if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") diff --git a/flytekit/types/pickle/__init__.py b/flytekit/types/pickle/__init__.py index 65604e67bb..e5bd1c056d 100644 --- a/flytekit/types/pickle/__init__.py +++ b/flytekit/types/pickle/__init__.py @@ -9,4 +9,4 @@ FlytePickle """ -from .pickle import FlytePickle +from .pickle import BatchSize, FlytePickle diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 3472dec7e6..3de75b765b 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -13,6 +13,19 @@ T = typing.TypeVar("T") +class BatchSize: + """ + Flyte-specific object used to wrap the hash function for a specific type + """ + + def __init__(self, val: int): + self._val = val + + @property + def val(self) -> int: + return self._val + + class FlytePickle(typing.Generic[T]): """ This type is only used by flytekit internally. User should not use this type. diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 8a8d832b58..7d043910d4 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -180,7 +180,6 @@ class FlyteSchema(object): """ This is the main schema class that users should use. """ - logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -197,6 +196,7 @@ def format(cls) -> SchemaFormat: def __class_getitem__( cls, columns: typing.Dict[str, typing.Type], fmt: SchemaFormat = SchemaFormat.PARQUET ) -> Type[FlyteSchema]: + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") if columns is None: return FlyteSchema @@ -234,6 +234,7 @@ def __init__( supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, downloader: typing.Callable[[str, os.PathLike], None] = None, ): + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") if supported_mode == SchemaOpenMode.READ and remote_path is None: raise ValueError("To create a FlyteSchema in read mode, remote_path is required") if ( diff --git a/flytekit/types/structured/__init__.py b/flytekit/types/structured/__init__.py index 52577a650d..86fa19f4f0 100644 --- a/flytekit/types/structured/__init__.py +++ b/flytekit/types/structured/__init__.py @@ -13,15 +13,9 @@ """ -from flytekit.configuration.internal import LocalSDK +from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer from flytekit.loggers import logger -from .basic_dfs import ( - ArrowToParquetEncodingHandler, - PandasToParquetEncodingHandler, - ParquetToArrowDecodingHandler, - ParquetToPandasDecodingHandler, -) from .structured_dataset import ( StructuredDataset, StructuredDatasetDecoder, @@ -29,15 +23,42 @@ StructuredDatasetTransformerEngine, ) -try: - from .bigquery import ( - ArrowToBQEncodingHandlers, - BQToArrowDecodingHandler, - BQToPandasDecodingHandler, - PandasToBQEncodingHandlers, - ) -except ImportError: - logger.info( - "We won't register bigquery handler for structured dataset because " - "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" - ) + +def register_pandas_handlers(): + import pandas as pd + + from .basic_dfs import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler + + StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) + + +def register_arrow_handlers(): + import pyarrow as pa + + from .basic_dfs import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler + + StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) + StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) + + +def register_bigquery_handlers(): + try: + from .bigquery import ( + ArrowToBQEncodingHandlers, + BQToArrowDecodingHandler, + BQToPandasDecodingHandler, + PandasToBQEncodingHandlers, + ) + + StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) + StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) + StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) + StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) + except ImportError: + logger.info( + "We won't register bigquery handler for structured dataset because " + "we can't find the packages google-cloud-bigquery-storage and google-cloud-bigquery" + ) diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 39f8d11e24..8004867271 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -1,14 +1,18 @@ import os import typing +from pathlib import Path from typing import TypeVar import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from botocore.exceptions import NoCredentialsError +from fsspec.core import split_protocol, strip_protocol +from fsspec.utils import get_protocol -from flytekit import FlyteContext -from flytekit.deck import TopFrameRenderer -from flytekit.deck.renderer import ArrowRenderer +from flytekit import FlyteContext, logger +from flytekit.configuration import DataConfig +from flytekit.core.data_persistence import s3_setup_args from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -17,12 +21,20 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, - StructuredDatasetTransformerEngine, ) T = TypeVar("T") +def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]: + protocol = get_protocol(uri) + if protocol == "s3": + kwargs = s3_setup_args(cfg.s3, anon) + if kwargs: + return kwargs + return None + + class PandasToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pd.DataFrame, None, PARQUET) @@ -33,15 +45,19 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + if not ctx.file_access.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + path = os.path.join(uri, f"{0:05}") df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - df.to_parquet(local_path, coerce_timestamps="us", allow_truncated_timestamps=False) - ctx.file_access.upload_directory(local_dir, path) + df.to_parquet( + path, + coerce_timestamps="us", + allow_truncated_timestamps=False, + storage_options=get_storage_options(ctx.file_access.data_config, path), + ) structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): @@ -54,13 +70,17 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) + uri = flyte_value.uri + columns = None + kwargs = get_storage_options(ctx.file_access.data_config, uri) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pd.read_parquet(local_dir, columns=columns) - return pd.read_parquet(local_dir) + try: + return pd.read_parquet(uri, columns=columns, storage_options=kwargs) + except NoCredentialsError: + logger.debug("S3 source detected, attempting anonymous S3 access") + kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + return pd.read_parquet(uri, columns=columns, storage_options=kwargs) class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): @@ -73,13 +93,13 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_path() - df = structured_dataset.dataframe - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - pq.write_table(df, local_path) - ctx.file_access.upload_directory(local_dir, path) - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + if not ctx.file_access.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + path = os.path.join(uri, f"{0:05}") + filesystem = ctx.file_access.get_filesystem_for_path(path) + pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem) + return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): @@ -92,19 +112,20 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) + uri = flyte_value.uri + if not ctx.file_access.is_remote(uri): + Path(uri).parent.mkdir(parents=True, exist_ok=True) + _, path = split_protocol(uri) + + columns = None if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pq.read_table(local_dir, columns=columns) - return pq.read_table(local_dir) - - -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True) - -StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) -StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) + try: + fs = ctx.file_access.get_filesystem_for_path(uri) + return pq.read_table(path, filesystem=fs, columns=columns) + except NoCredentialsError as e: + logger.debug("S3 source detected, attempting anonymous S3 access") + fs = ctx.file_access.get_filesystem_for_path(uri, anonymous=True) + if fs is not None: + return pq.read_table(path, filesystem=fs, columns=columns) + raise e diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 85cede1544..049a21c07e 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -14,7 +14,6 @@ StructuredDatasetDecoder, StructuredDatasetEncoder, StructuredDatasetMetadata, - StructuredDatasetTransformerEngine, ) BIGQUERY = "bq" @@ -110,9 +109,3 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata)) - - -StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) -StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) -StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) -StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 0e4649203a..fe5c3595ff 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -9,15 +9,13 @@ from typing import Dict, Generator, Optional, Type, Union import _datetime -import numpy as _np -import pandas as pd -import pyarrow as pa from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin +from flytekit import lazy_module from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.deck.renderer import Renderable from flytekit.loggers import logger @@ -26,6 +24,13 @@ from flytekit.models.literals import Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType +if typing.TYPE_CHECKING: + import pandas as pd + import pyarrow as pa +else: + pd = lazy_module("pandas") + pa = lazy_module("pyarrow") + T = typing.TypeVar("T") # StructuredDataset type or a dataframe type DF = typing.TypeVar("DF") # Dataframe type @@ -35,6 +40,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" GENERIC_FORMAT: StructuredDatasetFormat = "" +GENERIC_PROTOCOL: str = "generic protocol" @dataclass_json @@ -74,7 +80,8 @@ def __init__( # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type: Optional[Type[DF]] = None + self._dataframe_type: Optional[DF] = None # type: ignore + self._already_uploaded = False @property def dataframe(self) -> Optional[Type[DF]]: @@ -109,7 +116,7 @@ def iter(self) -> Generator[DF, None, None]: def extract_cols_and_format( t: typing.Any, -) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional[pa.lib.Schema]]: +) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]: """ Helper function, just used to iterate through Annotations and extract out the following information: - base type, if not Annotated, it will just be the type that was passed in. @@ -143,7 +150,7 @@ def extract_cols_and_format( if ordered_dict_cols is not None: raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") ordered_dict_cols = aa - elif isinstance(aa, pa.Schema): + elif isinstance(aa, pa.lib.Schema): if pa_schema is not None: raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") pa_schema = aa @@ -271,11 +278,6 @@ def decode( raise NotImplementedError -def protocol_prefix(uri: str) -> str: - p = DataPersistencePlugins.get_protocol(uri) - return p - - def convert_schema_type_to_structured_dataset_type( column_type: int, ) -> int: @@ -295,16 +297,8 @@ def convert_schema_type_to_structured_dataset_type( raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}") -class DuplicateHandlerError(ValueError): - ... - - -class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): - """ - Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. - If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of - registering with the main type engine, you should register with this transformer instead. - """ +def get_supported_types(): + import numpy as _np _SUPPORTED_TYPES: typing.Dict[Type, LiteralType] = { _np.int32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER), @@ -326,6 +320,19 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): _np.object_: type_models.LiteralType(simple=type_models.SimpleType.STRING), str: type_models.LiteralType(simple=type_models.SimpleType.STRING), } + return _SUPPORTED_TYPES + + +class DuplicateHandlerError(ValueError): + ... + + +class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): + """ + Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. + If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of + registering with the main type engine, you should register with this transformer instead. + """ ENCODERS: Dict[Type, Dict[str, Dict[str, StructuredDatasetEncoder]]] = {} DECODERS: Dict[Type, Dict[str, Dict[str, StructuredDatasetDecoder]]] = {} @@ -337,42 +344,54 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): @classmethod def _finder(cls, handler_map, df_type: Type, protocol: str, format: str): - # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler - # if missing, see if there's a generic format handler. Error if missing. - # If the incoming format requested is the generic format (""), then see if it's present, - # if not, look to see if there is a default format for the df_type and a handler for that format. - # if still missing, look to see if there's only _one_ handler for that type, if so then use that. - if format != GENERIC_FORMAT: - try: - return handler_map[df_type][protocol][format] - except KeyError: - try: - return handler_map[df_type][protocol][GENERIC_FORMAT] - except KeyError: - ... - else: - try: - return handler_map[df_type][protocol][GENERIC_FORMAT] - except KeyError: - if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]: - hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]] - logger.debug( - f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" - f" using the generic handler {hh} instead." - ) - return hh - if len(handler_map[df_type][protocol]) == 1: - hh = list(handler_map[df_type][protocol].values())[0] - logger.debug( - f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}" - ) - return hh + # If there's an exact match, then we should use it. + try: + return handler_map[df_type][protocol][format] + except KeyError: + ... + + fsspec_handler = None + protocol_specific_handler = None + single_handler = None + default_format = cls.DEFAULT_FORMATS.get(df_type, None) + + try: + fss_handlers = handler_map[df_type]["fsspec"] + if format in fss_handlers: + fsspec_handler = fss_handlers[format] + elif GENERIC_FORMAT in fss_handlers: + fsspec_handler = fss_handlers[GENERIC_FORMAT] + else: + if default_format and default_format in fss_handlers and format == GENERIC_FORMAT: + fsspec_handler = fss_handlers[default_format] else: - logger.warning( - f"Did not automatically pick a handler for {df_type}," - f" more than one detected {handler_map[df_type][protocol].keys()}" - ) - raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") + if len(fss_handlers) == 1 and format == GENERIC_FORMAT: + single_handler = list(fss_handlers.values())[0] + else: + ... + except KeyError: + ... + + try: + protocol_handlers = handler_map[df_type][protocol] + if GENERIC_FORMAT in protocol_handlers: + protocol_specific_handler = protocol_handlers[GENERIC_FORMAT] + else: + if default_format and default_format in protocol_handlers: + protocol_specific_handler = protocol_handlers[default_format] + else: + if len(protocol_handlers) == 1: + single_handler = list(protocol_handlers.values())[0] + else: + ... + + except KeyError: + ... + + if protocol_specific_handler or fsspec_handler or single_handler: + return protocol_specific_handler or fsspec_handler or single_handler + else: + raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") @classmethod def get_encoder(cls, df_type: Type, protocol: str, format: str): @@ -437,18 +456,12 @@ def register( if h.protocol is None: if default_for_type: raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.") - for persistence_protocol in DataPersistencePlugins.supported_protocols(): - # TODO: Clean this up when we get to replacing the persistence layer. - # The behavior of the protocols given in the supported_protocols and is_supported_protocol - # is not actually the same as the one returned in get_protocol. - stripped = DataPersistencePlugins.get_protocol(persistence_protocol) - logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") - try: - cls.register_for_protocol( - h, stripped, False, override, default_format_for_type, default_storage_for_type - ) - except DuplicateHandlerError: - logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") + try: + cls.register_for_protocol( + h, "fsspec", False, override, default_format_for_type, default_storage_for_type + ) + except DuplicateHandlerError: + logger.debug(f"Skipping generic fsspec protocol for handler {h} because duplicate") elif h.protocol == "": raise ValueError(f"Use None instead of empty string for registering handler {h}") @@ -471,8 +484,7 @@ def register_for_protocol( See the main register function instead. """ if protocol == "/": - # TODO: Special fix again, because get_protocol returns file, instead of file:// - protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL) + protocol = "file" lowest_level = cls._handler_finder(h, protocol) if h.supported_format in lowest_level and override is False: raise DuplicateHandlerError( @@ -543,13 +555,15 @@ def to_literal( # def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]: # return dataset if python_val._literal_sd is not None: + if python_val._already_uploaded: + return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) if python_val.dataframe is not None: raise ValueError( f"Shouldn't have specified both literal {python_val._literal_sd} and dataframe {python_val.dataframe}" ) return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) - # 2. A task returns a python StructuredDataset with a uri. + # 2. A task returns a python StructuredDataset with an uri. # Note: this case is also what happens we start a local execution of a task with a python StructuredDataset. # It gets converted into a literal first, then back into a python StructuredDataset. # @@ -594,7 +608,7 @@ def _protocol_from_type_or_prefix(self, ctx: FlyteContext, df_type: Type, uri: O if df_type in self.DEFAULT_PROTOCOLS: return self.DEFAULT_PROTOCOLS[df_type] else: - protocol = protocol_prefix(uri or ctx.file_access.raw_output_prefix) + protocol = get_protocol(uri or ctx.file_access.raw_output_prefix) logger.debug( f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.file_access.raw_output_prefix}" ) @@ -623,7 +637,10 @@ def encode( # Note that this will always be the same as the incoming format except for when the fallback handler # with a format of "" is used. sd_model.metadata._structured_dataset_type.format = handler.supported_format - return Literal(scalar=Scalar(structured_dataset=sd_model)) + lit = Literal(scalar=Scalar(structured_dataset=sd_model)) + sd._literal_sd = sd_model + sd._already_uploaded = True + return lit def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset @@ -770,7 +787,7 @@ def open_as( :param updated_metadata: New metadata type, since it might be different from the metadata in the literal. :return: dataframe. It could be pandas dataframe or arrow table, etc. """ - protocol = protocol_prefix(sd.uri) + protocol = get_protocol(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) result = decoder.decode(ctx, sd, updated_metadata) if isinstance(result, types.GeneratorType): @@ -783,8 +800,8 @@ def iter_as( sd: literals.StructuredDataset, df_type: Type[DF], updated_metadata: StructuredDatasetMetadata, - ) -> Generator[DF, None, None]: - protocol = protocol_prefix(sd.uri) + ) -> typing.Iterator[DF]: + protocol = get_protocol(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] result = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): @@ -792,8 +809,8 @@ def iter_as( return result def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: - if t in self._SUPPORTED_TYPES: - return self._SUPPORTED_TYPES[t] + if t in get_supported_types(): + return get_supported_types()[t] if hasattr(t, "__origin__") and t.__origin__ == list: return type_models.LiteralType(collection_type=self._get_dataset_column_literal_type(t.__args__[0])) if hasattr(t, "__origin__") and t.__origin__ == dict: diff --git a/plugins/flytekit-aws-sagemaker/requirements.txt b/plugins/flytekit-aws-sagemaker/requirements.txt index 64d03c18f2..37771df5ce 100644 --- a/plugins/flytekit-aws-sagemaker/requirements.txt +++ b/plugins/flytekit-aws-sagemaker/requirements.txt @@ -46,7 +46,8 @@ cryptography==39.0.1 dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via + # retry2 deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -152,8 +153,6 @@ protoc-gen-swagger==0.1.0 # via flyteidl psutil==5.9.4 # via sagemaker-training -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit pycparser==2.21 @@ -193,8 +192,8 @@ requests==2.28.2 # responses responses==0.22.0 # via flytekit -retry==0.9.2 - # via flytekit +retry2==0.9.5 + # via flytekitplugins-awssagemaker retrying==1.3.4 # via sagemaker-training s3transfer==0.6.0 diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 9be6800e49..ade15df0e8 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0", "retry2==0.9.5"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py index cba899669b..416a021516 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/__init__.py @@ -11,4 +11,5 @@ BigQueryTask """ +from .backend_plugin import BigQueryPlugin from .task import BigQueryConfig, BigQueryTask diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py new file mode 100644 index 0000000000..acd5ece430 --- /dev/null +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/backend_plugin.py @@ -0,0 +1,94 @@ +import datetime +from typing import Dict, Optional + +import grpc +from flyteidl.service.external_plugin_service_pb2 import ( + SUCCEEDED, + TaskCreateResponse, + TaskDeleteResponse, + TaskGetResponse, +) +from google.cloud import bigquery + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry, convert_to_flyte_state +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.models.types import LiteralType, StructuredDatasetType + +pythonTypeToBigQueryType: Dict[type, str] = { + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes + list: "ARRAY", + bool: "BOOL", + bytes: "BYTES", + datetime.datetime: "DATETIME", + float: "FLOAT64", + int: "INT64", + str: "STRING", +} + + +class BigQueryPlugin(BackendPluginBase): + def __init__(self): + super().__init__(task_type="bigquery_query_job_task") + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> TaskCreateResponse: + job_config = None + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + + logger.info(f"Create BigQuery job config with inputs: {native_inputs}") + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ScalarQueryParameter(name, pythonTypeToBigQueryType[python_interface_inputs[name]], val) + for name, val in native_inputs.items() + ] + ) + + custom = task_template.custom + client = bigquery.Client(project=custom["ProjectID"], location=custom["Location"]) + query_job = client.query(task_template.sql.statement, job_config=job_config) + + return TaskCreateResponse(job_id=str(query_job.job_id)) + + def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + client = bigquery.Client() + job = client.get_job(job_id) + cur_state = convert_to_flyte_state(str(job.state)) + res = None + + if cur_state == SUCCEEDED: + ctx = FlyteContextManager.current_context() + output_location = f"bq://{job.destination.project}:{job.destination.dataset_id}.{job.destination.table_id}" + res = literals.LiteralMap( + { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(uri=output_location), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } + ) + + return TaskGetResponse(state=cur_state, outputs=res.to_flyte_idl()) + + def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse: + client = bigquery.Client() + client.cancel_job(job_id) + return TaskDeleteResponse() + + +BackendPluginRegistry.register(BigQueryPlugin()) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 1d4a7f0dbd..7c24e9e3e9 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -5,10 +5,10 @@ from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct -from flytekit import StructuredDataset from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask from flytekit.models import task as _task_model +from flytekit.types.structured import StructuredDataset @dataclass @@ -81,3 +81,6 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI) return sql + + def execute(self, **kwargs) -> Any: + raise Exception("Cannot run a SQL Task natively, please mock.") diff --git a/plugins/flytekit-bigquery/requirements.txt b/plugins/flytekit-bigquery/requirements.txt index a9bacca60d..c89afe3cc6 100644 --- a/plugins/flytekit-bigquery/requirements.txt +++ b/plugins/flytekit-bigquery/requirements.txt @@ -48,9 +48,9 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 +flyteidl==1.2.10 # via flytekit -flytekit==1.2.7 +flytekit==1.2.9 # via flytekitplugins-bigquery google-api-core[grpc]==2.11.0 # via diff --git a/plugins/flytekit-bigquery/setup.py b/plugins/flytekit-bigquery/setup.py index 88f77429a2..4bb161301a 100644 --- a/plugins/flytekit-bigquery/setup.py +++ b/plugins/flytekit-bigquery/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "google-cloud-bigquery"] +plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "google-cloud-bigquery", "flyteidl>=1.2.10,<1.3.0"] __version__ = "0.0.0+develop" @@ -33,4 +33,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-bigquery/tests/test_backend_plugin.py b/plugins/flytekit-bigquery/tests/test_backend_plugin.py new file mode 100644 index 0000000000..c95cf308a7 --- /dev/null +++ b/plugins/flytekit-bigquery/tests/test_backend_plugin.py @@ -0,0 +1,94 @@ +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +from flyteidl.service.external_plugin_service_pb2 import SUCCEEDED + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_plugin import BackendPluginRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Sql, TaskTemplate + + +@mock.patch("google.cloud.bigquery.job.QueryJob") +@mock.patch("google.cloud.bigquery.Client") +def test_bigquery_plugin(mock_client, mock_query_job): + job_id = "dummy_id" + mock_instance = mock_client.return_value + mock_query_job_instance = mock_query_job.return_value + mock_query_job_instance.state.return_value = "SUCCEEDED" + mock_query_job_instance.job_id.return_value = job_id + + class MockDestination: + def __init__(self): + self.project = "dummy_project" + self.dataset_id = "dummy_dataset" + self.table_id = "dummy_table" + + class MockJob: + def __init__(self): + self.state = "SUCCEEDED" + self.job_id = job_id + self.destination = MockDestination() + + mock_instance.get_job.return_value = MockJob() + mock_instance.query.return_value = MockJob() + mock_instance.cancel_job.return_value = MockJob() + + ctx = MagicMock(spec=grpc.ServicerContext) + p = BackendPluginRegistry.get_plugin(ctx, "bigquery_query_job_task") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + task_config = { + "Location": "us-central1", + "ProjectID": "dummy_project", + } + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="bigquery_query_job_task", + sql=Sql("SELECT 1"), + ) + + assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == job_id + res = p.get(ctx, job_id) + assert res.state == SUCCEEDED + assert ( + res.outputs.literals["results"].scalar.structured_dataset.uri == "bq://dummy_project:dummy_dataset.dummy_table" + ) + p.delete(ctx, job_id) + mock_instance.cancel_job.assert_called() diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index 68ee456ed6..e69de29bb2 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -1,53 +0,0 @@ -""" -.. currentmodule:: flytekitplugins.fsspec - -This package contains things that are useful when extending Flytekit. - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - ArrowToParquetEncodingHandler - FSSpecPersistence - PandasToParquetEncodingHandler - ParquetToArrowDecodingHandler - ParquetToPandasDecodingHandler -""" - -__all__ = [ - "ArrowToParquetEncodingHandler", - "FSSpecPersistence", - "PandasToParquetEncodingHandler", - "ParquetToArrowDecodingHandler", - "ParquetToPandasDecodingHandler", -] - -import importlib - -from flytekit import StructuredDatasetTransformerEngine, logger - -from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler -from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler -from .persist import FSSpecPersistence - -S3 = "s3" -ABFS = "abfs" -GCS = "gs" - - -def _register(protocol: str): - logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.") - StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), True, True) - - -if importlib.util.find_spec("adlfs"): - _register(ABFS) - -if importlib.util.find_spec("s3fs"): - _register(S3) - -if importlib.util.find_spec("gcsfs"): - _register(GCS) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py deleted file mode 100644 index ec8d5f975e..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import typing -from pathlib import Path - -import pyarrow as pa -import pyarrow.parquet as pq -from botocore.exceptions import NoCredentialsError -from flytekitplugins.fsspec.persist import FSSpecPersistence -from fsspec.core import split_protocol, strip_protocol - -from flytekit import FlyteContext, logger -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - PARQUET, - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, -) - - -class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - if not ctx.file_access.is_remote(uri): - Path(uri).mkdir(parents=True, exist_ok=True) - path = os.path.join(uri, f"{0:05}") - fp = FSSpecPersistence(data_config=ctx.file_access.data_config) - filesystem = fp.get_filesystem(path) - pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem) - return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pa.Table: - uri = flyte_value.uri - if not ctx.file_access.is_remote(uri): - Path(uri).parent.mkdir(parents=True, exist_ok=True) - _, path = split_protocol(uri) - - columns = None - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - try: - fp = FSSpecPersistence(data_config=ctx.file_access.data_config) - fs = fp.get_filesystem(uri) - return pq.read_table(path, filesystem=fs, columns=columns) - except NoCredentialsError as e: - logger.debug("S3 source detected, attempting anonymous S3 access") - fs = FSSpecPersistence.get_anonymous_filesystem(uri) - if fs is not None: - return pq.read_table(path, filesystem=fs, columns=columns) - raise e diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py deleted file mode 100644 index e4986ed9f6..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -import typing -from pathlib import Path - -import pandas as pd -from botocore.exceptions import NoCredentialsError -from flytekitplugins.fsspec.persist import FSSpecPersistence, s3_setup_args - -from flytekit import FlyteContext, logger -from flytekit.configuration import DataConfig -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - PARQUET, - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, -) - - -def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]: - protocol = FSSpecPersistence.get_protocol(uri) - if protocol == "s3": - kwargs = s3_setup_args(cfg.s3) - if kwargs: - return kwargs - return None - - -class PandasToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - if not ctx.file_access.is_remote(uri): - Path(uri).mkdir(parents=True, exist_ok=True) - path = os.path.join(uri, f"{0:05}") - df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - df.to_parquet( - path, - coerce_timestamps="us", - allow_truncated_timestamps=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), - ) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - uri = flyte_value.uri - columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - try: - return pd.read_parquet(uri, columns=columns, storage_options=kwargs) - except NoCredentialsError: - logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs["anon"] = True - return pd.read_parquet(uri, columns=columns, storage_options=kwargs) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py deleted file mode 100644 index b890b3cc6c..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import typing - -import fsspec -from fsspec.registry import known_implementations - -from flytekit.configuration import DataConfig, S3Config -from flytekit.extend import DataPersistence, DataPersistencePlugins -from flytekit.loggers import logger - -# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 -# for key and secret -_FSSPEC_S3_KEY_ID = "key" -_FSSPEC_S3_SECRET = "secret" - - -def s3_setup_args(s3_cfg: S3Config): - kwargs = {} - if s3_cfg.access_key_id: - kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - - if s3_cfg.secret_access_key: - kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key - - # S3fs takes this as a special arg - if s3_cfg.endpoint is not None: - kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} - - return kwargs - - -class FSSpecPersistence(DataPersistence): - """ - This DataPersistence plugin uses fsspec to perform the IO. - NOTE: The put is not as performant as it can be for multiple files because of - - https://github.com/intake/filesystem_spec/issues/724. Once this bug is fixed, we can remove the `HACK` in the put - method - """ - - def __init__(self, default_prefix=None, data_config: typing.Optional[DataConfig] = None): - super(FSSpecPersistence, self).__init__(name="fsspec-persistence", default_prefix=default_prefix) - self.default_protocol = self.get_protocol(default_prefix) - self._data_cfg = data_config if data_config else DataConfig.auto() - - @staticmethod - def get_protocol(path: typing.Optional[str] = None): - if path: - return DataPersistencePlugins.get_protocol(path) - logger.info("Setting protocol to file") - return "file" - - def get_filesystem(self, path: str) -> fsspec.AbstractFileSystem: - protocol = FSSpecPersistence.get_protocol(path) - kwargs = {} - if protocol == "file": - kwargs = {"auto_mkdir": True} - elif protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - return fsspec.filesystem(protocol, **kwargs) # type: ignore - - def get_anonymous_filesystem(self, path: str) -> typing.Optional[fsspec.AbstractFileSystem]: - protocol = FSSpecPersistence.get_protocol(path) - if protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore - return anonymous_fs - return None - - @staticmethod - def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: - if not f.endswith("*"): - f = os.path.join(f, "*") - if not t.endswith("/"): - t += "/" - return f, t - - def exists(self, path: str) -> bool: - try: - fs = self.get_filesystem(path) - return fs.exists(path) - except OSError as oe: - logger.debug(f"Error in exists checking {path} {oe}") - fs = self.get_anonymous_filesystem(path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 exists check") - return fs.exists(path) - raise oe - - def get(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(from_path) - if recursive: - from_path, to_path = self.recursive_paths(from_path, to_path) - try: - return fs.get(from_path, to_path, recursive=recursive) - except OSError as oe: - logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") - fs = self.get_anonymous_filesystem(from_path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 access") - return fs.get(from_path, to_path, recursive=recursive) - raise oe - - def put(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(to_path) - if recursive: - from_path, to_path = self.recursive_paths(from_path, to_path) - # BEGIN HACK! - # Once https://github.com/intake/filesystem_spec/issues/724 is fixed, delete the special recursive handling - from fsspec.implementations.local import LocalFileSystem - from fsspec.utils import other_paths - - lfs = LocalFileSystem() - try: - lpaths = lfs.expand_path(from_path, recursive=recursive) - except FileNotFoundError: - # In some cases, there is no file in the original directory, so we just skip copying the file to the remote path - logger.debug(f"there is no file in the {from_path}") - return - rpaths = other_paths(lpaths, to_path) - for l, r in zip(lpaths, rpaths): - fs.put_file(l, r) - return - # END OF HACK!! - return fs.put(from_path, to_path, recursive=recursive) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - path_list = list(paths) # make type check happy - if add_prefix: - path_list.insert(0, self.default_prefix) # type: ignore - path = "/".join(path_list) - if add_protocol: - return f"{self.default_protocol}://{path}" - return typing.cast(str, path) - - -def _register(): - logger.info("Registering fsspec known implementations and overriding all default implementations for persistence.") - DataPersistencePlugins.register_plugin("/", FSSpecPersistence, force=True) - for k, v in known_implementations.items(): - DataPersistencePlugins.register_plugin(f"{k}://", FSSpecPersistence, force=True) - - -# Registering all plugins -_register() diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index 5e6712f396..7102e4fb5f 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0", "fsspec<=2023.1", "botocore>=1.7.48", "pandas>=1.2.0"] +plugin_requires = [] __version__ = "0.0.0+develop" @@ -13,7 +13,7 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="This package data-plugins for flytekit, that are powered by fsspec", + description="This is a deprecated plugin as of flytekit 1.5", url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-data-fsspec", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -22,9 +22,9 @@ install_requires=plugin_requires, extras_require={ # https://github.com/fsspec/filesystem_spec/blob/master/setup.py#L36 - "abfs": ["adlfs>=2022.2.0"], - "aws": ["s3fs>=2021.7.0"], - "gcp": ["gcsfs>=2021.7.0"], + "abfs": [], + "aws": [], + "gcp": [], }, license="apache2", python_requires=">=3.7", @@ -42,5 +42,4 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py b/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py deleted file mode 100644 index 434a763a93..0000000000 --- a/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py +++ /dev/null @@ -1,44 +0,0 @@ -import pandas as pd -import pyarrow as pa -from flytekitplugins.fsspec.pandas import get_storage_options - -from flytekit import kwtypes, task -from flytekit.configuration import DataConfig, S3Config - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - - -def test_get_storage_options(): - endpoint = "https://s3.amazonaws.com" - - options = get_storage_options(DataConfig(s3=S3Config(endpoint=endpoint)), "s3://bucket/somewhere") - assert options == {"client_kwargs": {"endpoint_url": endpoint}} - - options = get_storage_options(DataConfig(), "/tmp/file") - assert options is None - - -cols = kwtypes(Name=str, Age=int) -subset_cols = kwtypes(Name=str) - - -@task -def t1( - df1: Annotated[pd.DataFrame, cols], df2: Annotated[pa.Table, cols] -) -> (Annotated[pd.DataFrame, subset_cols], Annotated[pa.Table, subset_cols]): - return df1, df2 - - -def test_structured_dataset_wf(): - pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - pa_df = pa.Table.from_pandas(pd_df) - - subset_pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"]}) - subset_pa_df = pa.Table.from_pandas(subset_pd_df) - - df1, df2 = t1(df1=pd_df, df2=pa_df) - assert df1.equals(subset_pd_df) - assert df2.equals(subset_pa_df) diff --git a/plugins/flytekit-data-fsspec/tests/test_persist.py b/plugins/flytekit-data-fsspec/tests/test_persist.py deleted file mode 100644 index 8e87c9c5eb..0000000000 --- a/plugins/flytekit-data-fsspec/tests/test_persist.py +++ /dev/null @@ -1,183 +0,0 @@ -import os -import pathlib -import tempfile - -import mock -from flytekitplugins.fsspec.persist import FSSpecPersistence, s3_setup_args -from fsspec.implementations.local import LocalFileSystem - -from flytekit.configuration import S3Config - - -def test_s3_setup_args(): - kwargs = s3_setup_args(S3Config()) - assert kwargs == {} - - kwargs = s3_setup_args(S3Config(endpoint="http://localhost:30084")) - assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} - - kwargs = s3_setup_args(S3Config(access_key_id="access")) - assert kwargs == {"key": "access"} - - -@mock.patch.dict(os.environ, {}, clear=True) -def test_s3_setup_args_env_empty(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {} - - -@mock.patch.dict( - os.environ, - { - "AWS_ACCESS_KEY_ID": "ignore-user", - "AWS_SECRET_ACCESS_KEY": "ignore-secret", - "FLYTE_AWS_ACCESS_KEY_ID": "flyte", - "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", - }, - clear=True, -) -def test_s3_setup_args_env_both(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret"} - - -@mock.patch.dict( - os.environ, - { - "FLYTE_AWS_ACCESS_KEY_ID": "flyte", - "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", - }, - clear=True, -) -def test_s3_setup_args_env_flyte(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret"} - - -@mock.patch.dict( - os.environ, - { - "AWS_ACCESS_KEY_ID": "ignore-user", - "AWS_SECRET_ACCESS_KEY": "ignore-secret", - }, - clear=True, -) -def test_s3_setup_args_env_aws(): - kwargs = s3_setup_args(S3Config.auto()) - # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default - assert kwargs == {} - - -def test_get_protocol(): - assert FSSpecPersistence.get_protocol("s3://abc") == "s3" - assert FSSpecPersistence.get_protocol("/abc") == "file" - assert FSSpecPersistence.get_protocol("file://abc") == "file" - assert FSSpecPersistence.get_protocol("gs://abc") == "gs" - assert FSSpecPersistence.get_protocol("sftp://abc") == "sftp" - assert FSSpecPersistence.get_protocol("abfs://abc") == "abfs" - - -def test_get_anonymous_filesystem(): - fp = FSSpecPersistence() - fs = fp.get_anonymous_filesystem("/abc") - assert fs is None - fs = fp.get_anonymous_filesystem("s3://abc") - assert fs is not None - assert fs.protocol == ["s3", "s3a"] - - -def test_get_filesystem(): - fp = FSSpecPersistence() - fs = fp.get_filesystem("/abc") - assert fs is not None - assert isinstance(fs, LocalFileSystem) - - -def test_recursive_paths(): - f, t = FSSpecPersistence.recursive_paths("/tmp", "/tmp") - assert (f, t) == ("/tmp/*", "/tmp/") - f, t = FSSpecPersistence.recursive_paths("/tmp/", "/tmp/") - assert (f, t) == ("/tmp/*", "/tmp/") - f, t = FSSpecPersistence.recursive_paths("/tmp/*", "/tmp") - assert (f, t) == ("/tmp/*", "/tmp/") - - -def test_exists(): - fs = FSSpecPersistence() - assert not fs.exists("/tmp/non-existent") - - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - assert fs.exists(f) - - -def test_get(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - t = os.path.join(tdir, "t.txt") - - fs.get(f, t) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_get_recursive(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - p = pathlib.Path(tdir) - d = p.joinpath("d") - d.mkdir() - f = d.joinpath(d, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - o = p.joinpath("o") - - t = o.joinpath(o, "f.txt") - fs.get(str(d), str(o), recursive=True) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_put(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - t = os.path.join(tdir, "t.txt") - - fs.put(f, t) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_put_recursive(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - p = pathlib.Path(tdir) - d = p.joinpath("d") - d.mkdir() - f = d.joinpath(d, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - o = p.joinpath("o") - - t = o.joinpath(o, "f.txt") - fs.put(str(d), str(o), recursive=True) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_construct_path(): - fs = FSSpecPersistence() - assert fs.construct_path(True, False, "abc") == "file://abc" diff --git a/plugins/flytekit-data-fsspec/tests/test_placeholder.py b/plugins/flytekit-data-fsspec/tests/test_placeholder.py new file mode 100644 index 0000000000..eb6dc82a34 --- /dev/null +++ b/plugins/flytekit-data-fsspec/tests/test_placeholder.py @@ -0,0 +1,3 @@ +# This test is here to give pytest something to run, otherwise it returns a non-zero return code. +def test_dummy(): + assert 1 + 1 == 2 diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index d2b44e0b65..c090ea6a46 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -1,7 +1,18 @@ -import markdown -import pandas -import plotly.express as px -from pandas_profiling import ProfileReport +from typing import TYPE_CHECKING, List, Optional, Union + +from flytekit import lazy_module +from flytekit.types.file import FlyteFile + +if TYPE_CHECKING: + import markdown + import pandas as pd + import PIL + import plotly.express as px +else: + pd = lazy_module("pandas") + markdown = lazy_module("markdown") + px = lazy_module("plotly.express") + PIL = lazy_module("PIL") class FrameProfilingRenderer: @@ -12,9 +23,11 @@ class FrameProfilingRenderer: def __init__(self, title: str = "Pandas Profiling Report"): self._title = title - def to_html(self, df: pandas.DataFrame) -> str: - assert isinstance(df, pandas.DataFrame) - profile = ProfileReport(df, title=self._title) + def to_html(self, df: "pd.DataFrame") -> str: + assert isinstance(df, pd.DataFrame) + import ydata_profiling + + profile = ydata_profiling.ProfileReport(df, title=self._title) return profile.to_html() @@ -37,7 +50,7 @@ class BoxRenderer: Each box spans from quartile 1 (Q1) to quartile 3 (Q3). The second quartile (Q2) is marked by a line inside the box. By default, the - whiskers correspond to the box' edges +/- 1.5 times the interquartile + whiskers correspond to the box edges +/- 1.5 times the interquartile range (IQR: Q3-Q1), see "points" for other options. """ @@ -45,6 +58,116 @@ class BoxRenderer: def __init__(self, column_name): self._column_name = column_name - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: "pd.DataFrame") -> str: fig = px.box(df, y=self._column_name) return fig.to_html() + + +class ImageRenderer: + """Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data + represented as a base64-encoded string. + """ + + def to_html(self, image_src: Union[FlyteFile, "PIL.Image.Image"]) -> str: + img = self._get_image_object(image_src) + return self._image_to_html_string(img) + + @staticmethod + def _get_image_object(image_src: Union[FlyteFile, "PIL.Image.Image"]) -> "PIL.Image.Image": + if isinstance(image_src, FlyteFile): + local_path = image_src.download() + return PIL.Image.open(local_path) + elif isinstance(image_src, PIL.Image.Image): + return image_src + else: + raise ValueError("Unsupported image source type") + + @staticmethod + def _image_to_html_string(img: "PIL.Image.Image") -> str: + import base64 + from io import BytesIO + + buffered = BytesIO() + img.save(buffered, format="PNG") + img_base64 = base64.b64encode(buffered.getvalue()).decode() + return f'Rendered Image' + + +class TableRenderer: + """ + Convert a pandas DataFrame into an HTML table. + """ + + def to_html(self, df: pd.DataFrame, header_labels: Optional[List] = None, table_width: Optional[int] = None) -> str: + # Check if custom labels are provided and have the correct length + if header_labels is not None and len(header_labels) == len(df.columns): + df = df.copy() + df.columns = header_labels + + style = f""" + + """ + return style + df.to_html(classes="table-class", index=False) + + +class GanttChartRenderer: + """ + This renderer is primarily used by the timeline deck. The input DataFrame should + have at least the following columns: + - "Start": datetime.datetime (represents the start time) + - "Finish": datetime.datetime (represents the end time) + - "Name": string (the name of the task or event) + """ + + def to_html(self, df: pd.DataFrame, chart_width: Optional[int] = None) -> str: + fig = px.timeline(df, x_start="Start", x_end="Finish", y="Name", color="Name", width=chart_width) + + fig.update_xaxes( + tickangle=90, + rangeslider_visible=True, + tickformatstops=[ + dict(dtickrange=[None, 1], value="%3f ms"), + dict(dtickrange=[1, 60], value="%S:%3f s"), + dict(dtickrange=[60, 3600], value="%M:%S m"), + dict(dtickrange=[3600, None], value="%H:%M h"), + ], + ) + + # Remove y-axis tick labels and title since the time line deck space is limited. + fig.update_yaxes(showticklabels=False, title="") + + fig.update_layout( + autosize=True, + # Set the orientation of the legend to horizontal and move the legend anchor 2% beyond the top of the timeline graph's vertical axis + legend=dict(orientation="h", y=1.02), + ) + + return fig.to_html() diff --git a/plugins/flytekit-deck-standard/tests/test_renderer.py b/plugins/flytekit-deck-standard/tests/test_renderer.py index 79eb7e877d..1878193733 100644 --- a/plugins/flytekit-deck-standard/tests/test_renderer.py +++ b/plugins/flytekit-deck-standard/tests/test_renderer.py @@ -1,8 +1,33 @@ +import datetime +import tempfile + import markdown import pandas as pd -from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer +import pytest +from flytekitplugins.deck.renderer import ( + BoxRenderer, + FrameProfilingRenderer, + GanttChartRenderer, + ImageRenderer, + MarkdownRenderer, + TableRenderer, +) +from PIL import Image + +from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) +time_info_df = pd.DataFrame( + [ + dict( + Name="foo", + Start=datetime.datetime.utcnow(), + Finish=datetime.datetime.utcnow() + datetime.timedelta(microseconds=1000), + WallTime=1.0, + ProcessTime=1.0, + ) + ] +) def test_frame_profiling_renderer(): @@ -19,3 +44,39 @@ def test_markdown_renderer(): def test_box_renderer(): renderer = BoxRenderer("Name") assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title() + + +def create_simple_image(fmt: str): + """Create a simple PNG image using PIL""" + img = Image.new("RGB", (100, 100), color="black") + tmp = tempfile.mktemp() + img.save(tmp, fmt) + return tmp + + +png_image = create_simple_image(fmt="png") +jpeg_image = create_simple_image(fmt="jpeg") + + +@pytest.mark.parametrize( + "image_src", + [ + FlyteFile(path=png_image), + JPEGImageFile(path=jpeg_image), + PNGImageFile(path=png_image), + Image.open(png_image), + ], +) +def test_image_renderer(image_src): + renderer = ImageRenderer() + assert " str: +# return "hello" +``` diff --git a/plugins/flytekit-envd/flytekitplugins/envd/__init__.py b/plugins/flytekit-envd/flytekitplugins/envd/__init__.py new file mode 100644 index 0000000000..d3dec806a1 --- /dev/null +++ b/plugins/flytekit-envd/flytekitplugins/envd/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.envd + +This plugin enables seamless integration between Flyte and envd. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + EnvdImageSpecBuilder +""" + +from .image_builder import EnvdImageSpecBuilder diff --git a/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py new file mode 100644 index 0000000000..fec6647443 --- /dev/null +++ b/plugins/flytekit-envd/flytekitplugins/envd/image_builder.py @@ -0,0 +1,73 @@ +import pathlib +import shutil +import subprocess + +import click + +from flytekit.configuration import DefaultImages +from flytekit.core import context_manager +from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, ImageSpec, ImageSpecBuilder + + +class EnvdImageSpecBuilder(ImageSpecBuilder): + """ + This class is used to build a docker image using envd. + """ + + def build_image(self, image_spec: ImageSpec): + cfg_path = create_envd_config(image_spec) + command = f"envd build --path {pathlib.Path(cfg_path).parent}" + if image_spec.registry: + command += f" --output type=image,name={image_spec.image_name()},push=true" + click.secho(f"Run command: {command} ", fg="blue") + p = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + for line in iter(p.stdout.readline, ""): + if p.poll() is not None: + break + if line.decode().strip() != "": + click.secho(line.decode().strip(), fg="blue") + + if p.returncode != 0: + _, stderr = p.communicate() + raise Exception( + f"failed to build the imageSpec at {cfg_path} with error {stderr}", + ) + + +def create_envd_config(image_spec: ImageSpec) -> str: + base_image = DefaultImages.default_image() if image_spec.base_image is None else image_spec.base_image + packages = [] if image_spec.packages is None else image_spec.packages + apt_packages = [] if image_spec.apt_packages is None else image_spec.apt_packages + env = {"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()} + if image_spec.env: + env.update(image_spec.env) + + envd_config = f"""# syntax=v1 + +def build(): + base(image="{base_image}", dev=False) + install.python_packages(name = [{', '.join(map(str, map(lambda x: f'"{x}"', packages)))}]) + install.apt_packages(name = [{', '.join(map(str, map(lambda x: f'"{x}"', apt_packages)))}]) + runtime.environ(env={env}) +""" + + if image_spec.python_version: + # Indentation is required by envd + envd_config += f' install.python(version="{image_spec.python_version}")\n' + + ctx = context_manager.FlyteContextManager.current_context() + cfg_path = ctx.file_access.get_random_local_path("build.envd") + pathlib.Path(cfg_path).parent.mkdir(parents=True, exist_ok=True) + + if image_spec.source_root: + shutil.copytree(image_spec.source_root, pathlib.Path(cfg_path).parent, dirs_exist_ok=True) + # Indentation is required by envd + envd_config += ' io.copy(host_path="./", envd_path="/root")' + + with open(cfg_path, "w+") as f: + f.write(envd_config) + + return cfg_path + + +ImageBuildEngine.register("envd", EnvdImageSpecBuilder()) diff --git a/plugins/flytekit-envd/requirements.in b/plugins/flytekit-envd/requirements.in new file mode 100644 index 0000000000..16b527ba7e --- /dev/null +++ b/plugins/flytekit-envd/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-envd diff --git a/plugins/flytekit-envd/requirements.txt b/plugins/flytekit-envd/requirements.txt new file mode 100644 index 0000000000..78e9a287ca --- /dev/null +++ b/plugins/flytekit-envd/requirements.txt @@ -0,0 +1,229 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-envd + # via -r requirements.in +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests +cffi==1.15.1 + # via cryptography +chardet==5.1.0 + # via binaryornot +charset-normalizer==3.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.2.1 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.8 + # via flytekit +cryptography==40.0.1 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +envd==0.3.16 + # via flytekitplugins-envd +flyteidl==1.3.15 + # via flytekit +flytekit==1.5.0 + # via flytekitplugins-envd +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-auth==2.17.1 + # via kubernetes +googleapis-common-protos==1.59.0 + # via + # flyteidl + # flytekit + # grpcio-status +grpcio==1.53.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.53.0 + # via flytekit +idna==3.4 + # via requests +importlib-metadata==6.1.0 + # via + # flytekit + # keyring +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.13.1 + # via flytekit +kubernetes==26.1.0 + # via flytekit +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.1.0 + # via jaraco-classes +mypy-extensions==1.0.0 + # via typing-inspect +natsort==8.3.1 + # via flytekit +numpy==1.23.5 + # via + # flytekit + # pandas + # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.0 + # via + # docker + # marshmallow +pandas==1.5.3 + # via flytekit +protobuf==4.22.1 + # via + # flyteidl + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==10.0.1 + # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth +pycparser==2.21 + # via cffi +pyopenssl==23.1.1 + # via flytekit +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # kubernetes + # pandas +python-json-logger==2.0.7 + # via flytekit +python-slugify==8.0.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2023.3 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # kubernetes + # responses +regex==2023.3.23 + # via docker-image-py +requests==2.28.2 + # via + # cookiecutter + # docker + # flytekit + # kubernetes + # requests-oauthlib + # responses +requests-oauthlib==1.3.1 + # via kubernetes +responses==0.23.1 + # via flytekit +retry==0.9.2 + # via flytekit +rsa==4.9 + # via google-auth +six==1.16.0 + # via + # google-auth + # kubernetes + # python-dateutil +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +types-pyyaml==6.0.12.9 + # via responses +typing-extensions==4.5.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.15 + # via + # docker + # flytekit + # kubernetes + # requests + # responses +websocket-client==1.5.1 + # via + # docker + # kubernetes +wheel==0.40.0 + # via flytekit +wrapt==1.15.0 + # via + # deprecated + # flytekit +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-envd/setup.py b/plugins/flytekit-envd/setup.py new file mode 100644 index 0000000000..d95a260958 --- /dev/null +++ b/plugins/flytekit-envd/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "envd" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit", "envd"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package enables users to easily build a Docker image for tasks or workflows.", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-envd/tests/__init__.py b/plugins/flytekit-envd/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-envd/tests/test_image_spec.py b/plugins/flytekit-envd/tests/test_image_spec.py new file mode 100644 index 0000000000..7c7ccd2151 --- /dev/null +++ b/plugins/flytekit-envd/tests/test_image_spec.py @@ -0,0 +1,31 @@ +from pathlib import Path + +from flytekitplugins.envd.image_builder import EnvdImageSpecBuilder, create_envd_config + +from flytekit.image_spec.image_spec import ImageSpec + + +def test_image_spec(): + image_spec = ImageSpec( + packages=["pandas"], + apt_packages=["git"], + python_version="3.8", + registry="", + base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", + ) + + EnvdImageSpecBuilder().build_image(image_spec) + config_path = create_envd_config(image_spec) + contents = Path(config_path).read_text() + assert ( + contents + == """# syntax=v1 + +def build(): + base(image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", dev=False) + install.python_packages(name = ["pandas"]) + install.apt_packages(name = ["git"]) + runtime.environ(env={'PYTHONPATH': '/root', '_F_IMG_ID': 'flytekit:yZ8jICcDTLoDArmNHbWNwg..'}) + install.python(version="3.8") +""" + ) diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 0d6788ac92..014b88f4f3 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -355,8 +355,12 @@ def simple_pod_task(i: int): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.test_pod", "task-name", diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index 6920c34e84..df5c74288e 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import MPIJob +from .task import HorovodJob, MPIJob diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 6f207b421d..e1c1be0a03 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -133,5 +133,62 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) +@dataclass +class HorovodJob(object): + slots: int + num_launcher_replicas: int = 1 + num_workers: int = 1 + + +class HorovodFunctionTask(MPIFunctionTask): + """ + For more info, check out https://github.com/horovod/horovod + """ + + # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. + ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" + discovery_script_path = "/etc/mpi/discover_hosts.sh" + + def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): + + super().__init__( + task_config=task_config, + task_function=task_function, + **kwargs, + ) + + def get_command(self, settings: SerializationSettings) -> List[str]: + cmd = super().get_command(settings) + mpi_cmd = self._get_horovod_prefix() + cmd + return mpi_cmd + + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + config = super().get_config(settings) + return {**config, "worker_spec_command": self.ssh_command} + + def _get_horovod_prefix(self) -> List[str]: + np = self.task_config.num_workers * self.task_config.slots + base_cmd = [ + "horovodrun", + "-np", + f"{np}", + "--verbose", + "--log-level", + "INFO", + "--network-interface", + "eth0", + "--min-np", + f"{np}", + "--max-np", + f"{np}", + "--slots-per-host", + f"{self.task_config.slots}", + "--host-discovery-script", + self.discovery_script_path, + ] + return base_cmd + + # Register the MPI Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) +TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask) diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index ebb0c49b58..7732d520c2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,4 +1,4 @@ -from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel +from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -41,3 +41,26 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} assert my_mpi_task.task_type == "mpi" + + +def test_horovod_task(): + @task( + task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + ) + def my_horovod_task(): + ... + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + cmd = my_horovod_task.get_command(settings) + assert "horovodrun" in cmd + config = my_horovod_task.get_config(settings) + assert "/usr/sbin/sshd" in config["worker_spec_command"] + custom = my_horovod_task.get_custom(settings) + assert isinstance(custom, dict) is True diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 280fe687b6..7de27502bf 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -2,6 +2,9 @@ This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends. +This plugin can execute torch elastic training, which is equivalent to run `torchrun`. Elastic training can be executed +in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner. + To install the plugin, run the following command: ```bash diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index aedb0b192f..cb9add7302 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -8,6 +8,7 @@ :toctree: generated/ PyTorch + Elastic """ -from .task import PyTorch +from .task import Elastic, PyTorch diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py deleted file mode 100644 index 517f4a9eb6..0000000000 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py +++ /dev/null @@ -1,23 +0,0 @@ -from flyteidl.plugins import pytorch_pb2 as _pytorch_task - -from flytekit.models import common as _common - - -class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): - self._workers_count = workers_count - - @property - def workers_count(self): - return self._workers_count - - def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 4b0bde78b0..aea2c9a2e6 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -2,16 +2,20 @@ This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ +import os from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Union +import cloudpickle +from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask from google.protobuf.json_format import MessageToDict +import flytekit from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings -from flytekit.extend import TaskPlugins +from flytekit.extend import IgnoreOutputs, TaskPlugins -from .models import PyTorchJob +TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @dataclass @@ -29,6 +33,31 @@ class PyTorch(object): num_workers: int +@dataclass +class Elastic(object): + """ + Configuration for `torch elastic training `_. + + Use this to run single- or multi-node distributed pytorch elastic training on k8s. + + Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1. + Multi-node training is executed otherwise using a `Pytorch Job `_. + + Args: + nnodes (Union[int, str]): Number of nodes, or the range of nodes in form :. + nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int]. + start_method (str): Multiprocessing start method to use when creating workers. + monitor_interval (int): Interval, in seconds, to monitor the state of workers. + max_restarts (int): Maximum number of worker group restarts before failing. + """ + + nnodes: Union[int, str] = 1 + nproc_per_node: Union[int, str] = "auto" + start_method: str = "spawn" + monitor_interval: int = 5 + max_restarts: int = 0 + + class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator) @@ -46,9 +75,173 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = PyTorchJob(workers_count=self.task_config.num_workers) - return MessageToDict(job.to_flyte_idl()) + job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers) + return MessageToDict(job) # Register the Pytorch Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask) + + +def spawn_helper(fn: bytes, kwargs) -> Any: + """Help to spawn worker processes. + + The purpose of this function is to 1) be pickleable so that it can be used with + the multiprocessing start method `spawn` and 2) to call a cloudpickle-serialized + function passed to it. This function itself doesn't have to be pickleable. Without + such a helper task functions, which are not pickleable, couldn't be used with the + start method `spawn`. + + Args: + fn (bytes): Cloudpickle-serialized target function to be executed in the worker process. + + Returns: + The return value of the received target function. + """ + fn = cloudpickle.loads(fn) + return_val = fn(**kwargs) + return return_val + + +class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): + """ + Plugin for distributed training with torch elastic/torchrun (see + https://pytorch.org/docs/stable/elastic/run.html). + """ + + _ELASTIC_TASK_TYPE = "pytorch" + _ELASTIC_TASK_TYPE_STANDALONE = "python-task" + + def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): + task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE + + super(PytorchElasticFunctionTask, self).__init__( + task_config=task_config, + task_type=task_type, + task_function=task_function, + **kwargs, + ) + try: + from torch.distributed import run + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.task_config.nnodes)) + + """ + c10d is the backend recommended by torch elastic. + https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend + + For c10d, no backend server has to be deployed. + https://pytorch.org/docs/stable/elastic/run.html#deployment + Instead, the workers will use the master's address as the rendezvous point. + """ + self.rdzv_backend = "c10d" + + def _execute(self, **kwargs) -> Any: + """ + This helper method will be invoked to execute the task. + + + Returns: + The result of rank zero. + """ + try: + from torch.distributed import run + from torch.distributed.launcher.api import LaunchConfig, elastic_launch + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + + if isinstance(self.task_config.nproc_per_node, str): + nproc = run.determine_local_world_size(self.task_config.nproc_per_node) + else: + nproc = self.task_config.nproc_per_node + + config = LaunchConfig( + run_id=flytekit.current_context().execution_id.name, + min_nodes=self.min_nodes, + max_nodes=self.max_nodes, + nproc_per_node=nproc, + rdzv_backend=self.rdzv_backend, # rdzv settings + rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"), + max_restarts=self.task_config.max_restarts, + monitor_interval=self.task_config.monitor_interval, + start_method=self.task_config.start_method, + ) + + if self.task_config.start_method == "spawn": + """ + We use cloudpickle to serialize the non-pickleable task function. + The torch elastic launcher then launches the spawn_helper function (which is pickleable) + instead of the task function. This helper function, in the child-process, then deserializes + the task function, again with cloudpickle, and executes it. + """ + launcher_target_func = spawn_helper + + dumped_target_function = cloudpickle.dumps(self._task_function) + launcher_args = (dumped_target_function, kwargs) + elif self.task_config.start_method == "fork": + """ + The torch elastic launcher doesn't support passing kwargs to the target function, + only args. Flyte only works with kwargs. Thus, we create a closure which already has + the task kwargs bound. We tell the torch elastic launcher to start this function in + the child processes. + """ + + def fn_partial(): + """Closure of the task function with kwargs already bound.""" + return self._task_function(**kwargs) + + launcher_target_func = fn_partial + launcher_args = () + + else: + raise Exception("Bad start method") + + out = elastic_launch( + config=config, + entrypoint=launcher_target_func, + )(*launcher_args) + + # `out` is a dictionary of rank (not local rank) -> result + # Rank 0 returns the result of the task function + if 0 in out: + return out[0] + else: + raise IgnoreOutputs() + + def execute(self, **kwargs) -> Any: + """ + This method will be invoked to execute the task. + + Handles the exception scope for the `_execute` method. + """ + from flytekit.exceptions import scopes as exception_scopes + + return exception_scopes.user_entry_point(self._execute)(**kwargs) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + if self.task_config.nnodes == 1: + """ + Torch elastic distributed training is executed in a normal k8s pod so that this + works without the kubeflow train operator. + """ + return super().get_custom(settings) + else: + from flyteidl.plugins.pytorch_pb2 import ElasticConfig + + elastic_config = ElasticConfig( + rdzv_backend=self.rdzv_backend, + min_replicas=self.min_nodes, + max_replicas=self.max_nodes, + nproc_per_node=self.task_config.nproc_per_node, + max_restarts=self.task_config.max_restarts, + ) + job = DistributedPyTorchTrainingTask( + workers=self.max_nodes, + elastic_config=elastic_config, + ) + return MessageToDict(job) + + +# Register the PytorchElastic Plugin into the flytekit core plugin system +TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 96fa577a3e..9cbdbad5c6 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in @@ -10,25 +10,31 @@ arrow==1.2.3 # via jinja2-time binaryornot==0.4.4 # via cookiecutter +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 # via cryptography chardet==5.1.0 # via binaryornot -charset-normalizer==3.0.1 +charset-normalizer==3.1.0 # via requests click==8.1.3 # via # cookiecutter # flytekit cloudpickle==2.2.1 - # via flytekit + # via + # flytekit + # flytekitplugins-kfpytorch cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.3.14 # via flytekit -cryptography==39.0.1 +cryptography==40.0.2 # via # pyopenssl # secretstorage @@ -38,7 +44,7 @@ decorator==5.1.1 # via retry deprecated==1.2.13 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit docker==6.0.1 # via flytekit @@ -46,11 +52,19 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 - # via flytekit -flytekit==1.2.7 +flyteidl==1.2.10 + # via + # flytekit + # flytekitplugins-kfpytorch +flytekit==1.2.9 # via flytekitplugins-kfpytorch -googleapis-common-protos==1.58.0 +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-auth==2.17.3 + # via kubernetes +googleapis-common-protos==1.59.0 # via # flyteidl # grpcio-status @@ -62,13 +76,10 @@ grpcio-status==1.48.2 # via flytekit idna==3.4 # via requests -importlib-metadata==6.0.0 +importlib-metadata==6.6.0 # via - # click # flytekit # keyring -importlib-resources==5.12.0 - # via keyring jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -85,6 +96,8 @@ joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -102,12 +115,14 @@ mypy-extensions==1.0.0 # via typing-inspect natsort==8.2.0 # via flytekit -numpy==1.21.6 +numpy==1.23.5 # via # flytekit # pandas # pyarrow -packaging==23.0 +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 # via # docker # marshmallow @@ -126,15 +141,22 @@ py==1.11.0 # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.0.0 +pyopenssl==23.1.1 # via flytekit python-dateutil==2.8.2 # via # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -142,7 +164,7 @@ python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.7.1 +pytz==2023.3 # via # flytekit # pandas @@ -150,42 +172,47 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.10.31 + # kubernetes + # responses +regex==2023.3.23 # via docker-image-py requests==2.28.2 # via # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses -responses==0.22.0 +requests-oauthlib==1.3.1 + # via kubernetes +responses==0.23.1 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth secretstorage==3.3.3 # via keyring -singledispatchmethod==1.0 - # via flytekit six==1.16.0 # via + # google-auth # grpcio + # kubernetes # python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -toml==0.10.2 - # via responses -types-toml==0.10.8.5 +types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 # via - # arrow # flytekit - # importlib-metadata - # responses # typing-inspect typing-inspect==0.8.0 # via dataclasses-json @@ -193,17 +220,21 @@ urllib3==1.26.14 # via # docker # flytekit + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker -wheel==0.38.4 + # via + # docker + # kubernetes +wheel==0.40.0 # via flytekit wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.14.0 - # via - # importlib-metadata - # importlib-resources +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index c45e409567..a207b9381e 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0"] +plugin_requires = ["cloudpickle", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.10,<1.3.0"] __version__ = "0.0.0+develop" @@ -17,6 +17,9 @@ namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, + extras_require={ + "elastic": ["torch>=1.9.0"], + }, license="apache2", python_requires=">=3.7", classifiers=[ diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py new file mode 100644 index 0000000000..2ca6c9cc65 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -0,0 +1,67 @@ +import os +import typing +from dataclasses import dataclass + +import pytest +import torch +import torch.distributed as dist +from dataclasses_json import dataclass_json +from flytekitplugins.kfpytorch.task import Elastic + +from flytekit import task, workflow + + +@dataclass_json +@dataclass +class Config: + lr: float = 1e-5 + bs: int = 64 + name: str = "foo" + + +def dist_communicate() -> int: + """Communicate between distributed workers.""" + rank = torch.distributed.get_rank() + world_size = dist.get_world_size() + tensor = torch.tensor([5], dtype=torch.int64) + 2 * rank + world_size + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor.item() + + +def train(config: Config) -> typing.Tuple[str, Config, torch.nn.Module, int]: + """Mock training a model using torch-elastic for test purposes.""" + dist.init_process_group(backend="gloo") + + local_rank = os.environ["LOCAL_RANK"] + + out_model = torch.nn.Linear(1000, int(local_rank) + 1) + config.name = "elastic-test" + + distributed_result = dist_communicate() + + return f"result from local rank {local_rank}", config, out_model, distributed_result + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +def test_end_to_end(start_method: str) -> None: + """Test that the workflow with elastic task runs end to end.""" + world_size = 2 + + train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) + + @workflow + def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: + return train_task(config=config) + + r, cfg, m, distributed_result = wf() + assert "result from local rank 0" in r + assert cfg.name == "elastic-test" + assert m.in_features == 1000 + assert m.out_features == 1 + """ + The distributed result is calculated by the workers of the elastic train + task by performing a `dist.all_reduce` operation. The correct result can + only be obtained if the distributed process group is initialized correctly. + """ + assert distributed_result == sum([5 + 2 * rank + world_size for rank in range(world_size)]) diff --git a/plugins/flytekit-mlflow/dev-requirements.txt b/plugins/flytekit-mlflow/dev-requirements.txt index 6ad9be49bb..5788aeb7d2 100644 --- a/plugins/flytekit-mlflow/dev-requirements.txt +++ b/plugins/flytekit-mlflow/dev-requirements.txt @@ -36,11 +36,7 @@ h5py==3.7.0 # via tensorflow idna==3.4 # via requests -importlib-metadata==5.0.0 - # via markdown -keras==2.10.0 - # via tensorflow -keras-preprocessing==1.1.2 +keras==2.11.0 # via tensorflow libclang==14.0.6 # via tensorflow @@ -51,7 +47,6 @@ markupsafe==2.1.1 numpy==1.23.4 # via # h5py - # keras-preprocessing # opt-einsum # tensorboard # tensorflow @@ -87,17 +82,16 @@ six==1.16.0 # google-auth # google-pasta # grpcio - # keras-preprocessing # tensorflow -tensorboard==2.10.1 +tensorboard==2.11.2 # via tensorflow tensorboard-data-server==0.6.1 # via tensorboard tensorboard-plugin-wit==1.8.1 # via tensorboard -tensorflow==2.10.0 +tensorflow==2.11.1 # via -r dev-requirements.in -tensorflow-estimator==2.10.0 +tensorflow-estimator==2.11.0 # via tensorflow tensorflow-io-gcs-filesystem==0.27.0 # via tensorflow @@ -115,8 +109,6 @@ wheel==0.38.3 # tensorboard wrapt==1.14.1 # via tensorflow -zipp==3.10.0 - # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py index b58aa4a120..9fa897f90e 100644 --- a/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py +++ b/plugins/flytekit-mlflow/flytekitplugins/mlflow/tracking.py @@ -13,7 +13,7 @@ from flytekit import FlyteContextManager from flytekit.bin.entrypoint import get_one_of from flytekit.core.context_manager import ExecutionState -from flytekit.deck import TopFrameRenderer +from flytekit.deck.renderer import TopFrameRenderer def metric_to_df(metrics: typing.List[Metric]) -> pd.DataFrame: diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py index b196327d8d..613cbfcd76 100644 --- a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -29,4 +29,4 @@ def train_model(epochs: int): def test_local_exec(): train_model(epochs=1) - assert len(flytekit.current_context().decks) == 4 # mlflow metrics, params, input, and output + assert len(flytekit.current_context().decks) == 5 # mlflow metrics, params, timeline, input, and output diff --git a/plugins/flytekit-pandera/tests/test_plugin.py b/plugins/flytekit-pandera/tests/test_plugin.py index cc9b26c4fa..7e73aac932 100644 --- a/plugins/flytekit-pandera/tests/test_plugin.py +++ b/plugins/flytekit-pandera/tests/test_plugin.py @@ -55,7 +55,13 @@ def invalid_wf() -> pandera.typing.DataFrame[OutSchema]: def wf_with_df_input(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: return transform2(df=transform1(df=df)) - with pytest.raises(pandera.errors.SchemaError, match="^expected series 'col2' to have type float64, got object"): + with pytest.raises( + pandera.errors.SchemaError, + match=( + "^Encountered error while executing workflow 'test_plugin.wf_with_df_input':\n" + " expected series 'col2' to have type float64, got object" + ), + ): wf_with_df_input(df=invalid_df) # raise error when executing workflow with invalid output @@ -67,7 +73,14 @@ def transform2_noop(df: pandera.typing.DataFrame[IntermediateSchema]) -> pandera def wf_invalid_output(df: pandera.typing.DataFrame[InSchema]) -> pandera.typing.DataFrame[OutSchema]: return transform2_noop(df=transform1(df=df)) - with pytest.raises(TypeError, match="^Failed to convert return value"): + with pytest.raises( + TypeError, + match=( + "^Encountered error while executing workflow 'test_plugin.wf_invalid_output':\n" + " Error encountered while executing 'wf_invalid_output':\n" + " Failed to convert outputs of task" + ), + ): wf_invalid_output(df=valid_df) diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py index bce2ef2653..648ba1c6e0 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py @@ -11,4 +11,4 @@ record_outputs """ -from .task import NotebookTask, record_outputs +from .task import NotebookTask, load_flytedirectory, load_flytefile, load_structureddataset, record_outputs diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 6c160c2690..b1f472e99a 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -1,23 +1,29 @@ import json +import logging import os +import sys +import tempfile import typing from typing import Any import nbformat import papermill as pm +from flyteidl.core.literals_pb2 import Literal as _pb2_Literal from flyteidl.core.literals_pb2 import LiteralMap as _pb2_LiteralMap from google.protobuf import text_format as _text_format from nbconvert import HTMLExporter -from flytekit import FlyteContext, PythonInstanceTask +from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset from flytekit.configuration import SerializationSettings +from flytekit.core import utils from flytekit.core.context_manager import ExecutionParameters from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger from flytekit.models import task as task_models -from flytekit.models.literals import LiteralMap -from flytekit.types.file import HTMLPage, PythonNotebook +from flytekit.models.literals import Literal, LiteralMap +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook T = typing.TypeVar("T") @@ -26,6 +32,8 @@ def _dummy_task_func(): return None +SAVE_AS_LITERAL = (FlyteFile, FlyteDirectory, StructuredDataset) + PAPERMILL_TASK_PREFIX = "pm.nb" @@ -86,6 +94,13 @@ class NotebookTask(PythonInstanceTask[T]): Users can access these notebooks after execution of the task locally or from remote servers. + .. note: + + By default, print statements in your notebook won't be transmitted to the pod logs/stdout. If you would + like to have logs forwarded as the notebook executes, pass the stream_logs argument. Note that notebook + logs can be quite verbose, so ensure you are prepared for any downstream log ingestion costs + (e.g., cloudwatch) + .. todo: Implicit extraction of SparkConfiguration from the notebook is not supported. @@ -114,6 +129,7 @@ def __init__( name: str, notebook_path: str, render_deck: bool = False, + stream_logs: bool = False, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, @@ -135,6 +151,16 @@ def __init__( self._notebook_path = os.path.abspath(notebook_path) self._render_deck = render_deck + self._stream_logs = stream_logs + + # Send the papermill logger to stdout so that it appears in pod logs. Note that papermill doesn't allow + # injecting a logger, so we cannot redirect logs to the flyte child loggers (e.g., the userspace logger) + # and inherit their settings, but we instead must send logs to stdout directly + if self._stream_logs: + papermill_logger = logging.getLogger("papermill") + papermill_logger.addHandler(logging.StreamHandler(sys.stdout)) + # Papermill leaves the default level of DEBUG. We increase it here. + papermill_logger.setLevel(logging.INFO) if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") @@ -235,8 +261,12 @@ def execute(self, **kwargs) -> Any: singleton """ logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.") + for k, v in kwargs.items(): + if isinstance(v, SAVE_AS_LITERAL): + kwargs[k] = save_python_val_to_file(v) + # Execute Notebook via Papermill. - pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs) # type: ignore + pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore outputs = self.extract_outputs(self.output_notebook_path) self.render_nb_html(self.output_notebook_path, self.rendered_output_path) @@ -245,6 +275,7 @@ def execute(self, **kwargs) -> Any: if outputs: m = outputs.literals output_list = [] + for k, type_v in self.python_interface.outputs.items(): if k == self._IMPLICIT_OP_NOTEBOOK: output_list.append(self.output_notebook_path) @@ -254,7 +285,7 @@ def execute(self, **kwargs) -> Any: v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v) output_list.append(v) else: - raise RuntimeError(f"Expected output {k} of type {v} not found in the notebook outputs") + raise TypeError(f"Expected output {k} of type {type_v} not found in the notebook outputs") return tuple(output_list) @@ -287,3 +318,80 @@ def record_outputs(**kwargs) -> str: lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected) m[k] = lit return LiteralMap(literals=m).to_flyte_idl() + + +def save_python_val_to_file(input: Any) -> str: + """Save a python value to a local file as a Flyte literal. + + Args: + input (Any): the python value + + Returns: + str: the path to the file + """ + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(type(input)) + lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected) + + tmp_file = tempfile.mktemp(suffix="bin") + utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file) + return tmp_file + + +def load_python_val_from_file(path: str, dtype: T) -> T: + """Loads a python value from a Flyte literal saved to a local file. + + If the path matches the type, it is returned as is. This enables + reusing the parameters cell for local development. + + Args: + path (str): path to the file + dtype (T): the type of the literal + + Returns: + T: the python value of the literal + """ + if isinstance(path, dtype): + return path + + proto = utils.load_proto_from_file(_pb2_Literal, path) + lit = Literal.from_flyte_idl(proto) + ctx = FlyteContext.current_context() + python_value = TypeEngine.to_python_value(ctx, lit, dtype) + return python_value + + +def load_flytefile(path: str) -> T: + """Loads a FlyteFile from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=FlyteFile) + + +def load_flytedirectory(path: str) -> T: + """Loads a FlyteDirectory from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=FlyteDirectory) + + +def load_structureddataset(path: str) -> T: + """Loads a StructuredDataset from a file. + + Args: + path (str): path to the file + + Returns: + T: the python value of the literal + """ + return load_python_val_from_file(path=path, dtype=StructuredDataset) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 1947d09445..0e54e7082e 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,14 +1,17 @@ import datetime import os +import tempfile +import pandas as pd from flytekitplugins.papermill import NotebookTask from flytekitplugins.pod import Pod from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import kwtypes +from flytekit import StructuredDataset, kwtypes, task from flytekit.configuration import Image, ImageConfig -from flytekit.types.file import PythonNotebook +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, PythonNotebook from .testdata.datatype import X @@ -134,3 +137,38 @@ def test_notebook_pod_task(): nb.get_command(serialization_settings) == nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"] ) + + +def test_flyte_types(): + @task + def create_file() -> FlyteFile: + tmp_file = tempfile.mktemp() + with open(tmp_file, "w") as f: + f.write("abc") + return FlyteFile(path=tmp_file) + + @task + def create_dir() -> FlyteDirectory: + tmp_dir = tempfile.mkdtemp() + with open(os.path.join(tmp_dir, "file.txt"), "w") as f: + f.write("abc") + return FlyteDirectory(path=tmp_dir) + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + return StructuredDataset(dataframe=df) + + ff = create_file() + fd = create_dir() + sd = create_sd() + + nb_name = "nb-types" + nb_types = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset), + outputs=kwtypes(success=bool), + ) + success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd) + assert success is True, "Notebook execution failed" diff --git a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb index ebdf9a3c71..1ad7aaed4a 100644 --- a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb +++ b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb @@ -34,7 +34,6 @@ "outputs": [], "source": [ "from flytekitplugins.papermill import record_outputs\n", - "\n", "record_outputs(square=out)" ] }, @@ -49,7 +48,7 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -63,9 +62,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.10.10" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb new file mode 100644 index 0000000000..824b1d39ae --- /dev/null +++ b/plugins/flytekit-papermill/tests/testdata/nb-types.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "ff = None\n", + "fd = None\n", + "sd = None" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from flytekitplugins.papermill import (\n", + " load_flytefile, load_flytedirectory, load_structureddataset,\n", + " record_outputs\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ff = load_flytefile(ff)\n", + "fd = load_flytedirectory(fd)\n", + "sd = load_structureddataset(sd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read file\n", + "with open(ff.download(), 'r') as f:\n", + " text = f.read()\n", + " assert text == \"abc\", \"Text does not match\"\n", + "\n", + "# check file inside directory\n", + "with open(os.path.join(fd.download(),\"file.txt\"), 'r') as f:\n", + " text = f.read()\n", + " assert text == \"abc\", \"Text does not match\"\n", + "\n", + "# check dataset\n", + "df = sd.open(pd.DataFrame).all()\n", + "expected = pd.DataFrame({\"a\": [1, 2], \"b\": [3, 4]})\n", + "assert df.equals(expected), \"Dataframes do not match\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "outputs" + ] + }, + "outputs": [], + "source": [ + "record_outputs(success=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0b5bf8e577..4290c88ae4 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -7,6 +7,7 @@ from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.basic_dfs import get_storage_options from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -62,12 +63,12 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + uri = flyte_value.uri + kwargs = get_storage_options(ctx.file_access.data_config, uri) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) - return pl.read_parquet(local_dir, use_pyarrow=True) + return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs) + return pl.read_parquet(uri, use_pyarrow=True, storage_options=kwargs) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index 15a195e5d5..23fbf6d441 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -1,3 +1,5 @@ +import tempfile + import pandas as pd import polars as pl from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer @@ -79,3 +81,13 @@ def create_sd() -> StructuredDataset: sd = create_sd() polars_df = sd.open(pl.DataFrame).all() assert pl.DataFrame(data).frame_equal(polars_df) + + tmp = tempfile.mktemp() + pl.DataFrame(data).write_parquet(tmp) + + @task + def t1(sd: StructuredDataset) -> pl.DataFrame: + return sd.open(pd.DataFrame).all() + + sd = StructuredDataset(uri=tmp) + t1(sd=sd).frame_equal(polars_df) diff --git a/plugins/flytekit-spark/Dockerfile b/plugins/flytekit-spark/Dockerfile new file mode 100644 index 0000000000..0789df45b7 --- /dev/null +++ b/plugins/flytekit-spark/Dockerfile @@ -0,0 +1,14 @@ +# https://github.com/apache/spark/blob/master/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/bindings/python/Dockerfile +FROM apache/spark-py:3.3.1 +LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit + +USER 0 +RUN ln -s /usr/bin/python3 /usr/bin/python + +ARG VERSION +RUN pip install flytekitplugins-spark==$VERSION +RUN pip install flytekit==$VERSION + +RUN chown -R ${spark_uid}:${spark_uid} /root +WORKDIR /root +USER ${spark_uid} diff --git a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py index e48778ad70..4afb257f9d 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/pyspark_transformers.py @@ -1,4 +1,3 @@ -import pathlib from typing import Type from pyspark.ml import PipelineModel @@ -24,22 +23,17 @@ def to_literal( python_type: Type[PipelineModel], expected: LiteralType, ) -> Literal: - local_path = ctx.file_access.get_random_local_path() - pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - python_val.save(local_path) - + # Must write to remote directory remote_dir = ctx.file_access.get_random_remote_directory() - ctx.file_access.upload_directory(local_path, remote_dir) + python_val.write().overwrite().save(remote_dir) return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[PipelineModel] ) -> PipelineModel: - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.download_directory(lv.scalar.blob.uri, local_dir) - - return PipelineModel.load(local_dir) + remote_dir = lv.scalar.blob.uri + return PipelineModel.load(remote_dir) TypeEngine.register(PySparkPipelineModelTransformer()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7b32e9f28b..564e55778f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -1,15 +1,15 @@ import os -import typing from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union, cast from google.protobuf.json_format import MessageToDict from pyspark.sql import SparkSession from flytekit import FlyteContextManager, PythonFunctionTask -from flytekit.configuration import SerializationSettings +from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins +from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -48,7 +48,7 @@ class Databricks(Spark): databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. """ - databricks_conf: typing.Optional[Dict[str, typing.Union[str, dict]]] = None + databricks_conf: Optional[Dict[str, Union[str, dict]]] = None databricks_token: Optional[str] = None databricks_instance: Optional[str] = None @@ -56,7 +56,7 @@ class Databricks(Spark): # This method does not reset the SparkSession since it's a bit hard to handle multiple # Spark sessions in a single application as it's described in: # https://stackoverflow.com/questions/41491972/how-can-i-tear-down-a-sparksession-and-create-a-new-one-within-one-application. -def new_spark_session(name: str, conf: typing.Dict[str, str] = None): +def new_spark_session(name: str, conf: Dict[str, str] = None): """ Optionally creates a new spark session and returns it. In cluster mode (running in hosted flyte, this will disregard the spark conf passed in) @@ -99,26 +99,43 @@ class PysparkFunctionTask(PythonFunctionTask[Spark]): _SPARK_TASK_TYPE = "spark" - def __init__(self, task_config: Spark, task_function: Callable, **kwargs): + def __init__( + self, + task_config: Spark, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + self.sess: Optional[SparkSession] = None + self._default_executor_path: Optional[str] = None + self._default_applications_path: Optional[str] = None + + if isinstance(container_image, ImageSpec): + if container_image.base_image is None: + img = f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + container_image.base_image = img + # default executor path and applications path in apache/spark-py:3.3.1 + self._default_executor_path = "/usr/bin/python3" + self._default_applications_path = "local:///usr/local/bin/entrypoint.py" super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, task_function=task_function, + container_image=container_image, **kwargs, ) - self.sess: Optional[SparkSession] = None def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job = SparkJob( spark_conf=self.task_config.spark_conf, hadoop_conf=self.task_config.hadoop_conf, - application_file="local://" + settings.entrypoint_settings.path, - executor_path=settings.python_interpreter, + application_file=self._default_applications_path or "local://" + settings.entrypoint_settings.path, + executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, ) if isinstance(self.task_config, Databricks): - cfg = typing.cast(Databricks, self.task_config) + cfg = cast(Databricks, self.task_config) job._databricks_conf = cfg.databricks_conf job._databricks_token = cfg.databricks_token job._databricks_instance = cfg.databricks_instance diff --git a/plugins/flytekit-spark/tests/test_pyspark_transformers.py b/plugins/flytekit-spark/tests/test_pyspark_transformers.py index cb527e16ef..212af454dd 100644 --- a/plugins/flytekit-spark/tests/test_pyspark_transformers.py +++ b/plugins/flytekit-spark/tests/test_pyspark_transformers.py @@ -6,13 +6,24 @@ import flytekit from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.types.structured.structured_dataset import StructuredDatasetTransformerEngine def test_type_resolution(): assert type(TypeEngine.get_transformer(PipelineModel)) == PySparkPipelineModelTransformer +def test_basic_get(): + + ctx = FlyteContextManager.current_context() + e = StructuredDatasetTransformerEngine() + prot = e._protocol_from_type_or_prefix(ctx, pyspark.sql.DataFrame, uri="/tmp/blah") + en = e.get_encoder(pyspark.sql.DataFrame, prot, "") + assert en is not None + + def test_pipeline_model_compatibility(): @task(task_config=Spark()) def my_dataset() -> pyspark.sql.DataFrame: diff --git a/plugins/flytekit-sqlalchemy/Dockerfile b/plugins/flytekit-sqlalchemy/Dockerfile new file mode 100644 index 0000000000..ed1a644d8f --- /dev/null +++ b/plugins/flytekit-sqlalchemy/Dockerfile @@ -0,0 +1,19 @@ +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-buster + +WORKDIR /root +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 +ENV PYTHONPATH /root + +ARG VERSION + +RUN pip install sqlalchemy \ + psycopg2-binary \ + pymysql \ + flytekitplugins-sqlalchemy==$VERSION \ + flytekit==$VERSION + +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.10 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.10 deleted file mode 100644 index 791b13fa53..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.10 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.10-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.10 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.7 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.7 deleted file mode 100644 index 879656adb5..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.7 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.7-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.7 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.8 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.8 deleted file mode 100644 index 93b7048e1b..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.8 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.8-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.8 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/flytekit-sqlalchemy/Dockerfile.py3.9 b/plugins/flytekit-sqlalchemy/Dockerfile.py3.9 deleted file mode 100644 index 039956dcd1..0000000000 --- a/plugins/flytekit-sqlalchemy/Dockerfile.py3.9 +++ /dev/null @@ -1,25 +0,0 @@ -FROM python:3.9-slim-buster - -WORKDIR /app -ENV VENV /opt/venv -ENV LANG C.UTF-8 -ENV LC_ALL C.UTF-8 -ENV PYTHONPATH /app - -RUN pip install awscli -RUN pip install gsutil - -ARG VERSION - -# Virtual environment -RUN python3.9 -m venv ${VENV} -RUN ${VENV}/bin/pip install wheel - -RUN ${VENV}/bin/pip install sqlalchemy psycopg2-binary pymysql flytekitplugins-sqlalchemy==$VERSION flytekit==$VERSION - -# Copy over the helper script that the SDK relies on -RUN cp ${VENV}/bin/flytekit_venv /usr/local/bin -RUN chmod a+x /usr/local/bin/flytekit_venv - -# Enable the virtualenv for this image. Note this relies on the VENV variable we've set in this image. -ENTRYPOINT ["/usr/local/bin/flytekit_venv"] diff --git a/plugins/setup.py b/plugins/setup.py index 1b47cc58e0..d607468081 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -18,6 +18,8 @@ "flytekitplugins-dbt": "flytekit-dbt", "flytekitplugins-dolt": "flytekit-dolt", "flytekitplugins-duckdb": "flytekit-duckdb", + "flytekitplugins-data-fsspec": "flytekit-data-fsspec", + "flytekitplugins-envd": "flytekit-envd", "flytekitplugins-great_expectations": "flytekit-greatexpectations", "flytekitplugins-hive": "flytekit-hive", "flytekitplugins-pod": "flytekit-k8s-pod", @@ -30,6 +32,7 @@ "flytekitplugins-onnxpytorch": "flytekit-onnx-pytorch", "flytekitplugins-pandera": "flytekit-pandera", "flytekitplugins-papermill": "flytekit-papermill", + "flytekitplugins-polars": "flytekit-polars", "flytekitplugins-ray": "flytekit-ray", "flytekitplugins-snowflake": "flytekit-snowflake", "flytekitplugins-spark": "flytekit-spark", diff --git a/setup.py b/setup.py index 6519fbccd1..36acb1a81d 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,8 @@ ] }, install_requires=[ - "flyteidl>=1.2.9,<1.3.0", + "googleapis-common-protos>=1.57", + "flyteidl>=1.2.10,<1.3.0", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -42,6 +43,10 @@ "grpcio>=1.43.0,!=1.45.0,<1.49.1,<2.0", "grpcio-status>=1.43,!=1.45.0,<1.49.1", "importlib-metadata", + "fsspec>=2023.1.0", + "adlfs", + "s3fs>=0.6.0", + "gcsfs", "pyopenssl", "joblib", "protobuf>=3.6.1,<4", @@ -56,7 +61,6 @@ "statsd>=3.0.0,<4.0.0", "urllib3>=1.22,<2.0.0", "wrapt>=1.0.0,<2.0.0", - "retry==0.9.2", "dataclasses-json>=0.5.2", "marshmallow-jsonschema>=0.12.0", "natsort>=7.0.1", @@ -73,6 +77,8 @@ "numpy<1.24.0", "gitpython", "kubernetes>=12.0.1", + "rich", + "rich_click", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 45d50a2fc5..1a24cccb61 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -2,6 +2,7 @@ import typing from collections import OrderedDict +import fsspec import mock import pytest from flyteidl.core.errors_pb2 import ErrorDocument @@ -10,15 +11,12 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs -from flytekit.core.data_persistence import DiskPersistence from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.scopes import system_entry_point -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models from flytekit.models.core import execution as execution_models @@ -311,7 +309,22 @@ def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM -def test_persist_ss(): +def test_setup_disk_prefix(): + with setup_execution("qwerty") as ctx: + assert isinstance(ctx.file_access._default_remote, fsspec.AbstractFileSystem) + assert ctx.file_access._default_remote.protocol == "file" + + +def test_setup_cloud_prefix(): + with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert ctx.file_access._default_remote.protocol[0] == "s3" + + with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert "gs" in ctx.file_access._default_remote.protocol + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +def test_persist_ss(mock_gcs): default_img = Image(name="default", fqn="test", tag="tag") ss = SerializationSettings( project="proj1", @@ -327,19 +340,6 @@ def test_persist_ss(): assert ctx.serialization_settings.domain == "dom" -def test_setup_disk_prefix(): - with setup_execution("qwerty") as ctx: - assert isinstance(ctx.file_access._default_remote, DiskPersistence) - - -def test_setup_cloud_prefix(): - with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, S3Persistence) - - with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, GCSPersistence) - - def test_normalize_inputs(): assert normalize_inputs("{{.rawOutputDataPrefix}}", "{{.checkpointOutputPrefix}}", "{{.prevCheckpointPrefix}}") == ( None, diff --git a/tests/flytekit/unit/cli/pyflyte/imageSpec.yaml b/tests/flytekit/unit/cli/pyflyte/imageSpec.yaml new file mode 100644 index 0000000000..ba67ab4b91 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/imageSpec.yaml @@ -0,0 +1,2 @@ +python_version: 3.8 +builder: test diff --git a/tests/flytekit/unit/cli/pyflyte/image_spec_wf.py b/tests/flytekit/unit/cli/pyflyte/image_spec_wf.py new file mode 100644 index 0000000000..9d5d74ff1b --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/image_spec_wf.py @@ -0,0 +1,20 @@ +from flytekit import task, workflow +from flytekit.image_spec import ImageSpec + +image_spec = ImageSpec(packages=["numpy", "pandas"], apt_packages=["git"], registry="", builder="test") + + +@task(container_image=image_spec) +def t2() -> str: + return "flyte" + + +@task(container_image=image_spec) +def t1() -> str: + return "flyte" + + +@workflow +def wf(): + t1() + t2() diff --git a/tests/flytekit/unit/cli/pyflyte/test_backfill.py b/tests/flytekit/unit/cli/pyflyte/test_backfill.py index 8389295af2..0fd328e638 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_backfill.py +++ b/tests/flytekit/unit/cli/pyflyte/test_backfill.py @@ -39,7 +39,6 @@ def test_pyflyte_backfill(mock_remote): "--backfill-window", "5 day", "daily", - "--dry-run", ], ) assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/test_build.py b/tests/flytekit/unit/cli/pyflyte/test_build.py new file mode 100644 index 0000000000..7b4b26fb69 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_build.py @@ -0,0 +1,31 @@ +import os + +from click.testing import CliRunner + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder + +WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_spec_wf.py") + + +def test_build(): + class TestImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + ... + + ImageBuildEngine.register("test", TestImageSpecBuilder()) + runner = CliRunner() + result = runner.invoke(pyflyte.main, ["build", "--fast", WORKFLOW_FILE, "wf"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", WORKFLOW_FILE, "wf"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", WORKFLOW_FILE, "wf"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", "--help"]) + assert result.exit_code == 0 + + result = runner.invoke(pyflyte.main, ["build", "../", "wf"]) + assert result.exit_code == 1 diff --git a/tests/flytekit/unit/cli/pyflyte/test_launchplan.py b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py new file mode 100644 index 0000000000..1a461bfd35 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py @@ -0,0 +1,34 @@ +import pytest +from click.testing import CliRunner +from mock import mock + +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.remote import FlyteRemote + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@pytest.mark.parametrize( + ("action", "expected_state"), + [ + ("activate", "ACTIVE"), + ("deactivate", "INACTIVE"), + ], +) +def test_pyflyte_launchplan(mock_remote, action, expected_state): + mock_remote.generate_console_url.return_value = "ex" + runner = CliRunner() + with runner.isolated_filesystem(): + result = runner.invoke( + pyflyte.main, + [ + "launchplan", + f"--{action}", + "-p", + "flytesnacks", + "-d", + "development", + "daily", + ], + ) + assert result.exit_code == 0 + assert f"Launchplan was set to {expected_state}: " in result.output diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index e3ccb1d803..4d8251fc57 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -1,7 +1,6 @@ import os import shutil -import pytest from click.testing import CliRunner import flytekit @@ -10,7 +9,6 @@ from flytekit import TaskMetadata from flytekit.clis.sdk_in_container import pyflyte from flytekit.core import context_manager -from flytekit.exceptions.user import FlyteValidationException from flytekit.models.admin.workflow import WorkflowSpec from flytekit.models.core.identifier import Identifier, ResourceType from flytekit.models.launch_plan import LaunchPlan @@ -104,56 +102,6 @@ def test_package_with_fast_registration(): shutil.rmtree("core") -def test_duplicate_registrable_entities(): - @flytekit.task - def t_1(): - pass - - # Keep a reference to a task named `t_1` that's going to be duplicated below - reference_1 = t_1 - - @flytekit.workflow - def wf_1(): - return t_1() - - # Duplicate definition of `t_1` - @flytekit.task - def t_1() -> str: - pass - - # Keep a second reference to the duplicate task named `t_1` so that we can use it later - reference_2 = t_1 - - @flytekit.task - def non_duplicate_task(): - pass - - @flytekit.workflow - def wf_2(): - non_duplicate_task() - # refers to the second definition of `t_1` - return t_1() - - ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings( - flytekit.configuration.SerializationSettings( - project="p", - domain="d", - version="v", - image_config=flytekit.configuration.ImageConfig( - default_image=flytekit.configuration.Image("def", "docker.io/def", "latest") - ), - ) - ) - - context_manager.FlyteEntities.entities = [reference_1, wf_1, "str", reference_2, non_duplicate_task, wf_2, "str"] - - with pytest.raises( - FlyteValidationException, - match=r"Multiple definitions of the following tasks were found: \['pyflyte.test_package.t_1'\]", - ): - flytekit.tools.serialize_helpers.get_registrable_entities(ctx) - - def test_package(): runner = CliRunner() with runner.isolated_filesystem(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index a6c0bb91d8..0a371b76d1 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -92,7 +92,7 @@ def test_non_fast_register(mock_client, mock_remote): def test_non_fast_register_require_version(mock_client, mock_remote): mock_remote._client = mock_client mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" - mock_remote.return_value._upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" + mock_remote.return_value.upload_file.return_value = "dummy_md5_bytes", "dummy_native_url" runner = CliRunner() context_manager.FlyteEntities.entities.clear() with runner.isolated_filesystem(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index e963f3dfc6..735df4af2c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,6 +1,8 @@ import functools +import json import os import pathlib +import tempfile import typing from datetime import datetime, timedelta from enum import Enum @@ -8,6 +10,7 @@ import click import mock import pytest +import yaml from click.testing import CliRunner from flytekit import FlyteContextManager @@ -21,12 +24,14 @@ DurationParamType, FileParamType, FlyteLiteralConverter, + JsonParamType, get_entities_in_file, run_command, ) from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine +from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder from flytekit.models.types import SimpleType from flytekit.remote import FlyteRemote @@ -62,8 +67,19 @@ def test_imperative_wf(): assert result.exit_code == 0 +def test_copy_all_files(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", "--copy-all", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() + parquet_file = os.path.join(DIR_NAME, "testdata/df.parquet") result = runner.invoke( pyflyte.main, [ @@ -83,7 +99,7 @@ def test_pyflyte_run_cli(): "--f", '{"x":1.0, "y":2.0}', "--g", - os.path.join(DIR_NAME, "testdata/df.parquet"), + parquet_file, "--i", "2020-05-01", "--j", @@ -97,6 +113,10 @@ def test_pyflyte_run_cli(): "--image", os.path.join(DIR_NAME, "testdata"), "--h", + "--n", + json.dumps([{"x": parquet_file}]), + "--o", + json.dumps({"x": [parquet_file]}), ], catch_exceptions=False, ) @@ -148,19 +168,19 @@ def test_union_type2(input): def test_union_type_with_invalid_input(): runner = CliRunner() - with pytest.raises(ValueError, match="Failed to convert python type typing.Union"): - runner.invoke( - pyflyte.main, - [ - "--verbose", - "run", - os.path.join(DIR_NAME, "workflow.py"), - "test_union2", - "--a", - "hello", - ], - catch_exceptions=False, - ) + result = runner.invoke( + pyflyte.main, + [ + "--verbose", + "run", + os.path.join(DIR_NAME, "workflow.py"), + "test_union2", + "--a", + "hello", + ], + catch_exceptions=False, + ) + assert result.exit_code == 2 def test_get_entities_in_file(): @@ -216,6 +236,7 @@ def test_list_default_arguments(wf_path): ], catch_exceptions=False, ) + print(result.stdout) assert result.exit_code == 0 @@ -239,6 +260,17 @@ def test_list_default_arguments(wf_path): images=[Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"), Image(name="abc", fqn="docker.io/abc", tag=None)], ) +ic_result_4 = ImageConfig( + default_image=Image(name="default", fqn="flytekit", tag="4VC-c-UDrUvfySJ0aS3qCw.."), + images=[ + Image(name="default", fqn="flytekit", tag="4VC-c-UDrUvfySJ0aS3qCw.."), + Image(name="xyz", fqn="docker.io/xyz", tag="latest"), + Image(name="abc", fqn="docker.io/abc", tag=None), + ], +) + +IMAGE_SPEC = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imageSpec.yaml") + @pytest.mark.parametrize( "image_string, leaf_configuration_file_name, final_image_config", @@ -246,9 +278,16 @@ def test_list_default_arguments(wf_path): ("ghcr.io/flyteorg/mydefault:py3.9-latest", "no_images.yaml", ic_result_1), ("asdf=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_2), ("xyz=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_3), + (IMAGE_SPEC, "sample.yaml", ic_result_4), ], ) def test_pyflyte_run_run(image_string, leaf_configuration_file_name, final_image_config): + class TestImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + ... + + ImageBuildEngine.register("test", TestImageSpecBuilder()) + @task def a(): ... @@ -362,3 +401,34 @@ def test_datetime_type(): v = t.convert("now", None, None) assert v.day == now.day assert v.month == now.month + + +def test_json_type(): + t = JsonParamType() + assert t.convert(value='{"a": "b"}', param=None, ctx=None) == {"a": "b"} + + with pytest.raises(click.BadParameter): + t.convert(None, None, None) + + # test that it loads a json file + with tempfile.NamedTemporaryFile("w", delete=False) as f: + json.dump({"a": "b"}, f) + f.flush() + assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} + + # test that if the file is not a valid json, it raises an error + with tempfile.NamedTemporaryFile("w", delete=False) as f: + f.write("asdf") + f.flush() + with pytest.raises(click.BadParameter): + t.convert(value=f.name, param="asdf", ctx=None) + + # test if the file does not exist + with pytest.raises(click.BadParameter): + t.convert(value="asdf", param=None, ctx=None) + + # test if the file is yaml and ends with .yaml it works correctly + with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f: + yaml.dump({"a": "b"}, f) + f.flush() + assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} diff --git a/tests/flytekit/unit/cli/pyflyte/test_serve.py b/tests/flytekit/unit/cli/pyflyte/test_serve.py new file mode 100644 index 0000000000..f3ecbef547 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_serve.py @@ -0,0 +1,9 @@ +from click.testing import CliRunner + +from flytekit.clis.sdk_in_container import pyflyte + + +def test_pyflyte_serve(): + runner = CliRunner() + result = runner.invoke(pyflyte.main, ["serve", "--port", "0", "--timeout", "1"], catch_exceptions=False) + assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 85438eb00d..01621a6a01 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -56,8 +56,10 @@ def print_all( k: Color, l: dict, m: dict, + n: typing.List[typing.Dict[str, FlyteFile]], + o: typing.Dict[str, typing.List[FlyteFile]], ): - print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}") + print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}") @task @@ -84,6 +86,8 @@ def my_wf( j: datetime.timedelta, k: Color, l: dict, + n: typing.List[typing.Dict[str, FlyteFile]], + o: typing.Dict[str, typing.List[FlyteFile]], remote: pd.DataFrame, image: StructuredDataset, m: dict = {"hello": "world"}, @@ -91,5 +95,5 @@ def my_wf( x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks show_sd(in_sd=x) show_sd(in_sd=image) - print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m) + print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o) return x diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 4c968cf0bd..fdbddb2ebe 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -8,10 +8,12 @@ ClientConfig, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, StaticClientConfigStore, ) from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import DeviceCodeResponse ENDPOINT = "example.com" @@ -65,31 +67,69 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -def test_get_basic_authorization_header(): - header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc") - assert header == "Basic Y2xpZW50X2lkOmFiYw==" - +@patch("flytekit.clients.auth.token_client.requests") +def test_client_creds_authenticator(mock_requests): + authn = ClientCredentialsAuthenticator( + ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store + ) -@patch("flytekit.clients.auth.authenticator.requests") -def test_get_token(mock_requests): response = MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response - access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"]) - assert access == "abc" - assert expiration == 60 + authn.refresh_credentials() + expected_scopes = static_cfg_store.get_client_config().scopes + assert authn._creds + assert authn._scopes == expected_scopes -@patch("flytekit.clients.auth.authenticator.requests") -def test_client_creds_authenticator(mock_requests): - authn = ClientCredentialsAuthenticator( - ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.token_client.get_device_code") +@patch("flytekit.clients.auth.token_client.poll_token_endpoint") +def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): + with pytest.raises(AuthenticationError): + DeviceCodeAuthenticator( + ENDPOINT, + static_cfg_store, + audience="x", + ) + + cfg_store = StaticClientConfigStore( + ClientConfig( + token_endpoint="token_endpoint", + authorization_endpoint="auth_endpoint", + redirect_uri="redirect_uri", + client_id="client", + device_authorization_endpoint="dev", + ) + ) + authn = DeviceCodeAuthenticator( + ENDPOINT, + cfg_store, + audience="x", ) + device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0) + poll_mock.return_value = ("access", 100) + authn.refresh_credentials() + assert authn._creds + + +@patch("flytekit.clients.auth.token_client.requests") +def test_client_creds_authenticator_with_custom_scopes(mock_requests): + expected_scopes = ["foo", "baz"] + authn = ClientCredentialsAuthenticator( + ENDPOINT, + client_id="client", + client_secret="secret", + cfg_store=static_cfg_store, + scopes=expected_scopes, + ) response = MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") mock_requests.post.return_value = response authn.refresh_credentials() + assert authn._creds + assert authn._scopes == expected_scopes diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py new file mode 100644 index 0000000000..c22284cd38 --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -0,0 +1,81 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import ( + DeviceCodeResponse, + error_auth_pending, + get_basic_authorization_header, + get_device_code, + get_token, + poll_token_endpoint, +) + + +def test_get_basic_authorization_header(): + header = get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + header = get_basic_authorization_header("client_id", "abc%%$?\\/\\/") + assert header == "Basic Y2xpZW50X2lkOmFiYyUyNSUyNSUyNCUzRiU1QyUyRiU1QyUyRg==" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"]) + assert access == "abc" + assert expiration == 60 + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_device_code(mock_requests): + response = MagicMock() + response.ok = False + mock_requests.post.return_value = response + with pytest.raises(AuthenticationError): + get_device_code("test.com", "test") + + response.ok = True + response.json.return_value = { + "device_code": "code", + "user_code": "BNDJJFXL", + "verification_uri": "url", + "verification_uri_complete": "url", + "expires_in": 600, + "interval": 5, + } + mock_requests.post.return_value = response + c = get_device_code("test.com", "test") + assert c + assert c.device_code == "code" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_poll_token_endpoint(mock_requests): + response = MagicMock() + response.ok = False + response.json.return_value = {"error": error_auth_pending} + mock_requests.post.return_value = response + + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1 + ) + with pytest.raises(AuthenticationError): + poll_token_endpoint(r, "test.com", "test") + + response = MagicMock() + response.ok = True + response.json.return_value = {"access_token": "abc", "expires_in": 60} + mock_requests.post.return_value = response + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0 + ) + t, e = poll_token_endpoint(r, "test.com", "test") + assert t + assert e diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 8f14de730e..3bd57918f4 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -9,6 +9,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.auth.exceptions import AuthenticationError @@ -31,6 +32,8 @@ OAUTH_AUTHORIZE = "https://your.domain.io/oauth2/authorize" +DEVICE_AUTH_ENDPOINT = "https://your.domain.io/..." + def get_auth_service_mock() -> MagicMock: auth_stub_mock = MagicMock() @@ -66,13 +69,14 @@ def test_remote_client_config_store(mock_auth_service: MagicMock): assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE -def get_client_config() -> ClientConfigStore: +def get_client_config(**kwargs) -> ClientConfigStore: cfg_store = MagicMock() cfg_store.get_client_config.return_value = ClientConfig( token_endpoint=TOKEN_ENDPOINT, authorization_endpoint=OAUTH_AUTHORIZE, redirect_uri=REDIRECT_URI, client_id=CLIENT_ID, + **kwargs ) return cfg_store @@ -135,6 +139,15 @@ def test_get_authenticator_cmd(): assert authn._cmd == ["echo"] +def test_get_authenticator_deviceflow(): + cfg = PlatformConfig(auth_mode=AuthType.DEVICEFLOW) + with pytest.raises(AuthenticationError): + get_authenticator(cfg, get_client_config()) + + authn = get_authenticator(cfg, get_client_config(device_authorization_endpoint=DEVICE_AUTH_ENDPOINT)) + assert isinstance(authn, DeviceCodeAuthenticator) + + def test_wrap_exceptions_channel(): ch = MagicMock() out_ch = wrap_exceptions_channel(PlatformConfig(), ch) diff --git a/tests/flytekit/unit/configuration/test_image_config.py b/tests/flytekit/unit/configuration/test_image_config.py index 84c767f8fb..c14832df3c 100644 --- a/tests/flytekit/unit/configuration/test_image_config.py +++ b/tests/flytekit/unit/configuration/test_image_config.py @@ -60,3 +60,7 @@ def test_image_create(): ic = ImageConfig.from_images("cr.flyte.org/im/g:latest") assert ic.default_image.fqn == "cr.flyte.org/im/g" + + +def test_get_version_suffix(): + assert DefaultImages.get_version_suffix() == "latest" diff --git a/tests/flytekit/unit/core/flyte_functools/test_decorators.py b/tests/flytekit/unit/core/flyte_functools/test_decorators.py index 3edd547c8a..20e55e9d3c 100644 --- a/tests/flytekit/unit/core/flyte_functools/test_decorators.py +++ b/tests/flytekit/unit/core/flyte_functools/test_decorators.py @@ -39,7 +39,7 @@ def test_wrapped_tasks_error(capfd): ) out = capfd.readouterr().out - assert out.replace("\r", "").strip().split("\n") == [ + assert out.replace("\r", "").strip().split("\n")[:5] == [ "before running my_task", "try running my_task", "error running my_task: my_task failed with input: 0", @@ -74,11 +74,11 @@ def test_unwrapped_task(): capture_output=True, ) error = completed_process.stderr - error_str = error.strip().split("\n")[-1] - assert ( - "TaskFunction cannot be a nested/inner or local function." - " It should be accessible at a module level for Flyte to execute it." in error_str - ) + error_str = "" + for line in error.strip().split("\n"): + if line.startswith("ValueError"): + error_str += line + assert error_str.startswith("ValueError: TaskFunction cannot be a nested/inner or local function.") @pytest.mark.parametrize("script", ["nested_function.py", "nested_wrapped_function.py"]) @@ -90,5 +90,8 @@ def test_nested_function(script): capture_output=True, ) error = completed_process.stderr - error_str = error.strip().split("\n")[-1] + error_str = "" + for line in error.strip().split("\n"): + if line.startswith("ValueError"): + error_str += line assert error_str.startswith("ValueError: TaskFunction cannot be a nested/inner or local function.") diff --git a/tests/flytekit/unit/core/image_spec/__init__.py b/tests/flytekit/unit/core/image_spec/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/core/image_spec/test_image_spec.py b/tests/flytekit/unit/core/image_spec/test_image_spec.py new file mode 100644 index 0000000000..be8ea61427 --- /dev/null +++ b/tests/flytekit/unit/core/image_spec/test_image_spec.py @@ -0,0 +1,50 @@ +import os + +import pytest + +from flytekit.core import context_manager +from flytekit.core.context_manager import ExecutionState +from flytekit.image_spec import ImageSpec +from flytekit.image_spec.image_spec import _F_IMG_ID, ImageBuildEngine, ImageSpecBuilder, calculate_hash_from_image_spec + + +def test_image_spec(): + image_spec = ImageSpec( + packages=["pandas"], + apt_packages=["git"], + python_version="3.8", + registry="", + base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", + ) + + assert image_spec.python_version == "3.8" + assert image_spec.base_image == "cr.flyte.org/flyteorg/flytekit:py3.8-latest" + assert image_spec.packages == ["pandas"] + assert image_spec.apt_packages == ["git"] + assert image_spec.registry == "" + assert image_spec.name == "flytekit" + assert image_spec.builder == "envd" + assert image_spec.source_root is None + assert image_spec.env is None + assert image_spec.is_container() is True + assert image_spec.image_name() == "flytekit:yZ8jICcDTLoDArmNHbWNwg.." + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) + ): + os.environ[_F_IMG_ID] = "flytekit:123" + assert image_spec.is_container() is False + + class DummyImageSpecBuilder(ImageSpecBuilder): + def build_image(self, img): + ... + + ImageBuildEngine.register("dummy", DummyImageSpecBuilder()) + ImageBuildEngine._REGISTRY["dummy"].build_image(image_spec) + assert "dummy" in ImageBuildEngine._REGISTRY + assert calculate_hash_from_image_spec(image_spec) == "yZ8jICcDTLoDArmNHbWNwg.." + assert image_spec.exist() is False + + with pytest.raises(Exception): + image_spec.builder = "flyte" + ImageBuildEngine.build(image_spec) diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index 2add1b9e7d..f412af8ca8 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import pytest @@ -36,11 +37,14 @@ def test_sync_checkpoint_save_file(tmpdir): def test_sync_checkpoint_save_filepath(tmpdir): - td_path = Path(tmpdir) - cp = SyncCheckpoint(checkpoint_dest=tmpdir) - dst_path = td_path.joinpath("test") + src_path = Path(os.path.join(tmpdir, "src")) + src_path.mkdir(parents=True, exist_ok=True) + chkpnt_path = Path(os.path.join(tmpdir, "dest")) + chkpnt_path.mkdir() + cp = SyncCheckpoint(checkpoint_dest=str(chkpnt_path)) + dst_path = chkpnt_path.joinpath("test") assert not dst_path.exists() - inp = td_path.joinpath("test") + inp = src_path.joinpath("test") with inp.open("wb") as f: f.write(b"blah") cp.save(inp) @@ -84,16 +88,6 @@ def test_sync_checkpoint_restore_default_path(tmpdir): assert cp.restore() == cp._prev_download_path -def test_sync_checkpoint_read_empty_dir(tmpdir): - td_path = Path(tmpdir) - dest = td_path.joinpath("dest") - dest.mkdir() - src = td_path.joinpath("src") - src.mkdir() - cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) - assert cp.read() is None - - def test_sync_checkpoint_read_multiple_files(tmpdir): """ Read can only work with one file. diff --git a/tests/flytekit/unit/core/test_checkpointer.py b/tests/flytekit/unit/core/test_checkpointer.py index dda786545b..c8080d3b82 100644 --- a/tests/flytekit/unit/core/test_checkpointer.py +++ b/tests/flytekit/unit/core/test_checkpointer.py @@ -1,5 +1,4 @@ import typing -from pathlib import Path import py.path @@ -44,18 +43,6 @@ def test_sync_checkpoint_reader(tmpdir: py.path.local): assert outputs.listdir() == [expected_dst] -def test_sync_checkpoint_folder(tmpdir: py.path.local): - inputs, input_file, outputs = create_folder_write_file(tmpdir) - cp = SyncCheckpoint(checkpoint_dest=str(outputs)) - # Lets try to restore - should not work! - assert not cp.restore("/tmp") - # Now save - cp.save(Path(str(inputs))) - # Expect file in tmpdir - expected_dst = outputs.join(CHECKPOINT_FILE) - assert outputs.listdir() == [expected_dst] - - def test_sync_checkpoint_previous(tmpdir: py.path.local): inputs, input_file, outputs = create_folder_write_file(tmpdir) cp = SyncCheckpoint(checkpoint_dest=str(outputs), checkpoint_src=str(inputs)) diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py new file mode 100644 index 0000000000..599061d403 --- /dev/null +++ b/tests/flytekit/unit/core/test_container_task.py @@ -0,0 +1,80 @@ +from kubernetes.client.models import ( + V1Affinity, + V1NodeAffinity, + V1NodeSelectorRequirement, + V1NodeSelectorTerm, + V1PodSpec, + V1PreferredSchedulingTerm, + V1Toleration, +) + +from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig, SerializationSettings +from flytekit.core.container_task import ContainerTask +from flytekit.core.pod_template import PodTemplate +from flytekit.tools.translator import get_serializable_task + + +def test_pod_template(): + ps = V1PodSpec( + containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")] + ) + ps.runtime_class_name = "nvidia" + nsr = V1NodeSelectorRequirement(key="nvidia.com/gpu.memory", operator="Gt", values=["10000"]) + pref_sched = V1PreferredSchedulingTerm(preference=V1NodeSelectorTerm(match_expressions=[nsr]), weight=1) + ps.affinity = V1Affinity( + node_affinity=V1NodeAffinity(preferred_during_scheduling_ignored_during_execution=[pref_sched]) + ) + pt = PodTemplate(pod_spec=ps, labels={"somelabel": "foobar"}) + + image = "ghcr.io/flyteorg/rawcontainers-shell:v2" + cmd = [ + "./calculate-ellipse-area.sh", + "{{.inputs.a}}", + "{{.inputs.b}}", + "/var/outputs", + ] + ct = ContainerTask( + name="ellipse-area-metadata-shell", + input_data_dir="/var/inputs", + output_data_dir="/var/outputs", + inputs=kwtypes(a=float, b=float), + outputs=kwtypes(area=float, metadata=str), + image=image, + command=cmd, + pod_template=pt, + pod_template_name="my-base-template", + ) + + assert ct.metadata.pod_template_name == "my-base-template" + + default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") + default_image_config = ImageConfig(default_image=default_image) + default_serialization_settings = SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config + ) + + container = ct.get_container(default_serialization_settings) + assert container is None + + k8s_pod = ct.get_k8s_pod(default_serialization_settings) + assert k8s_pod.metadata.labels == {"somelabel": "foobar"} + + primary_container = k8s_pod.pod_spec["containers"][0] + + assert primary_container["image"] == image + assert primary_container["command"] == cmd + + ################# + # Test Serialization + ################# + ts = get_serializable_task(default_serialization_settings, ct) + assert ts.template.metadata.pod_template_name == "my-base-template" + assert ts.template.container is None + assert ts.template.k8s_pod is not None + serialized_pod_spec = ts.template.k8s_pod.pod_spec + assert serialized_pod_spec["affinity"]["nodeAffinity"] is not None + assert serialized_pod_spec["tolerations"] == [ + {"effect": "NoSchedule", "key": "nvidia.com/gpu", "operator": "Exists"} + ] + assert serialized_pod_spec["runtimeClassName"] == "nvidia" diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 6e68c9d4be..fe535b761c 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -115,18 +115,17 @@ def test_secrets_manager_default(): def test_secrets_manager_get_envvar(): sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_env_var("test", "") with pytest.raises(ValueError): sec.get_secrets_env_var("", "x") cfg = SecretsConfig.auto() assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST" + assert sec.get_secrets_env_var("group", "test", "v1") == f"{cfg.env_prefix}GROUP_V1_TEST" + assert sec.get_secrets_env_var("group", group_version="v1") == f"{cfg.env_prefix}GROUP_V1" + assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP" def test_secrets_manager_get_file(): sec = SecretsManager() - with pytest.raises(ValueError): - sec.get_secrets_file("test", "") with pytest.raises(ValueError): sec.get_secrets_file("", "x") cfg = SecretsConfig.auto() @@ -135,6 +134,12 @@ def test_secrets_manager_get_file(): "group", f"{cfg.file_prefix}test", ) + assert sec.get_secrets_file("group", "test", "v1") == os.path.join( + cfg.default_dir, + "group", + "v1", + f"{cfg.file_prefix}test", + ) def test_secrets_manager_file(tmpdir: py.path.local): @@ -145,8 +150,6 @@ def test_secrets_manager_file(tmpdir: py.path.local): with open(f, "w+") as w: w.write("my-password") - with pytest.raises(ValueError): - sec.get("test", "") with pytest.raises(ValueError): sec.get("", "x") # Group dir not exists @@ -207,7 +210,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 388 + assert len(tp) == 400 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py new file mode 100644 index 0000000000..2d61b58d8c --- /dev/null +++ b/tests/flytekit/unit/core/test_data.py @@ -0,0 +1,330 @@ +import os +import random +import shutil +import tempfile +from uuid import UUID + +import fsspec +import mock +import pytest + +from flytekit.configuration import Config, S3Config +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.types.directory.types import FlyteDirectory + +local = fsspec.filesystem("file") +root = os.path.abspath(os.sep) + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +@mock.patch("flytekit.core.data_persistence.UUID") +def test_path_getting(mock_uuid_class, mock_gcs): + mock_uuid_class.return_value.hex = "abcdef123" + + # Testing with raw output prefix pointing to a local path + loc_sandbox = os.path.join(root, "tmp", "unittest") + loc_data = os.path.join(root, "tmp", "unittestdata") + local_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data) + assert local_raw_fp.get_random_remote_path() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + assert local_raw_fp.get_random_remote_path("/fsa/blah.csv") == os.path.join( + root, "tmp", "unittestdata", "abcdef123", "blah.csv" + ) + assert local_raw_fp.get_random_remote_directory() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + + # Test local path and directory + assert local_raw_fp.get_random_local_path() == os.path.join(root, "tmp", "unittest", "local_flytekit", "abcdef123") + assert local_raw_fp.get_random_local_path("xjiosa/blah.txt") == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123", "blah.txt" + ) + assert local_raw_fp.get_random_local_directory() == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123" + ) + + # Recursive paths + assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" + ) + assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + "file:///abc/happy", "s3://my-s3-bucket/bucket1" + ) + + # Test with remote pointed to s3. + s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + # trailing slash should make no difference + s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + + # Testing with raw output prefix pointing to file:// + # Skip tests for windows + if os.name != "nt": + file_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata") + assert file_raw_fp.get_random_remote_path() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + assert file_raw_fp.get_random_remote_path("/fsa/blah.csv") == os.path.join( + root, "tmp", "unittestdata", "abcdef123", "blah.csv" + ) + assert file_raw_fp.get_random_remote_directory() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + + g_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/") + assert g_fa.get_random_remote_path() == "gs://my-s3-bucket/abcdef123" + + +@mock.patch("flytekit.core.data_persistence.UUID") +def test_default_file_access_instance(mock_uuid_class): + mock_uuid_class.return_value.hex = "abcdef123" + + assert default_local_file_access_provider.get_random_local_path().endswith( + os.path.join("sandbox", "local_flytekit", "abcdef123") + ) + assert default_local_file_access_provider.get_random_local_path("bob.txt").endswith( + os.path.join("abcdef123", "bob.txt") + ) + + assert default_local_file_access_provider.get_random_local_directory().endswith( + os.path.join("sandbox", "local_flytekit", "abcdef123") + ) + + x = default_local_file_access_provider.get_random_remote_path() + assert x.endswith(os.path.join("raw", "abcdef123")) + x = default_local_file_access_provider.get_random_remote_path("eve.txt") + assert x.endswith(os.path.join("raw", "abcdef123", "eve.txt")) + x = default_local_file_access_provider.get_random_remote_directory() + assert x.endswith(os.path.join("raw", "abcdef123")) + + +@pytest.fixture +def source_folder(): + # Set up source directory for testing + parent_temp = tempfile.mkdtemp() + src_dir = os.path.join(parent_temp, "source", "") + nested_dir = os.path.join(src_dir, "nested") + local.mkdir(nested_dir) + local.touch(os.path.join(src_dir, "original.txt")) + with open(os.path.join(src_dir, "original.txt"), "w") as fh: + fh.write("hello original") + local.touch(os.path.join(nested_dir, "more.txt")) + yield src_dir + shutil.rmtree(parent_temp) + + +def test_local_raw_fsspec(source_folder): + # Test copying using raw fsspec local filesystem, should not create a nested folder + with tempfile.TemporaryDirectory() as dest_tmpdir: + local.put(source_folder, dest_tmpdir, recursive=True) + + new_temp_dir_2 = tempfile.mkdtemp() + new_temp_dir_2 = os.path.join(new_temp_dir_2, "doesnotexist") + local.put(source_folder, new_temp_dir_2, recursive=True) + files = local.find(new_temp_dir_2) + assert len(files) == 2 + + +def test_local_provider(source_folder): + # Test that behavior putting from a local dir to a local remote dir is the same whether or not the local + # dest folder exists. + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as dest_tmpdir: + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + files = provider._default_remote.find(doesnotexist) + assert len(files) == 2 + + exists = provider.get_random_remote_directory() + provider._default_remote.mkdir(exists) + provider.put_data(source_folder, exists, is_multipart=True) + files = provider._default_remote.find(exists) + assert len(files) == 2 + + +@pytest.mark.sandbox_test +def test_s3_provider(source_folder): + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + fs = provider.get_filesystem_for_path(doesnotexist) + files = fs.find(doesnotexist) + assert len(files) == 2 + + +def test_local_provider_get_empty(): + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as empty_source: + with tempfile.TemporaryDirectory() as dest_folder: + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=empty_source, data_config=dc + ) + provider.get_data(empty_source, dest_folder, is_multipart=True) + loc = provider.get_filesystem_for_path(dest_folder) + src_files = loc.find(empty_source) + assert len(src_files) == 0 + dest_files = loc.find(dest_folder) + assert len(dest_files) == 0 + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_empty(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + mock_os.get.return_value = None + s3c = S3Config.auto() + kwargs = s3_setup_args(s3c) + assert kwargs == {"cache_regions": True} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_both(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "AWS_ACCESS_KEY_ID": "ignore-user", + "AWS_SECRET_ACCESS_KEY": "ignore-secret", + "FLYTE_AWS_ACCESS_KEY_ID": "flyte", + "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_flyte(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AWS_ACCESS_KEY_ID": "flyte", + "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + assert kwargs == {"key": "flyte", "secret": "flyte-secret", "cache_regions": True} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "AWS_ACCESS_KEY_ID": "ignore-user", + "AWS_SECRET_ACCESS_KEY": "ignore-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default + assert kwargs == {"cache_regions": True} + + +def test_crawl_local_nt(source_folder): + """ + running this to see what it prints + """ + if os.name != "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + print(f"NT split {split}") + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + print(f"NT files joined {files}") + + +def test_crawl_local_non_nt(source_folder): + """ + crawl on the source folder fixture should return for example + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'original.txt') + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'nested/more.txt') + """ + if os.name == "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in split] + assert set(split) == {(source_folder, "original.txt"), (source_folder, os.path.join("nested", "more.txt"))} + expected = {os.path.join(source_folder, "original.txt"), os.path.join(source_folder, "nested", "more.txt")} + assert set(files) == expected + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + + # Test crawling a single file + fd = FlyteDirectory(path=os.path.join(source_folder, "original.txt")) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert len(files) == 0 + + +@pytest.mark.sandbox_test +def test_crawl_s3(source_folder): + """ + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'original.txt') + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'nested/more.txt') + """ + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + s3_random_target = provider.get_random_remote_directory() + provider.put_data(source_folder, s3_random_target, is_multipart=True) + ctx = FlyteContextManager.current_context() + expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} + + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory(path=s3_random_target) + res = fd.crawl() + res = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + assert set(res) == {(s3_random_target, "original.txt"), (s3_random_target, os.path.join("nested", "more.txt"))} + + fd_file = FlyteDirectory(path=f"{s3_random_target}/original.txt") + res = fd_file.crawl() + files = [r for r in res] + assert len(files) == 1 + + +@pytest.mark.sandbox_test +def test_walk_local_copy_to_s3(source_folder): + dc = Config.for_sandbox().data_config + explicit_empty_folder = UUID(int=random.getrandbits(128)).hex + raw_output_path = f"s3://my-s3-bucket/testdata/{explicit_empty_folder}" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output_path, data_config=dc) + + ctx = FlyteContextManager.current_context() + local_fd = FlyteDirectory(path=source_folder) + local_fd_crawl = local_fd.crawl() + local_fd_crawl = [x for x in local_fd_crawl] + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory.new_remote() + assert raw_output_path in fd.path + + # Write source folder files to new remote path + for root_path, suffix in local_fd_crawl: + new_file = fd.new_file(suffix) # noqa + with open(os.path.join(root_path, suffix), "rb") as r: # noqa + with new_file.open("w") as w: + print(f"Writing, t {type(w)} p {new_file.path} |{suffix}|") + w.write(str(r.read())) + + new_crawl = fd.crawl() + new_suffixes = [y for x, y in new_crawl] + assert len(new_suffixes) == 2 # should have written two files diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index af39e9e852..27b407c1ce 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,11 +1,11 @@ -from flytekit.core.data_persistence import DataPersistencePlugins, FileAccessProvider +from flytekit.core.data_persistence import FileAccessProvider def test_get_random_remote_path(): fp = FileAccessProvider("/tmp", "s3://my-bucket") path = fp.get_random_remote_path() assert path.startswith("s3://my-bucket") - assert fp.raw_output_prefix == "s3://my-bucket" + assert fp.raw_output_prefix == "s3://my-bucket/" def test_is_remote(): @@ -14,10 +14,3 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True - - -def test_lister(): - x = DataPersistencePlugins.supported_protocols() - main_protocols = {"file", "/", "gs", "http", "https", "s3"} - all_protocols = set([y.replace("://", "") for y in x]) - assert main_protocols.issubset(all_protocols) diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index 0cb4f524f9..bd20c39c53 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -49,7 +49,6 @@ def test_engine(): def test_transformer_to_literal_local(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) ctx = context_manager.FlyteContext.current_context() @@ -86,6 +85,15 @@ def test_transformer_to_literal_local(): with pytest.raises(TypeError, match="No automatic conversion from "): TypeEngine.to_literal(ctx, 3, FlyteDirectory, lt) + +def test_transformer_to_literal_localss(): + random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + + tf = FlyteDirToMultipartBlobTransformer() + lt = tf.get_literal_type(FlyteDirectory) # Can't use if it's not a directory with pytest.raises(FlyteAssertion): p = "/tmp/flyte/xyz" diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index e2123222e0..b7f0a1aeee 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock import pytest +from typing_extensions import Annotated import flytekit.configuration -from flytekit.configuration import Image, ImageConfig -from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState +from flytekit.configuration import Config, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine @@ -81,11 +82,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - # print(f"Random: {random_dir}") + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -108,10 +108,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -137,12 +137,12 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir print(f"Random {random_dir}") fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit folder @@ -189,11 +189,11 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit dir @@ -243,8 +243,8 @@ def dyn(in1: FlyteFile): fd = FlyteFile("s3://anything") - with context_manager.FlyteContextManager.with_context( - context_manager.FlyteContextManager.current_context().with_serialization_settings( + with FlyteContextManager.with_context( + FlyteContextManager.current_context().with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", @@ -254,8 +254,8 @@ def dyn(in1: FlyteFile): ) ) ): - ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context( + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( @@ -433,3 +433,59 @@ def wf(path: str) -> os.PathLike: return t2(ff=n1) assert flyte_tmp_dir in wf(path="s3://somewhere").path + + +def test_flyte_file_annotated_hashmethod(local_dummy_file): + def calc_hash(ff: FlyteFile) -> str: + return str(ff.path) + + @task + def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: + return FlyteFile(path) + + @workflow + def wf(path: str) -> None: + t1(path=path) + + wf(path=local_dummy_file) + + +@pytest.mark.sandbox_test +def test_file_open_things(): + @task + def write_this_file_to_s3() -> FlyteFile: + ctx = FlyteContextManager.current_context() + dest = ctx.file_access.get_random_remote_path() + ctx.file_access.put(__file__, dest) + return FlyteFile(path=dest) + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.remote_path) + with ff.open("r") as r: + with new_file.open("w") as w: + w.write(r.read()) + return new_file + + @task + def print_file(ff: FlyteFile): + with open(ff, "r") as fh: + print(len(fh.readlines())) + + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as new_sandbox: + provider = FileAccessProvider( + local_sandbox_dir=new_sandbox, raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + ctx = FlyteContextManager.current_context() + local = ctx.file_access.get_filesystem("file") # get a local file system. + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + f = write_this_file_to_s3() + copy_file(ff=f) + files = local.find(new_sandbox) + # copy_file was done via streaming so no files should have been written + assert len(files) == 0 + print_file(ff=f) + # print_file uses traditional download semantics so now a file should have been created + files = local.find(new_sandbox) + assert len(files) == 1 diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 7ceec809b1..c45e200f95 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -14,7 +14,7 @@ from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType from flytekit.tools.translator import get_serializable -from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -55,6 +55,11 @@ def test_get_literal_type(): ) +def test_batch_size(): + bs = BatchSize(5) + assert bs.val == 5 + + def test_nested(): class Foo(object): def __init__(self, number: int): diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 95927873d0..d032aca2d1 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -1,3 +1,4 @@ +import functools import typing from collections import OrderedDict @@ -6,7 +7,7 @@ import flytekit.configuration from flytekit import LaunchPlan, map_task from flytekit.configuration import Image, ImageConfig -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapPythonTask, MapTaskResolver from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable @@ -36,6 +37,11 @@ def t2(a: int) -> str: return str(b) +@task(cache=True, cache_version="1") +def t3(a: int, b: str, c: float) -> str: + pass + + # This test is for documentation. def test_map_docs(): # test_map_task_start @@ -87,8 +93,12 @@ def test_serialization(serialization_settings): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.flytekit.unit.core.test_map_task", "task-name", @@ -177,15 +187,42 @@ def test_inputs_outputs_length(): def many_inputs(a: int, b: str, c: float) -> str: return f"{a} - {b} - {c}" - with pytest.raises(ValueError): - _ = map_task(many_inputs) + m = map_task(many_inputs) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_24c08b3a2f9c2e389ad9fc6a03482cf9" + r_m = MapPythonTask(many_inputs) + assert str(r_m.python_interface) == str(m.python_interface) + + p1 = functools.partial(many_inputs, c=1.0) + m = map_task(p1) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_697aa7389996041183cf6cfd102be4f7" + r_m = MapPythonTask(many_inputs, bound_inputs=set("c")) + assert str(r_m.python_interface) == str(m.python_interface) + + p2 = functools.partial(p1, b="hello") + m = map_task(p2) + assert m.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_cc18607da7494024a402a5fa4b3ea5c6" + r_m = MapPythonTask(many_inputs, bound_inputs={"c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + p3 = functools.partial(p2, a=1) + m = map_task(p3) + assert m.python_interface.inputs == {"a": int, "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_52fe80b04781ea77ef6f025f4b49abef" + r_m = MapPythonTask(many_inputs, bound_inputs={"a", "c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + with pytest.raises(TypeError): + m(a=[1, 2, 3]) @task def many_outputs(a: int) -> (int, str): return a, f"{a}" with pytest.raises(ValueError): - _ = map_task(many_inputs) + _ = map_task(many_outputs) def test_map_task_metadata(): @@ -194,3 +231,34 @@ def test_map_task_metadata(): assert mapped_1.metadata is map_meta mapped_2 = map_task(t2) assert mapped_2.metadata is t2.metadata + + +def test_map_task_resolver(serialization_settings): + list_outputs = {"o0": typing.List[str]} + mt = map_task(t3) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + assert mtr.name() == "MapTaskResolver" + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello", c=1.0)) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello")) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 48d3020e88..0fb9d6677c 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -419,7 +419,7 @@ def t1(a: str) -> str: @workflow def my_wf(a: str) -> str: - return t1(a=a).with_overrides(name="foo") + return t1(a=a).with_overrides(name="foo", node_name="t_1") serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -431,6 +431,7 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + assert wf_spec.template.nodes[0].id == "t-1" def test_config_override(): diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py new file mode 100644 index 0000000000..24e3908d1d --- /dev/null +++ b/tests/flytekit/unit/core/test_partials.py @@ -0,0 +1,219 @@ +import typing +from collections import OrderedDict +from functools import partial + +import pandas as pd +import pytest + +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.map_task import MapTaskResolver, map_task +from flytekit.core.task import TaskMetadata, task +from flytekit.core.workflow import workflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + +def test_basics_1(): + @task + def t1(a: int, b: str, c: float) -> int: + return a + len(b) + int(c) + + outside_p = partial(t1, b="hello", c=3.14) + + @workflow + def my_wf_1(a: int) -> typing.Tuple[int, int]: + inner_partial = partial(t1, b="world", c=2.7) + out = outside_p(a=a) + inside = inner_partial(a=a) + return out, inside + + with pytest.raises(Exception): + get_serializable(OrderedDict(), serialization_settings, outside_p) + + # check the od todo + od = OrderedDict() + wf_1_spec = get_serializable(od, serialization_settings, my_wf_1) + tts, wspecs, lps = gather_dependent_entities(od) + tts = [t for t in tts.values()] + assert len(tts) == 1 + assert len(wf_1_spec.template.nodes) == 2 + assert wf_1_spec.template.nodes[0].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[1].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[0].inputs[0].binding.promise.var == "a" + assert wf_1_spec.template.nodes[0].inputs[1].binding.scalar is not None + assert wf_1_spec.template.nodes[0].inputs[2].binding.scalar is not None + + @task + def get_str() -> str: + return "got str" + + bind_c = partial(t1, c=2.7) + + @workflow + def my_wf_2(a: int) -> int: + s = get_str() + inner_partial = partial(bind_c, b=s) + inside = inner_partial(a=a) + return inside + + wf_2_spec = get_serializable(OrderedDict(), serialization_settings, my_wf_2) + assert len(wf_2_spec.template.nodes) == 2 + + +def test_map_task_types(): + @task(cache=True, cache_version="1") + def t3(a: int, b: str, c: float) -> str: + return str(a) + b + str(c) + + t3_bind_b1 = partial(t3, b="hello") + t3_bind_b2 = partial(t3, b="world") + t3_bind_c1 = partial(t3_bind_b1, c=3.14) + t3_bind_c2 = partial(t3_bind_b2, c=2.78) + + mt1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + + @task + def print_lists(i: typing.List[str], j: typing.List[str]): + print(f"First: {i}") + print(f"Second: {j}") + + @workflow + def wf_out(a: typing.List[int]): + i = mt1(a=a) + j = mt2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_out(a=[1, 2]) + + @workflow + def wf_in(a: typing.List[int]): + mt_in1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt_in2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + i = mt_in1(a=a) + j = mt_in2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_in(a=[1, 2]) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf_in) + tts, _, _ = gather_dependent_entities(od) + assert len(tts) == 2 # one map task + the print task + assert ( + wf_spec.template.nodes[0].task_node.reference_id.name == wf_spec.template.nodes[1].task_node.reference_id.name + ) + assert wf_spec.template.nodes[0].inputs[0].binding.promise is not None # comes from wf input + assert wf_spec.template.nodes[1].inputs[0].binding.collection is not None # bound to static list + assert wf_spec.template.nodes[1].inputs[1].binding.scalar is not None # these are bound + assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None + + +def test_lists_cannot_be_used_in_partials(): + @task + def t(a: int, b: typing.List[str]) -> str: + return str(a) + str(b) + + with pytest.raises(ValueError): + map_task(partial(t, b=["hello", "world"]))(a=[1, 2, 3]) + + @task + def t_multilist(a: int, b: typing.List[float], c: typing.List[int]) -> str: + return str(a) + str(b) + str(c) + + with pytest.raises(ValueError): + map_task(partial(t_multilist, b=[3.14, 12.34, 9876.5432], c=[42, 99]))(a=[1, 2, 3, 4]) + + @task + def t_list_of_lists(a: typing.List[typing.List[float]], b: int) -> str: + return str(a) + str(b) + + with pytest.raises(ValueError): + map_task(partial(t_list_of_lists, a=[[3.14]]))(b=[1, 2, 3, 4]) + + +def test_everything(): + @task + def get_static_list() -> typing.List[float]: + return [3.14, 2.718] + + @task + def get_list_of_pd(s: int) -> typing.List[pd.DataFrame]: + df1 = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + df2 = pd.DataFrame({"Name": ["Rachel", "Eve", "Mary"], "Age": [22, 23, 24]}) + if s == 2: + return [df1, df2] + else: + return [df1, df2, df1] + + @task + def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.DataFrame) -> str: + return str(a) + f"pdsize{len(a2)}" + b + str(c) + "&&" + str(d) + + t3_bind_b2 = partial(t3, b="world") + # TODO: partial lists are not supported yet. + # t3_bind_b1 = partial(t3, b="hello") + # t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) + # mt1 = map_task(t3_bind_c1) + + mt1 = map_task(t3_bind_b2) + + mr = MapTaskResolver() + aa = mr.loader_args(serialization_settings, mt1) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["b"] + + @task + def print_lists(i: typing.List[str], j: typing.List[str], k: typing.List[str]) -> str: + print(f"First: {i}") + print(f"Second: {j}") + print(f"Third: {k}") + return f"{i}-{j}-{k}" + + @dynamic + def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str: + i = mt1(a=a, a2=a2, c=[[1.1, 2.0, 3.0], [1.1, 2.0, 3.0]], d=[sl, sl]) + mt_in2 = map_task(t3_bind_b2) + dfs = get_list_of_pd(s=3) + j = mt_in2(a=[3, 4, 5], a2=dfs, c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) + + # Test a2 bound to a fixed dataframe + t3_bind_a2 = partial(t3_bind_b2, a2=a2[0]) + + mt_in3 = map_task(t3_bind_a2) + + aa = mr.loader_args(serialization_settings, mt_in3) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["a2", "b"] + + k = mt_in3(a=[3, 4, 5], c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) + return print_lists(i=i, j=j, k=k) + + @workflow + def wf_dt(a: typing.List[int]) -> str: + sl = get_static_list() + dfs = get_list_of_pd(s=2) + return dt1(a=a, a2=dfs, sl=sl) + + print(wf_dt(a=[1, 2])) + assert ( + wf_dt(a=[1, 2]) + == "['1pdsize2world[1.1, 2.0, 3.0]&&[3.14, 2.718]', '2pdsize3world[1.1, 2.0, 3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize3world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']-['3pdsize2world[1.0]&&[3.14, 2.718]', '4pdsize2world[2.0]&&[3.14, 2.718]', '5pdsize2world[3.0]&&[3.14, 2.718]']" + ) diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index d8b043116e..9478cc33ba 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -3,6 +3,7 @@ import pytest from dataclasses_json import dataclass_json +from typing_extensions import Annotated from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager @@ -14,6 +15,8 @@ translate_inputs_to_literals, ) from flytekit.exceptions.user import FlyteAssertion +from flytekit.types.pickle import FlytePickle +from flytekit.types.pickle.pickle import BatchSize def test_create_and_link_node(): @@ -74,7 +77,7 @@ def wf(i: int, j: int): # without providing the _inputs_not_allowed or _ignorable_inputs, all inputs to lp become required, # which is incorrect - with pytest.raises(FlyteAssertion, match="Missing input `i` type `simple: INTEGER"): + with pytest.raises(FlyteAssertion, match="Missing input `i` type ``"): create_and_link_node_from_remote(ctx, lp) # Even if j is not provided it will default @@ -92,7 +95,7 @@ def wf(i: int, j: int): @pytest.mark.parametrize( "input", - [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5], ) def test_translate_inputs_to_literals(input): @dataclass_json @@ -102,7 +105,7 @@ class MyDataclass(object): a: typing.List[str] @task - def t1(a: typing.Union[float, typing.List[int], MyDataclass]): + def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], BatchSize(2)]]): print(a) ctx = context_manager.FlyteContext.current_context() @@ -111,7 +114,7 @@ def t1(a: typing.Union[float, typing.List[int], MyDataclass]): def test_translate_inputs_to_literals_with_wrong_types(): ctx = context_manager.FlyteContext.current_context() - with pytest.raises(TypeError, match="Not a map type union_type"): + with pytest.raises(TypeError, match="Not a map type pd.DataFrame: @@ -74,7 +75,6 @@ def t1(a: pd.DataFrame) -> pd.DataFrame: def test_setting_of_unset_formats(): - custom = Annotated[StructuredDataset, "parquet"] example = custom(dataframe=df, uri="/path") # It's okay that the annotation is not used here yet. @@ -89,7 +89,9 @@ def t2(path: str) -> StructuredDataset: def wf(path: str) -> StructuredDataset: return t2(path=path) - res = wf(path="/tmp/somewhere") + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "somewhere") + res = wf(path=fname) # Now that it's passed through an encoder however, it should be set. assert res.file_format == "parquet" @@ -281,7 +283,10 @@ def encode( # Check that registering with a / triggers the file protocol instead. StructuredDatasetTransformerEngine.register(TempEncoder("/")) - assert StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("file") is not None + res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "file", "/") + # Test that the one we got was registered under fsspec + assert res is StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("fsspec")["/"] + assert res is not None def test_sd(): diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index c7aa5563f9..cef124ffd0 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -50,5 +50,5 @@ def test_arrow(): assert encoder.protocol is None assert decoder.protocol is None assert encoder.python_type is decoder.python_type - d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["s3"]["parquet"] + d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["fsspec"]["parquet"] assert d is not None diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index bd270fd360..a0aafc3c13 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1,10 +1,12 @@ import datetime +import json import os import tempfile import typing from dataclasses import asdict, dataclass from datetime import timedelta from enum import Enum +from typing import Optional, Type import mock import pandas as pd @@ -18,7 +20,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -51,7 +53,7 @@ from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.pickle import FlytePickle -from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer from flytekit.types.schema import FlyteSchema from flytekit.types.schema.types_pandas import PandasDataFrameTransformer from flytekit.types.structured.structured_dataset import StructuredDataset @@ -170,6 +172,51 @@ class Foo(object): assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) +def test_annotated_type(): + class JsonTypeTransformer(TypeTransformer[T]): + LiteralType = LiteralType( + simple=SimpleType.STRING, annotation=TypeAnnotation(annotations=dict(protocol="json")) + ) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + return self.LiteralType + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: + return json.loads(lv.scalar.primitive.string_value) + + def to_literal( + self, ctx: FlyteContext, python_val: T, python_type: typing.Type[T], expected: LiteralType + ) -> Literal: + return Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(python_val)))) + + class JSONSerialized: + def __class_getitem__(cls, item: Type[T]): + return Annotated[item, JsonTypeTransformer(name=f"json[{item}]", t=item)] + + MyJsonDict = JSONSerialized[typing.Dict[str, int]] + _, test_transformer = get_args(MyJsonDict) + + assert TypeEngine.get_transformer(MyJsonDict) is test_transformer + assert TypeEngine.to_literal_type(MyJsonDict) == JsonTypeTransformer.LiteralType + + test_dict = {"foo": 1} + test_literal = Literal(scalar=Scalar(primitive=Primitive(string_value=json.dumps(test_dict)))) + + assert ( + TypeEngine.to_python_value( + FlyteContext.current_context(), + test_literal, + MyJsonDict, + ) + == test_dict + ) + + assert ( + TypeEngine.to_literal(FlyteContext.current_context(), test_dict, MyJsonDict, JsonTypeTransformer.LiteralType) + == test_literal + ) + + def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() @@ -1574,3 +1621,67 @@ def test_file_ext_with_flyte_file_wrong_type(): with pytest.raises(ValueError) as e: FlyteFile[WRONG_TYPE] assert str(e.value) == "Underlying type of File Extension must be of type " + + +def test_is_batchable(): + assert ListTransformer.is_batchable(typing.List[int]) is False + assert ListTransformer.is_batchable(typing.List[str]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False + assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False + + assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True + assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True + assert ( + ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)]) + is True + ) + + +@pytest.mark.parametrize( + "python_val, python_type, expected_list_length", + [ + # Case 1: List of FlytePickle objects with default batch size. + # (By default, the batch_size is set to the length of the whole list.) + # After converting to literal, the result will be [batched_FlytePickle(5 items)]. + # Therefore, the expected list length is [1]. + ([{"foo"}] * 5, typing.List[FlytePickle], [1]), + # Case 2: List of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. + # Therefore, the expected list length is [3]. + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]), + # Case 3: Nested list of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] + # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]), + # Case 4: Empty list + ([[], typing.List[FlytePickle], []]), + ], +) +def test_batch_pickle_list(python_val, python_type, expected_list_length): + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + tmp_lv = lv + for length in expected_list_length: + # Check that after converting to literal, the length of the literal list is equal to: + # - the length of the original list divided by the batch size if not nested + # - the length of the original list if it contains a nested list + assert len(tmp_lv.collection.literals) == length + tmp_lv = tmp_lv.collection.literals[0] + + pv = TypeEngine.to_python_value(ctx, lv, python_type) + # Check that after converting literal to Python value, the result is equal to the original python values. + assert pv == python_val + if get_origin(python_type) is Annotated: + pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) + # Remove the annotation and check that after converting to Python value, the result is equal + # to the original input values. This is used to simulate the following case: + # @workflow + # def wf(): + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] + # task1(data=data) # task1(data: typing.List[FlytePickle]) + assert pv == python_val diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 9da416c1e8..cb65981963 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -3,12 +3,12 @@ import functools import os import random +import re import tempfile import typing from collections import OrderedDict from dataclasses import dataclass from enum import Enum -from textwrap import dedent import pandas import pandas as pd @@ -492,11 +492,13 @@ def t1(path: str) -> DatasetStruct: def wf(path: str) -> DatasetStruct: return t1(path=path) - res = wf(path="/tmp/somewhere") - assert "parquet" == res.a.file_format - assert "parquet" == res.b.a.file_format - assert_frame_equal(df, res.a.open(pd.DataFrame).all()) - assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "df_file") + res = wf(path=fname) + assert "parquet" == res.a.file_format + assert "parquet" == res.b.a.file_format + assert_frame_equal(df, res.a.open(pd.DataFrame).all()) + assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) def test_wf1_with_map(): @@ -1627,16 +1629,28 @@ def foo2(a: int, b: str) -> typing.Tuple[int, str]: def foo3(a: typing.Dict) -> typing.Dict: return a - with pytest.raises(TypeError, match="Type of Val 'hello' is not an instance of "): + with pytest.raises( + TypeError, + match=( + "Failed to convert inputs of task 'tests.flytekit.unit.core.test_type_hints.foo':\n" + " Failed argument 'a': Expected value of type but got 'hello' of type " + ), + ): foo(a="hello", b=10) # type: ignore with pytest.raises( TypeError, - match="Failed to convert return value for var o0 for " "function tests.flytekit.unit.core.test_type_hints.foo2", + match=( + "Failed to convert outputs of task 'tests.flytekit.unit.core.test_type_hints.foo2' at position 0:\n" + " Expected value of type but got 'hello' of type " + ), ): foo2(a=10, b="hello") - with pytest.raises(TypeError, match="Not a collection type simple: STRUCT\n but got a list \\[{'hello': 2}\\]"): + with pytest.raises( + TypeError, + match="Not a collection type but got a list \\[{'hello': 2}\\]", + ): foo3(a=[{"hello": 2}]) # type: ignore @@ -1672,28 +1686,12 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]: with pytest.raises( TypeError, - match=dedent( - r""" - Cannot convert from scalar { - union { - value { - scalar { - primitive { - string_value: "2" - } - } - } - type { - simple: STRING - structure { - tag: "str" - } - } - } - } - to typing.Union\[float, dict\] \(using tag str\) - """ - )[1:-1], + match=re.escape( + "Error encountered while executing 'wf2':\n" + " Failed to convert inputs of task 'tests.flytekit.unit.core.test_type_hints.t2':\n" + ' Cannot convert from to typing.Union[float, dict] (using tag str)' + ), ): assert wf2(a="2") == "2" diff --git a/tests/flytekit/unit/core/test_utils.py b/tests/flytekit/unit/core/test_utils.py index 112a864b30..5c191b31ee 100644 --- a/tests/flytekit/unit/core/test_utils.py +++ b/tests/flytekit/unit/core/test_utils.py @@ -1,6 +1,8 @@ import pytest -from flytekit.core.utils import _dnsify +import flytekit +from flytekit import FlyteContextManager, task +from flytekit.core.utils import _dnsify, timeit @pytest.mark.parametrize( @@ -20,3 +22,38 @@ ) def test_dnsify(input, expected): assert _dnsify(input) == expected + + +def test_timeit(): + ctx = FlyteContextManager.current_context() + ctx.user_space_params._decks = [] + + with timeit("Set disable_deck to False"): + kwargs = {} + kwargs["disable_deck"] = False + + ctx = FlyteContextManager.current_context() + time_info_list = ctx.user_space_params.timeline_deck.time_info + names = [time_info["Name"] for time_info in time_info_list] + # check if timeit works for flytekit level code + assert "Set disable_deck to False" in names + + @task(**kwargs) + def t1() -> int: + @timeit("Download data") + def download_data(): + return "1" + + data = download_data() + + with timeit("Convert string to int"): + return int(data) + + t1() + + time_info_list = flytekit.current_context().timeline_deck.time_info + names = [time_info["Name"] for time_info in time_info_list] + + # check if timeit works for user level code + assert "Download data" in names + assert "Convert string to int" in names diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 4f1082df63..1aeeba894d 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -136,6 +136,89 @@ def wf(b: int) -> nt: assert x == (7, 7) +def test_sub_wf_varying_types(): + @task + def t1l( + a: typing.List[typing.Dict[str, typing.List[int]]], + b: typing.Dict[str, typing.List[int]], + c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int], + d: int, + ) -> str: + xx = ",".join([f"{k}:{v}" for d in a for k, v in d.items()]) + yy = ",".join([f"{k}: {i}" for k, v in b.items() for i in v]) + if isinstance(c, list): + zz = ",".join([f"{k}:{v}" for d in c for k, v in d.items()]) + elif isinstance(c, dict): + zz = ",".join([f"{k}: {i}" for k, v in c.items() for i in v]) + else: + zz = str(c) + return f"First: {xx} Second: {yy} Third: {zz} Int: {d}" + + @task + def get_int() -> int: + return 1 + + @workflow + def subwf( + a: typing.List[typing.Dict[str, typing.List[int]]], + b: typing.Dict[str, typing.List[int]], + c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]]], + d: int, + ) -> str: + return t1l(a=a, b=b, c=c, d=d) + + @workflow + def wf() -> str: + ds = [ + {"first_map_a": [42], "first_map_b": [get_int(), 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + ll = { + "ll_1": [get_int(), get_int(), get_int()], + "ll_2": [4, 5, 6], + } + out = subwf(a=ds, b=ll, c=ds, d=get_int()) + return out + + wf.compile() + x = wf() + expected = ( + "First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Third: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Int: 1" + ) + assert x == expected + + @workflow + def wf() -> str: + ds = [ + {"first_map_a": [42], "first_map_b": [get_int(), 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + ll = { + "ll_1": [get_int(), get_int(), get_int()], + "ll_2": [4, 5, 6], + } + out = subwf(a=ds, b=ll, c=ll, d=get_int()) + return out + + x = wf() + expected = ( + "First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Third: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Int: 1" + ) + assert x == expected + + def test_unexpected_outputs(): @task def t1(a: int) -> int: diff --git a/tests/flytekit/unit/core/tracker/d.py b/tests/flytekit/unit/core/tracker/d.py index 9385b0f08d..c84e36fe59 100644 --- a/tests/flytekit/unit/core/tracker/d.py +++ b/tests/flytekit/unit/core/tracker/d.py @@ -9,3 +9,7 @@ def tasks(): @task def foo(): pass + + +def inner_function(a: str) -> str: + return "hello" diff --git a/tests/flytekit/unit/core/tracker/test_arrow_data.py b/tests/flytekit/unit/core/tracker/test_arrow_data.py new file mode 100644 index 0000000000..747e7f1651 --- /dev/null +++ b/tests/flytekit/unit/core/tracker/test_arrow_data.py @@ -0,0 +1,29 @@ +import typing + +import pandas as pd +import pyarrow as pa +from typing_extensions import Annotated + +from flytekit import kwtypes, task + +cols = kwtypes(Name=str, Age=int) +subset_cols = kwtypes(Name=str) + + +@task +def t1( + df1: Annotated[pd.DataFrame, cols], df2: Annotated[pa.Table, cols] +) -> typing.Tuple[Annotated[pd.DataFrame, subset_cols], Annotated[pa.Table, subset_cols]]: + return df1, df2 + + +def test_structured_dataset_wf(): + pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + pa_df = pa.Table.from_pandas(pd_df) + + subset_pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"]}) + subset_pa_df = pa.Table.from_pandas(subset_pd_df) + + df1, df2 = t1(df1=pd_df, df2=pa_df) + assert df1.equals(subset_pd_df) + assert df2.equals(subset_pa_df) diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index 33ae18acd5..b33725436d 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -79,3 +79,10 @@ def test_extract_task_module(test_input, expected): except Exception: FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT = old raise + + +local_task = task(d.inner_function) + + +def test_local_task_wrap(): + assert local_task.instantiated_in == "tests.flytekit.unit.core.tracker.test_tracking" diff --git a/tests/flytekit/unit/deck/test_deck.py b/tests/flytekit/unit/deck/test_deck.py index a6b00e79e2..f65c94b877 100644 --- a/tests/flytekit/unit/deck/test_deck.py +++ b/tests/flytekit/unit/deck/test_deck.py @@ -1,3 +1,5 @@ +import datetime + import pandas as pd import pytest from mock import mock @@ -23,12 +25,30 @@ def test_deck(): _output_deck("test_task", ctx.user_space_params) +def test_timeline_deck(): + time_info = dict( + Name="foo", + Start=datetime.datetime.utcnow(), + Finish=datetime.datetime.utcnow() + datetime.timedelta(microseconds=1000), + WallTime=1.0, + ProcessTime=1.0, + ) + ctx = FlyteContextManager.current_context() + ctx.user_space_params._decks = [] + timeline_deck = ctx.user_space_params.timeline_deck + timeline_deck.append_time_info(time_info) + assert timeline_deck.name == "Timeline" + assert len(timeline_deck.time_info) == 1 + assert timeline_deck.time_info[0] == time_info + assert len(ctx.user_space_params.decks) == 1 + + @pytest.mark.parametrize( "disable_deck,expected_decks", [ - (None, 0), - (False, 2), # input and output decks - (True, 0), + (None, 1), # time line deck + (False, 3), # time line deck + input and output decks + (True, 1), # time line deck ], ) def test_deck_for_task(disable_deck, expected_decks): @@ -49,9 +69,9 @@ def t1(a: int) -> str: @pytest.mark.parametrize( "disable_deck, expected_decks", [ - (None, 1), - (False, 1 + 2), # input and output decks - (True, 1), + (None, 2), # default deck and time line deck + (False, 4), # default deck and time line deck + input and output decks + (True, 2), # default deck and time line deck ], ) def test_deck_pandas_dataframe(disable_deck, expected_decks): diff --git a/tests/flytekit/unit/extend/test_backend_plugin.py b/tests/flytekit/unit/extend/test_backend_plugin.py new file mode 100644 index 0000000000..9dfd20d99e --- /dev/null +++ b/tests/flytekit/unit/extend/test_backend_plugin.py @@ -0,0 +1,105 @@ +import typing +from datetime import timedelta +from unittest.mock import MagicMock + +import grpc +from flyteidl.service.external_plugin_service_pb2 import ( + PERMANENT_FAILURE, + SUCCEEDED, + TaskCreateRequest, + TaskCreateResponse, + TaskDeleteRequest, + TaskDeleteResponse, + TaskGetRequest, + TaskGetResponse, +) + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry +from flytekit.extend.backend.external_plugin_service import BackendPluginServer +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +dummy_id = "dummy_id" + + +class DummyPlugin(BackendPluginBase): + def __init__(self): + super().__init__(task_type="dummy") + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> TaskCreateResponse: + return TaskCreateResponse(job_id=dummy_id) + + def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse: + return TaskGetResponse(state=SUCCEEDED) + + def delete(self, context: grpc.ServicerContext, job_id) -> TaskDeleteResponse: + return TaskDeleteResponse() + + +BackendPluginRegistry.register(DummyPlugin()) + +task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version") +task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", +) + +int_type = types.LiteralType(types.SimpleType.INTEGER) +interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + }, + {}, +) +task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, +) + +dummy_template = TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type="dummy", + custom={}, +) + + +def test_dummy_plugin(): + ctx = MagicMock(spec=grpc.ServicerContext) + p = BackendPluginRegistry.get_plugin(ctx, "dummy") + assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == dummy_id + assert p.get(ctx, dummy_id).state == SUCCEEDED + assert p.delete(ctx, dummy_id) == TaskDeleteResponse() + + +def test_backend_plugin_server(): + server = BackendPluginServer() + ctx = MagicMock(spec=grpc.ServicerContext) + request = TaskCreateRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + ) + + assert server.CreateTask(request, ctx).job_id == dummy_id + assert server.GetTask(TaskGetRequest(task_type="dummy", job_id=dummy_id), ctx).state == SUCCEEDED + assert server.DeleteTask(TaskDeleteRequest(task_type="dummy", job_id=dummy_id), ctx) == TaskDeleteResponse() + + res = server.GetTask(TaskGetRequest(task_type="fake", job_id=dummy_id), ctx) + assert res.state == PERMANENT_FAILURE diff --git a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py b/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py deleted file mode 100644 index d2c50cc4a9..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py +++ /dev/null @@ -1,35 +0,0 @@ -import mock - -from flytekit import GCSPersistence - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1") - mock_exec.assert_called_with(["gsutil", "cp", "/test", "gs://my-bucket/k1"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "/test/*", "gs://my-bucket/k1/"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test") - mock_exec.assert_called_with(["gsutil", "cp", "gs://my-bucket/k1", "/test"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "gs://my-bucket/k1/*", "/test"]) diff --git a/tests/flytekit/unit/extras/persistence/test_http.py b/tests/flytekit/unit/extras/persistence/test_http.py deleted file mode 100644 index 893b43f364..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_http.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from flytekit import HttpPersistence - - -def test_put(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.put("", "", recursive=True) - - -def test_construct_path(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.construct_path(True, False, "", "") - - -def test_exists(): - proxy = HttpPersistence() - assert proxy.exists("https://flyte.org") diff --git a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py deleted file mode 100644 index a6f29f36d6..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py +++ /dev/null @@ -1,80 +0,0 @@ -from datetime import timedelta - -import mock - -from flytekit import S3Persistence -from flytekit.configuration import DataConfig, S3Config -from flytekit.extras.persistence import s3_awscli - - -def test_property(): - aws = S3Persistence("s3://raw-output") - assert aws.default_prefix == "s3://raw-output" - - -def test_construct_path(): - aws = S3Persistence() - p = aws.construct_path(True, False, "xyz") - assert p == "s3://xyz" - - -@mock.patch("flytekit.extras.persistence.s3_awscli.S3Persistence._check_binary") -@mock.patch("flytekit.extras.persistence.s3_awscli.subprocess") -def test_retries(mock_subprocess, mock_check): - mock_subprocess.check_call.side_effect = Exception("test exception (404)") - mock_check.return_value = True - - proxy = S3Persistence(data_config=DataConfig(s3=S3Config(backoff=timedelta(seconds=0)))) - assert proxy.exists("s3://test/fdsa/fdsa") is False - assert mock_subprocess.check_call.call_count == 8 - - -def test_extra_args(): - assert s3_awscli._extra_args({}) == [] - assert s3_awscli._extra_args({"ContentType": "ct"}) == ["--content-type", "ct"] - assert s3_awscli._extra_args({"ContentEncoding": "ec"}) == ["--content-encoding", "ec"] - assert s3_awscli._extra_args({"ACL": "acl"}) == ["--acl", "acl"] - assert s3_awscli._extra_args({"ContentType": "ct", "ContentEncoding": "ec", "ACL": "acl"}) == [ - "--content-type", - "ct", - "--content-encoding", - "ec", - "--acl", - "acl", - ] - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1") - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put_recursive(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test") - mock_exec.assert_called_with(cmd=["aws", "s3", "cp", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto()) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get_recursive(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto() - ) diff --git a/tests/flytekit/unit/extras/sqlite3/chinook.zip b/tests/flytekit/unit/extras/sqlite3/chinook.zip new file mode 100644 index 0000000000..6dd568fa61 Binary files /dev/null and b/tests/flytekit/unit/extras/sqlite3/chinook.zip differ diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index 40fc94a3d2..f8014f244b 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,3 +1,5 @@ +import os + import pandas import pytest @@ -10,8 +12,7 @@ from flytekit.types.schema import FlyteSchema ctx = context_manager.FlyteContextManager.current_context() -EXAMPLE_DB = ctx.file_access.get_random_local_path("chinook.zip") -ctx.file_access.get_data("https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip", EXAMPLE_DB) +EXAMPLE_DB = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinook.zip") # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( diff --git a/tests/flytekit/unit/extras/tensorflow/model/__init__.py b/tests/flytekit/unit/extras/tensorflow/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py new file mode 100644 index 0000000000..392ab695c5 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py @@ -0,0 +1,75 @@ +from collections import OrderedDict + +import numpy as np +import pytest +import tensorflow as tf + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.tensorflow import TensorFlowModelTransformer +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def get_tf_model(): + inputs = tf.keras.Input(shape=(32,)) + outputs = tf.keras.layers.Dense(1)(inputs) + tf_model = tf.keras.Model(inputs, outputs) + return tf_model + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + lt = transformer.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.MULTIPART)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + (TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + lt = transformer.get_literal_type(python_type) + + lv = transformer.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + output = transformer.to_python_value(ctx, lv, python_type) + + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.MULTIPART, + ) + ) + assert lv.scalar.blob.uri is not None + for w1, w2 in zip(output.weights, python_val.weights): + np.testing.assert_allclose(w1.numpy(), w2.numpy()) + + +def test_example_model(): + @task + def t1() -> tf.keras.Model: + return get_tf_model() + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorFlowModelTransformer.TENSORFLOW_FORMAT diff --git a/tests/flytekit/unit/models/core/test_security.py b/tests/flytekit/unit/models/core/test_security.py new file mode 100644 index 0000000000..c2933f9353 --- /dev/null +++ b/tests/flytekit/unit/models/core/test_security.py @@ -0,0 +1,13 @@ +from flytekit.models.security import Secret + + +def test_secret(): + obj = Secret("grp", "key") + obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) + assert obj2.key == "key" + assert obj2.group_version is None + + obj = Secret("grp", group_version="v1") + obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) + assert obj2.key is None + assert obj2.group_version == "v1" diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 4b8f82fb7e..5bfd7e4bf6 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -175,15 +175,7 @@ def test_more_stuff(mock_client): # Can't upload a folder with pytest.raises(ValueError): with tempfile.TemporaryDirectory() as tmp_dir: - r._upload_file(pathlib.Path(tmp_dir)) - - # Test that this copies the file. - with tempfile.TemporaryDirectory() as tmp_dir: - mm = MagicMock() - mm.signed_url = os.path.join(tmp_dir, "tmp_file") - mock_client.return_value.get_upload_signed_url.return_value = mm - - r._upload_file(pathlib.Path(__file__)) + r.upload_file(pathlib.Path(tmp_dir)) serialization_settings = flytekit.configuration.SerializationSettings( project="project", diff --git a/tests/flytekit/unit/tools/test_script_mode.py b/tests/flytekit/unit/tools/test_script_mode.py index a433769075..67898fb780 100644 --- a/tests/flytekit/unit/tools/test_script_mode.py +++ b/tests/flytekit/unit/tools/test_script_mode.py @@ -1,13 +1,51 @@ import os +import subprocess +import sys -from flytekit.tools.script_mode import compress_single_script, hash_file +from flytekit.tools.script_mode import compress_scripts, hash_file + +MAIN_WORKFLOW = """ +from flytekit import task, workflow +from wf1.test import t1 -WORKFLOW = """ @workflow def my_wf() -> str: return "hello world" """ +IMPERATIVE_WORKFLOW = """ +from flytekit import Workflow, task + +@task +def t1(a: int): + print(a) + + +wf = Workflow(name="my.imperative.workflow.example") +wf.add_workflow_input("a", int) +node_t1 = wf.add_entity(t1, a=wf.inputs["a"]) +""" + +T1_TASK = """ +from flytekit import task +from wf2.test import t2 + + +@task() +def t1() -> str: + print("hello") + return "hello" +""" + +T2_TASK = """ +from flytekit import task + +@task() +def t2() -> str: + print("hello") + return "hello" +""" + def test_deterministic_hash(tmp_path): workflows_dir = tmp_path / "workflows" @@ -17,19 +55,46 @@ def test_deterministic_hash(tmp_path): open(workflows_dir / "__init__.py", "a").close() # Write a dummy workflow workflow_file = workflows_dir / "hello_world.py" - workflow_file.write_text(WORKFLOW) + workflow_file.write_text(MAIN_WORKFLOW) + + imperative_workflow_file = workflows_dir / "imperative_wf.py" + imperative_workflow_file.write_text(IMPERATIVE_WORKFLOW) + + t1_dir = tmp_path / "wf1" + t1_dir.mkdir() + open(t1_dir / "__init__.py", "a").close() + t1_file = t1_dir / "test.py" + t1_file.write_text(T1_TASK) + + t2_dir = tmp_path / "wf2" + t2_dir.mkdir() + open(t2_dir / "__init__.py", "a").close() + t2_file = t2_dir / "test.py" + t2_file.write_text(T2_TASK) destination = tmp_path / "destination" - compress_single_script(workflows_dir, destination, "hello_world") - print(f"{os.listdir(tmp_path)}") + sys.path.append(str(workflows_dir.parent)) + compress_scripts(str(workflows_dir.parent), str(destination), "workflows.hello_world") digest, hex_digest = hash_file(destination) # Try again to assert digest determinism destination2 = tmp_path / "destination2" - compress_single_script(workflows_dir, destination2, "hello_world") + compress_scripts(str(workflows_dir.parent), str(destination2), "workflows.hello_world") digest2, hex_digest2 = hash_file(destination) assert digest == digest2 assert hex_digest == hex_digest2 + + test_dir = tmp_path / "test" + test_dir.mkdir() + + result = subprocess.run( + ["tar", "-xvf", str(destination), "-C", str(test_dir)], + stdout=subprocess.PIPE, + ) + result.check_returncode() + assert len(next(os.walk(test_dir))[1]) == 3 + + compress_scripts(str(workflows_dir.parent), str(destination), "workflows.imperative_wf") diff --git a/tests/flytekit/unit/types/directory/__init__.py b/tests/flytekit/unit/types/directory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/types/directory/test_types.py b/tests/flytekit/unit/types/directory/test_types.py new file mode 100644 index 0000000000..199b788733 --- /dev/null +++ b/tests/flytekit/unit/types/directory/test_types.py @@ -0,0 +1,31 @@ +import mock + +from flytekit import FlyteContext +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile + + +def test_new_file_dir(): + fd = FlyteDirectory(path="s3://my-bucket") + assert fd.sep == "/" + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + fd = FlyteDirectory(path="s3://my-bucket/") + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + f = inner_dir.new_file("test") + assert isinstance(f, FlyteFile) + assert f.path == "s3://my-bucket/test/test" + + +def test_new_remote_dir(): + fd = FlyteDirectory.new_remote() + assert FlyteContext.current_context().file_access.raw_output_prefix in fd.path + + +@mock.patch("flytekit.types.directory.types.os.name", "nt") +def test_sep_nt(): + fd = FlyteDirectory(path="file://mypath") + assert fd.sep == "\\" + fd = FlyteDirectory(path="s3://mypath") + assert fd.sep == "/" diff --git a/tests/flytekit/unit/types/file/__init__.py b/tests/flytekit/unit/types/file/__init__.py new file mode 100644 index 0000000000..e69de29bb2