-
Notifications
You must be signed in to change notification settings - Fork 300
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
hugging Face Datasets Plugin (#1116)
Signed-off-by: Yee Hing Tong <[email protected]>
- Loading branch information
1 parent
495894d
commit 21ae290
Showing
8 changed files
with
444 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Flytekit Hugging Face Plugin | ||
[Hugging Face](https://github.com/huggingface) is a community and data science platform that provides: Tools that enable users to build, train and deploy ML models based on open source (OS) code and technologies | ||
|
||
This plugin supports `datasets.Dataset` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html). | ||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-huggingface | ||
``` |
14 changes: 14 additions & 0 deletions
14
plugins/flytekit-huggingface/flytekitplugins/huggingface/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
.. currentmodule:: flytekitplugins.huggingface | ||
This package contains things that are useful when extending Flytekit. | ||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
HuggingFaceDatasetToParquetEncodingHandler | ||
ParquetToHuggingFaceDatasetDecodingHandler | ||
""" | ||
|
||
from .sd_transformers import HuggingFaceDatasetToParquetEncodingHandler, ParquetToHuggingFaceDatasetDecodingHandler |
72 changes: 72 additions & 0 deletions
72
plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import typing | ||
|
||
import datasets | ||
|
||
from flytekit import FlyteContext | ||
from flytekit.models import literals | ||
from flytekit.models.literals import StructuredDatasetMetadata | ||
from flytekit.models.types import StructuredDatasetType | ||
from flytekit.types.structured.structured_dataset import ( | ||
PARQUET, | ||
StructuredDataset, | ||
StructuredDatasetDecoder, | ||
StructuredDatasetEncoder, | ||
StructuredDatasetTransformerEngine, | ||
) | ||
|
||
|
||
class HuggingFaceDatasetRenderer: | ||
""" | ||
The datasets.Dataset printable representation is saved to HTML. | ||
""" | ||
|
||
def to_html(self, df: datasets.Dataset) -> str: | ||
assert isinstance(df, datasets.Dataset) | ||
return str(df).replace("\n", "<br>") | ||
|
||
|
||
class HuggingFaceDatasetToParquetEncodingHandler(StructuredDatasetEncoder): | ||
def __init__(self): | ||
super().__init__(datasets.Dataset, None, PARQUET) | ||
|
||
def encode( | ||
self, | ||
ctx: FlyteContext, | ||
structured_dataset: StructuredDataset, | ||
structured_dataset_type: StructuredDatasetType, | ||
) -> literals.StructuredDataset: | ||
df = typing.cast(datasets.Dataset, structured_dataset.dataframe) | ||
|
||
local_dir = ctx.file_access.get_random_local_directory() | ||
local_path = f"{local_dir}/00000" | ||
|
||
df.to_parquet(local_path) | ||
|
||
remote_dir = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() | ||
ctx.file_access.upload_directory(local_dir, remote_dir) | ||
return literals.StructuredDataset(uri=remote_dir, metadata=StructuredDatasetMetadata(structured_dataset_type)) | ||
|
||
|
||
class ParquetToHuggingFaceDatasetDecodingHandler(StructuredDatasetDecoder): | ||
def __init__(self): | ||
super().__init__(datasets.Dataset, None, PARQUET) | ||
|
||
def decode( | ||
self, | ||
ctx: FlyteContext, | ||
flyte_value: literals.StructuredDataset, | ||
current_task_metadata: StructuredDatasetMetadata, | ||
) -> datasets.Dataset: | ||
local_dir = ctx.file_access.get_random_local_directory() | ||
ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) | ||
path = f"{local_dir}/00000" | ||
|
||
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: | ||
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] | ||
return datasets.Dataset.from_parquet(path, columns=columns) | ||
return datasets.Dataset.from_parquet(path) | ||
|
||
|
||
StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) | ||
StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler()) | ||
StructuredDatasetTransformerEngine.register_renderer(datasets.Dataset, HuggingFaceDatasetRenderer()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
. | ||
-e file:.#egg=flytekitplugins-huggingface |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
# | ||
# This file is autogenerated by pip-compile with python 3.9 | ||
# To update, run: | ||
# | ||
# pip-compile requirements.in | ||
# | ||
-e file:.#egg=flytekitplugins-huggingface | ||
# via -r requirements.in | ||
aiohttp==3.8.1 | ||
# via | ||
# datasets | ||
# fsspec | ||
aiosignal==1.2.0 | ||
# via aiohttp | ||
arrow==1.2.2 | ||
# via jinja2-time | ||
async-timeout==4.0.2 | ||
# via aiohttp | ||
attrs==21.4.0 | ||
# via aiohttp | ||
binaryornot==0.4.4 | ||
# via cookiecutter | ||
certifi==2022.6.15 | ||
# via requests | ||
cffi==1.15.1 | ||
# via cryptography | ||
chardet==5.0.0 | ||
# via binaryornot | ||
charset-normalizer==2.1.0 | ||
# via | ||
# aiohttp | ||
# requests | ||
click==8.1.3 | ||
# via | ||
# cookiecutter | ||
# flytekit | ||
cloudpickle==2.1.0 | ||
# via flytekit | ||
cookiecutter==2.1.1 | ||
# via flytekit | ||
croniter==1.3.5 | ||
# via flytekit | ||
cryptography==37.0.4 | ||
# via pyopenssl | ||
dataclasses-json==0.5.7 | ||
# via flytekit | ||
datasets==2.4.0 | ||
# via flytekitplugins-huggingface | ||
decorator==5.1.1 | ||
# via retry | ||
deprecated==1.2.13 | ||
# via flytekit | ||
dill==0.3.5.1 | ||
# via | ||
# datasets | ||
# multiprocess | ||
diskcache==5.4.0 | ||
# via flytekit | ||
docker==5.0.3 | ||
# via flytekit | ||
docker-image-py==0.1.12 | ||
# via flytekit | ||
docstring-parser==0.14.1 | ||
# via flytekit | ||
filelock==3.7.1 | ||
# via huggingface-hub | ||
flyteidl==1.1.9 | ||
# via flytekit | ||
flytekit==1.1.0 | ||
# via flytekitplugins-huggingface | ||
frozenlist==1.3.0 | ||
# via | ||
# aiohttp | ||
# aiosignal | ||
fsspec[http]==2022.7.0 | ||
# via datasets | ||
googleapis-common-protos==1.56.4 | ||
# via | ||
# flyteidl | ||
# grpcio-status | ||
grpcio==1.47.0 | ||
# via | ||
# flytekit | ||
# grpcio-status | ||
grpcio-status==1.47.0 | ||
# via flytekit | ||
huggingface-hub==0.8.1 | ||
# via datasets | ||
idna==3.3 | ||
# via | ||
# requests | ||
# yarl | ||
importlib-metadata==4.12.0 | ||
# via | ||
# flytekit | ||
# keyring | ||
jinja2==3.1.2 | ||
# via | ||
# cookiecutter | ||
# jinja2-time | ||
jinja2-time==0.2.0 | ||
# via cookiecutter | ||
keyring==23.7.0 | ||
# via flytekit | ||
markupsafe==2.1.1 | ||
# via jinja2 | ||
marshmallow==3.17.0 | ||
# via | ||
# dataclasses-json | ||
# marshmallow-enum | ||
# marshmallow-jsonschema | ||
marshmallow-enum==1.5.1 | ||
# via dataclasses-json | ||
marshmallow-jsonschema==0.13.0 | ||
# via flytekit | ||
multidict==6.0.2 | ||
# via | ||
# aiohttp | ||
# yarl | ||
multiprocess==0.70.13 | ||
# via datasets | ||
mypy-extensions==0.4.3 | ||
# via typing-inspect | ||
natsort==8.1.0 | ||
# via flytekit | ||
numpy==1.23.1 | ||
# via | ||
# datasets | ||
# pandas | ||
# pyarrow | ||
packaging==21.3 | ||
# via | ||
# datasets | ||
# huggingface-hub | ||
# marshmallow | ||
pandas==1.4.3 | ||
# via | ||
# datasets | ||
# flytekit | ||
protobuf==3.20.1 | ||
# via | ||
# flyteidl | ||
# flytekit | ||
# googleapis-common-protos | ||
# grpcio-status | ||
# protoc-gen-swagger | ||
protoc-gen-swagger==0.1.0 | ||
# via flyteidl | ||
py==1.11.0 | ||
# via retry | ||
pyarrow==6.0.1 | ||
# via | ||
# datasets | ||
# flytekit | ||
pycparser==2.21 | ||
# via cffi | ||
pyopenssl==22.0.0 | ||
# via flytekit | ||
pyparsing==3.0.9 | ||
# via packaging | ||
python-dateutil==2.8.2 | ||
# via | ||
# arrow | ||
# croniter | ||
# flytekit | ||
# pandas | ||
python-json-logger==2.0.4 | ||
# via flytekit | ||
python-slugify==6.1.2 | ||
# via cookiecutter | ||
pytimeparse==1.1.8 | ||
# via flytekit | ||
pytz==2022.1 | ||
# via | ||
# flytekit | ||
# pandas | ||
pyyaml==6.0 | ||
# via | ||
# cookiecutter | ||
# flytekit | ||
# huggingface-hub | ||
regex==2022.7.25 | ||
# via docker-image-py | ||
requests==2.28.1 | ||
# via | ||
# cookiecutter | ||
# datasets | ||
# docker | ||
# flytekit | ||
# fsspec | ||
# huggingface-hub | ||
# responses | ||
responses==0.18.0 | ||
# via | ||
# datasets | ||
# flytekit | ||
retry==0.9.2 | ||
# via flytekit | ||
six==1.16.0 | ||
# via | ||
# grpcio | ||
# python-dateutil | ||
sortedcontainers==2.4.0 | ||
# via flytekit | ||
statsd==3.3.0 | ||
# via flytekit | ||
text-unidecode==1.3 | ||
# via python-slugify | ||
tqdm==4.64.0 | ||
# via | ||
# datasets | ||
# huggingface-hub | ||
typing-extensions==4.3.0 | ||
# via | ||
# flytekit | ||
# huggingface-hub | ||
# typing-inspect | ||
typing-inspect==0.7.1 | ||
# via dataclasses-json | ||
urllib3==1.26.11 | ||
# via | ||
# flytekit | ||
# requests | ||
# responses | ||
websocket-client==1.3.3 | ||
# via docker | ||
wheel==0.37.1 | ||
# via flytekit | ||
wrapt==1.14.1 | ||
# via | ||
# deprecated | ||
# flytekit | ||
xxhash==3.0.0 | ||
# via datasets | ||
yarl==1.7.2 | ||
# via aiohttp | ||
zipp==3.8.1 | ||
# via importlib-metadata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from setuptools import setup | ||
|
||
PLUGIN_NAME = "huggingface" | ||
|
||
microlib_name = f"flytekitplugins-{PLUGIN_NAME}" | ||
|
||
plugin_requires = [ | ||
"flytekit>=1.1.0b0,<1.2.0", | ||
"datasets>=2.4.0", | ||
] | ||
|
||
__version__ = "0.0.0+develop" | ||
|
||
setup( | ||
name=microlib_name, | ||
version=__version__, | ||
author="Evan Sadler", | ||
description="Hugging Face plugin for flytekit", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.7", | ||
classifiers=[ | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
], | ||
) |
Empty file.
Oops, something went wrong.