diff --git a/plugins/flytekit-huggingface/README.md b/plugins/flytekit-huggingface/README.md new file mode 100644 index 0000000000..394c489ab9 --- /dev/null +++ b/plugins/flytekit-huggingface/README.md @@ -0,0 +1,10 @@ +# Flytekit Hugging Face Plugin +[Hugging Face](https://github.com/huggingface) is a community and data science platform that provides: Tools that enable users to build, train and deploy ML models based on open source (OS) code and technologies + +This plugin supports `datasets.Dataset` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html). + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-huggingface +``` diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py new file mode 100644 index 0000000000..30a877bb40 --- /dev/null +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py @@ -0,0 +1,14 @@ +""" +.. currentmodule:: flytekitplugins.huggingface + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + HuggingFaceDatasetToParquetEncodingHandler + ParquetToHuggingFaceDatasetDecodingHandler +""" + +from .sd_transformers import HuggingFaceDatasetToParquetEncodingHandler, ParquetToHuggingFaceDatasetDecodingHandler diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py new file mode 100644 index 0000000000..0690179bb1 --- /dev/null +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -0,0 +1,72 @@ +import typing + +import datasets + +from flytekit import FlyteContext +from flytekit.models import literals +from flytekit.models.literals import StructuredDatasetMetadata +from flytekit.models.types import StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + PARQUET, + StructuredDataset, + StructuredDatasetDecoder, + StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, +) + + +class HuggingFaceDatasetRenderer: + """ + The datasets.Dataset printable representation is saved to HTML. + """ + + def to_html(self, df: datasets.Dataset) -> str: + assert isinstance(df, datasets.Dataset) + return str(df).replace("\n", "
") + + +class HuggingFaceDatasetToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + df = typing.cast(datasets.Dataset, structured_dataset.dataframe) + + local_dir = ctx.file_access.get_random_local_directory() + local_path = f"{local_dir}/00000" + + df.to_parquet(local_path) + + remote_dir = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + ctx.file_access.upload_directory(local_dir, remote_dir) + return literals.StructuredDataset(uri=remote_dir, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + +class ParquetToHuggingFaceDatasetDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> datasets.Dataset: + local_dir = ctx.file_access.get_random_local_directory() + ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) + path = f"{local_dir}/00000" + + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return datasets.Dataset.from_parquet(path, columns=columns) + return datasets.Dataset.from_parquet(path) + + +StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(datasets.Dataset, HuggingFaceDatasetRenderer()) diff --git a/plugins/flytekit-huggingface/requirements.in b/plugins/flytekit-huggingface/requirements.in new file mode 100644 index 0000000000..9419fdddce --- /dev/null +++ b/plugins/flytekit-huggingface/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-huggingface diff --git a/plugins/flytekit-huggingface/requirements.txt b/plugins/flytekit-huggingface/requirements.txt new file mode 100644 index 0000000000..d811a7b62d --- /dev/null +++ b/plugins/flytekit-huggingface/requirements.txt @@ -0,0 +1,238 @@ +# +# This file is autogenerated by pip-compile with python 3.9 +# To update, run: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-huggingface + # via -r requirements.in +aiohttp==3.8.1 + # via + # datasets + # fsspec +aiosignal==1.2.0 + # via aiohttp +arrow==1.2.2 + # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==21.4.0 + # via aiohttp +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.6.15 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.0 + # via + # aiohttp + # requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.1.0 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.5 + # via flytekit +cryptography==37.0.4 + # via pyopenssl +dataclasses-json==0.5.7 + # via flytekit +datasets==2.4.0 + # via flytekitplugins-huggingface +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +dill==0.3.5.1 + # via + # datasets + # multiprocess +diskcache==5.4.0 + # via flytekit +docker==5.0.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.14.1 + # via flytekit +filelock==3.7.1 + # via huggingface-hub +flyteidl==1.1.9 + # via flytekit +flytekit==1.1.0 + # via flytekitplugins-huggingface +frozenlist==1.3.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2022.7.0 + # via datasets +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +grpcio==1.47.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.47.0 + # via flytekit +huggingface-hub==0.8.1 + # via datasets +idna==3.3 + # via + # requests + # yarl +importlib-metadata==4.12.0 + # via + # flytekit + # keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +keyring==23.7.0 + # via flytekit +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.17.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +multidict==6.0.2 + # via + # aiohttp + # yarl +multiprocess==0.70.13 + # via datasets +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.1.0 + # via flytekit +numpy==1.23.1 + # via + # datasets + # pandas + # pyarrow +packaging==21.3 + # via + # datasets + # huggingface-hub + # marshmallow +pandas==1.4.3 + # via + # datasets + # flytekit +protobuf==3.20.1 + # via + # flyteidl + # flytekit + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via + # datasets + # flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.0.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.1 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # huggingface-hub +regex==2022.7.25 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # datasets + # docker + # flytekit + # fsspec + # huggingface-hub + # responses +responses==0.18.0 + # via + # datasets + # flytekit +retry==0.9.2 + # via flytekit +six==1.16.0 + # via + # grpcio + # python-dateutil +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +tqdm==4.64.0 + # via + # datasets + # huggingface-hub +typing-extensions==4.3.0 + # via + # flytekit + # huggingface-hub + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.11 + # via + # flytekit + # requests + # responses +websocket-client==1.3.3 + # via docker +wheel==0.37.1 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +xxhash==3.0.0 + # via datasets +yarl==1.7.2 + # via aiohttp +zipp==3.8.1 + # via importlib-metadata diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py new file mode 100644 index 0000000000..f44864216b --- /dev/null +++ b/plugins/flytekit-huggingface/setup.py @@ -0,0 +1,38 @@ +from setuptools import setup + +PLUGIN_NAME = "huggingface" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flytekit>=1.1.0b0,<1.2.0", + "datasets>=2.4.0", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Evan Sadler", + description="Hugging Face plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.7", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-huggingface/tests/__init__.py b/plugins/flytekit-huggingface/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py new file mode 100644 index 0000000000..170fdc3789 --- /dev/null +++ b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py @@ -0,0 +1,70 @@ +from typing import Annotated + +import datasets +import pandas as pd +from flytekitplugins.huggingface.sd_transformers import HuggingFaceDatasetRenderer + +from flytekit import kwtypes, task, workflow +from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset + +subset_schema = Annotated[StructuredDataset, kwtypes(col2=str), PARQUET] +full_schema = Annotated[StructuredDataset, PARQUET] + + +def test_huggingface_dataset_workflow_subset(): + @task + def generate() -> subset_schema: + df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + dataset = datasets.Dataset.from_pandas(df) + return StructuredDataset(dataframe=dataset) + + @task + def consume(df: subset_schema) -> subset_schema: + dataset = df.open(datasets.Dataset).all() + + assert dataset[0]["col2"] == "a" + assert dataset[1]["col2"] == "b" + assert dataset[2]["col2"] == "c" + + return StructuredDataset(dataframe=dataset) + + @workflow + def wf() -> subset_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_huggingface_dataset__workflow_full(): + @task + def generate() -> full_schema: + df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + dataset = datasets.Dataset.from_pandas(df) + return StructuredDataset(dataframe=dataset) + + @task + def consume(df: full_schema) -> full_schema: + dataset = df.open(datasets.Dataset).all() + + assert dataset[0]["col1"] == 1 + assert dataset[1]["col1"] == 3 + assert dataset[2]["col1"] == 2 + assert dataset[0]["col2"] == "a" + assert dataset[1]["col2"] == "b" + assert dataset[2]["col2"] == "c" + + return StructuredDataset(dataframe=dataset) + + @workflow + def wf() -> full_schema: + return consume(df=generate()) + + result = wf() + assert result is not None + + +def test_datasets_renderer(): + df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + dataset = datasets.Dataset.from_pandas(df) + assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "
")