Skip to content

Commit

Permalink
TypeTransformer for reading and writing from TensorFlowRecord format (#…
Browse files Browse the repository at this point in the history
…1240)

* first commit

Signed-off-by: Ryan Nazareth <[email protected]>

* add tensorflow example tf record transformer

Signed-off-by: Ryan Nazareth <[email protected]>

* refactor

Signed-off-by: Ryan Nazareth <[email protected]>

* correct tfexample description

Signed-off-by: Ryan Nazareth <[email protected]>

* fix test_native.py

Signed-off-by: Ryan Nazareth <[email protected]>

* add tensorflow docs and reqs

Signed-off-by: Ryan Nazareth <[email protected]>

* add tensorflow docs and reqs1

Signed-off-by: Ryan Nazareth <[email protected]>

* tensorflow import in init

Signed-off-by: Ryan Nazareth <[email protected]>

* fix failing tests

Signed-off-by: Ryan Nazareth <[email protected]>

* add tensorflow pinned version to reqs

Signed-off-by: Ryan Nazareth <[email protected]>

* pin grpcio-status to remove protobuf error

Signed-off-by: Ryan Nazareth <[email protected]>

* add suggested changes

Signed-off-by: Ryan Nazareth <[email protected]>

* redesign transformer

Signed-off-by: Ryan Nazareth <[email protected]>

* remove old script

Signed-off-by: Ryan Nazareth <[email protected]>

* fix type reference for TFREcordDataset

Signed-off-by: Ryan Nazareth <[email protected]>

* refactor

Signed-off-by: Ryan Nazareth <[email protected]>

* refactor

Signed-off-by: Ryan Nazareth <[email protected]>

* spacing and uppercase

Signed-off-by: Ryan Nazareth <[email protected]>

* redesign with tfdir and tfrecordfile subclass

Signed-off-by: Ryan Nazareth <[email protected]>

* fix conflicts and typos

Signed-off-by: Ryan Nazareth <[email protected]>

* address majority of comments

Signed-off-by: Ryan Nazareth <[email protected]>

* refactor

Signed-off-by: Ryan Nazareth <[email protected]>

* fix test with flytefile and metadata annotated

Signed-off-by: Ryan Nazareth <[email protected]>

* fix check for example records in directory

Signed-off-by: Ryan Nazareth <[email protected]>

* refactor and correct typing

Signed-off-by: Ryan Nazareth <[email protected]>

* lint

Signed-off-by: Ryan Nazareth <[email protected]>

* import annotated from typing_extensions

Signed-off-by: Ryan Nazareth <[email protected]>

* tweak to tests to test case when Config not passed in as type

Signed-off-by: Ryan Nazareth <[email protected]>

* add suggested changes

Signed-off-by: Ryan Nazareth <[email protected]>

* add task for tfrecord dir with no config in test

Signed-off-by: Ryan Nazareth <[email protected]>

* get filenames from local dir instead of remote

Signed-off-by: Ryan Nazareth <[email protected]>

Signed-off-by: Ryan Nazareth <[email protected]>
  • Loading branch information
ryankarlos authored Dec 6, 2022
1 parent bc29749 commit 467a137
Show file tree
Hide file tree
Showing 13 changed files with 517 additions and 14 deletions.
2 changes: 2 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ codespell
google-cloud-bigquery
google-cloud-bigquery-storage
IPython
tensorflow
grpcio-status<1.49.0
# Newer versions of torch bring in nvidia dependencies that are not present in windows, so
# we put this constraint while we do not have per-environment requirements files
torch<=1.12.1
Expand Down
80 changes: 67 additions & 13 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with python 3.7
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# make dev-requirements.txt
Expand All @@ -8,12 +8,18 @@
# via
# -c requirements.txt
# pytest-flyte
absl-py==1.3.0
# via
# tensorboard
# tensorflow
appnope==0.1.3
# via ipython
arrow==1.2.3
# via
# -c requirements.txt
# jinja2-time
astunparse==1.6.3
# via tensorflow
attrs==20.3.0
# via
# -c requirements.txt
Expand All @@ -28,8 +34,6 @@ binaryornot==0.4.4
# via
# -c requirements.txt
# cookiecutter
cached-property==1.5.2
# via docker-compose
cachetools==5.2.0
# via google-auth
certifi==2022.9.24
Expand Down Expand Up @@ -127,6 +131,8 @@ flyteidl==1.2.5
# via
# -c requirements.txt
# flytekit
gast==0.5.3
# via tensorflow
google-api-core[grpc]==2.10.2
# via
# google-cloud-bigquery
Expand All @@ -135,7 +141,11 @@ google-api-core[grpc]==2.10.2
google-auth==2.14.1
# via
# google-api-core
# google-auth-oauthlib
# google-cloud-core
# tensorboard
google-auth-oauthlib==0.4.6
# via tensorboard
google-cloud-bigquery==3.4.0
# via -r dev-requirements.in
google-cloud-bigquery-storage==2.16.2
Expand All @@ -146,6 +156,8 @@ google-cloud-core==2.3.2
# via google-cloud-bigquery
google-crc32c==1.5.0
# via google-resumable-media
google-pasta==0.2.0
# via tensorflow
google-resumable-media==2.4.0
# via google-cloud-bigquery
googleapis-common-protos==1.57.0
Expand All @@ -161,11 +173,16 @@ grpcio==1.48.2
# google-api-core
# google-cloud-bigquery
# grpcio-status
# tensorboard
# tensorflow
grpcio-status==1.48.2
# via
# -c requirements.txt
# -r dev-requirements.in
# flytekit
# google-api-core
h5py==3.7.0
# via tensorflow
identify==2.5.9
# via pre-commit
idna==3.4
Expand All @@ -175,14 +192,9 @@ idna==3.4
importlib-metadata==5.0.0
# via
# -c requirements.txt
# click
# flytekit
# jsonschema
# keyring
# pluggy
# pre-commit
# pytest
# virtualenv
# markdown
iniconfig==1.1.1
# via pytest
ipython==7.34.0
Expand Down Expand Up @@ -213,14 +225,23 @@ jsonschema==3.2.0
# via
# -c requirements.txt
# docker-compose
keras==2.8.0
# via tensorflow
keras-preprocessing==1.1.2
# via tensorflow
keyring==23.11.0
# via
# -c requirements.txt
# flytekit
libclang==14.0.6
# via tensorflow
markdown==3.4.1
# via tensorboard
markupsafe==2.1.1
# via
# -c requirements.txt
# jinja2
# werkzeug
marshmallow==3.19.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -259,9 +280,17 @@ nodeenv==1.7.0
numpy==1.21.6
# via
# -c requirements.txt
# flytekit
# h5py
# keras-preprocessing
# opt-einsum
# pandas
# pyarrow
# tensorboard
# tensorflow
oauthlib==3.2.2
# via requests-oauthlib
opt-einsum==3.3.0
# via tensorflow
# scikit-learn
# scipy
packaging==21.3
Expand Down Expand Up @@ -307,6 +336,8 @@ protobuf==3.20.3
# grpcio-status
# proto-plus
# protoc-gen-swagger
# tensorboard
# tensorflow
protoc-gen-swagger==0.1.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -407,7 +438,11 @@ requests==2.28.1
# flytekit
# google-api-core
# google-cloud-bigquery
# requests-oauthlib
# responses
# tensorboard
requests-oauthlib==1.3.1
# via google-auth-oauthlib
responses==0.22.0
# via
# -c requirements.txt
Expand All @@ -429,12 +464,16 @@ singledispatchmethod==1.0
six==1.16.0
# via
# -c requirements.txt
# astunparse
# dockerpty
# google-auth
# google-pasta
# grpcio
# jsonschema
# keras-preprocessing
# paramiko
# python-dateutil
# tensorflow
# websocket-client
sortedcontainers==2.4.0
# via
Expand All @@ -444,6 +483,20 @@ statsd==3.3.0
# via
# -c requirements.txt
# flytekit
tensorboard==2.8.0
# via tensorflow
tensorboard-data-server==0.6.1
# via tensorboard
tensorboard-plugin-wit==1.8.1
# via tensorboard
tensorflow==2.8.1
# via -r dev-requirements.in
tensorflow-estimator==2.8.0
# via tensorflow
tensorflow-io-gcs-filesystem==0.27.0
# via tensorflow
termcolor==2.0.1
# via tensorflow
text-unidecode==1.3
# via
# -c requirements.txt
Expand Down Expand Up @@ -477,11 +530,9 @@ types-toml==0.10.8.1
typing-extensions==4.4.0
# via
# -c requirements.txt
# arrow
# flytekit
# importlib-metadata
# mypy
# responses
# tensorflow
# torch
# typing-inspect
typing-inspect==0.8.0
Expand All @@ -507,12 +558,15 @@ websocket-client==0.59.0
wheel==0.38.4
# via
# -c requirements.txt
# astunparse
# flytekit
# tensorboard
wrapt==1.14.1
# via
# -c requirements.txt
# deprecated
# flytekit
# tensorflow
zipp==3.10.0
# via
# -c requirements.txt
Expand Down
7 changes: 7 additions & 0 deletions docs/source/extras.tensorflow.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
############
TensorFlow Type
############
.. automodule:: flytekit.extras.tensorflow
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/types.extend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ Refer to the :ref:`extensibility contribution guide <cookbook:advanced_custom_ty
types.builtins.file
types.builtins.directory
extras.pytorch
extras.tensorflow
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch
from flytekit.extras import pytorch, tensorflow
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand Down
31 changes: 31 additions & 0 deletions flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Flytekit TensorFlow
=========================================
.. currentmodule:: flytekit.extras.tensorflow
.. autosummary::
:template: custom.rst
:toctree: generated/
TensorFlowRecord
"""
from flytekit.loggers import logger

# TODO: abstract this out so that there's an established pattern for registering plugins
# that have soft dependencies
try:
# isolate the exception to the tensorflow import
import tensorflow

_tensorflow_installed = True
except (ImportError, OSError):
_tensorflow_installed = False


if _tensorflow_installed:
from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer
else:
logger.info(
"We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer "
"because tensorflow is not installed."
)
Loading

0 comments on commit 467a137

Please sign in to comment.