Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add TypeTransformer for Tensorflow Model #1241

Closed
wants to merge 12 commits into from
Closed
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ google-cloud-bigquery
google-cloud-bigquery-storage
IPython
torch
tensorflow
93 changes: 78 additions & 15 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@
# via
# -c requirements.txt
# pytest-flyte
absl-py==1.3.0
# via
# tensorboard
# tensorflow
appnope==0.1.3
# via ipython
arrow==1.2.3
# via
# -c requirements.txt
# jinja2-time
astunparse==1.6.3
# via tensorflow
attrs==20.3.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -77,7 +85,6 @@ cryptography==38.0.1
# -c requirements.txt
# paramiko
# pyopenssl
# secretstorage
dataclasses-json==0.5.7
# via
# -c requirements.txt
Expand Down Expand Up @@ -118,12 +125,18 @@ docstring-parser==0.15
# via
# -c requirements.txt
# flytekit
exceptiongroup==1.0.0rc9
# via pytest
filelock==3.8.0
# via virtualenv
flatbuffers==22.9.24
# via tensorflow
flyteidl==1.1.22
# via
# -c requirements.txt
# flytekit
gast==0.4.0
# via tensorflow
google-api-core[grpc]==2.10.2
# via
# google-cloud-bigquery
Expand All @@ -132,7 +145,11 @@ google-api-core[grpc]==2.10.2
google-auth==2.13.0
# via
# google-api-core
# google-auth-oauthlib
# google-cloud-core
# tensorboard
google-auth-oauthlib==0.4.6
# via tensorboard
google-cloud-bigquery==3.3.5
# via -r dev-requirements.in
google-cloud-bigquery-storage==2.16.2
Expand All @@ -143,6 +160,8 @@ google-cloud-core==2.3.2
# via google-cloud-bigquery
google-crc32c==1.5.0
# via google-resumable-media
google-pasta==0.2.0
# via tensorflow
google-resumable-media==2.4.0
# via google-cloud-bigquery
googleapis-common-protos==1.56.4
Expand All @@ -158,12 +177,16 @@ grpcio==1.47.0
# google-api-core
# google-cloud-bigquery
# grpcio-status
# tensorboard
# tensorflow
grpcio-status==1.47.0
# via
# -c requirements.txt
# flytekit
# google-api-core
identify==2.5.6
h5py==3.7.0
# via tensorflow
identify==2.5.7
# via pre-commit
idna==3.4
# via
Expand All @@ -176,6 +199,7 @@ importlib-metadata==5.0.0
# flytekit
# jsonschema
# keyring
# markdown
# pluggy
# pre-commit
# pytest
Expand All @@ -190,11 +214,6 @@ jaraco-classes==3.2.3
# keyring
jedi==0.18.1
# via ipython
jeepney==0.8.0
# via
# -c requirements.txt
# keyring
# secretstorage
jinja2==3.1.2
# via
# -c requirements.txt
Expand All @@ -214,14 +233,23 @@ jsonschema==3.2.0
# via
# -c requirements.txt
# docker-compose
keras==2.10.0
# via tensorflow
keras-preprocessing==1.1.2
# via tensorflow
keyring==23.9.3
# via
# -c requirements.txt
# flytekit
libclang==14.0.6
# via tensorflow
markdown==3.4.1
# via tensorboard
markupsafe==2.1.1
# via
# -c requirements.txt
# jinja2
# werkzeug
marshmallow==3.18.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -261,15 +289,25 @@ numpy==1.21.6
# via
# -c requirements.txt
# flytekit
# h5py
# keras-preprocessing
# opt-einsum
# pandas
# pyarrow
# tensorboard
# tensorflow
oauthlib==3.2.2
# via requests-oauthlib
opt-einsum==3.3.0
# via tensorflow
packaging==21.3
# via
# -c requirements.txt
# docker
# google-cloud-bigquery
# marshmallow
# pytest
# tensorflow
pandas==1.3.5
# via
# -c requirements.txt
Expand All @@ -294,7 +332,7 @@ proto-plus==1.22.1
# via
# google-cloud-bigquery
# google-cloud-bigquery-storage
protobuf==3.20.3
protobuf==3.19.6
# via
# -c requirements.txt
# flyteidl
Expand All @@ -306,6 +344,8 @@ protobuf==3.20.3
# grpcio-status
# proto-plus
# protoc-gen-swagger
# tensorboard
# tensorflow
protoc-gen-swagger==0.1.0
# via
# -c requirements.txt
Expand All @@ -315,7 +355,6 @@ ptyprocess==0.7.0
py==1.11.0
# via
# -c requirements.txt
# pytest
# retry
pyarrow==6.0.1
# via
Expand Down Expand Up @@ -348,7 +387,7 @@ pyrsistent==0.18.1
# via
# -c requirements.txt
# jsonschema
pytest==7.1.3
pytest==7.2.0
# via
# -r dev-requirements.in
# pytest-cov
Expand Down Expand Up @@ -407,7 +446,11 @@ requests==2.28.1
# flytekit
# google-api-core
# google-cloud-bigquery
# requests-oauthlib
# responses
# tensorboard
requests-oauthlib==1.3.1
# via google-auth-oauthlib
responses==0.22.0
# via
# -c requirements.txt
Expand All @@ -418,23 +461,23 @@ retry==0.9.2
# flytekit
rsa==4.9
# via google-auth
secretstorage==3.3.3
# via
# -c requirements.txt
# keyring
singledispatchmethod==1.0
# via
# -c requirements.txt
# flytekit
six==1.16.0
# via
# -c requirements.txt
# astunparse
# dockerpty
# google-auth
# google-pasta
# grpcio
# jsonschema
# keras-preprocessing
# paramiko
# python-dateutil
# tensorflow
# websocket-client
sortedcontainers==2.4.0
# via
Expand All @@ -444,6 +487,20 @@ statsd==3.3.0
# via
# -c requirements.txt
# flytekit
tensorboard==2.10.1
# via tensorflow
tensorboard-data-server==0.6.1
# via tensorboard
tensorboard-plugin-wit==1.8.1
# via tensorboard
tensorflow==2.10.0
# via -r dev-requirements.in
tensorflow-estimator==2.10.0
# via tensorflow
tensorflow-io-gcs-filesystem==0.27.0
# via tensorflow
termcolor==2.0.1
# via tensorflow
text-unidecode==1.3
# via
# -c requirements.txt
Expand Down Expand Up @@ -480,6 +537,7 @@ typing-extensions==4.4.0
# importlib-metadata
# mypy
# responses
# tensorflow
# torch
# typing-inspect
typing-inspect==0.8.0
Expand All @@ -493,7 +551,7 @@ urllib3==1.26.12
# flytekit
# requests
# responses
virtualenv==20.16.5
virtualenv==20.16.6
# via pre-commit
wcwidth==0.2.5
# via prompt-toolkit
Expand All @@ -502,15 +560,20 @@ websocket-client==0.59.0
# -c requirements.txt
# docker
# docker-compose
werkzeug==2.2.2
# via tensorboard
wheel==0.37.1
# via
# -c requirements.txt
# astunparse
# flytekit
# tensorboard
wrapt==1.14.1
# via
# -c requirements.txt
# deprecated
# flytekit
# tensorflow
zipp==3.9.0
# via
# -c requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch
from flytekit.extras import pytorch, tensorflow
from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand Down
26 changes: 26 additions & 0 deletions flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
Flytekit TensorFlow
=========================================
.. currentmodule:: flytekit.extras.tensorflow

"""
from flytekit.loggers import logger

# TODO: abstract this out so that there's an established pattern for registering plugins
# that have soft dependencies
try:
# isolate the exception to the tensorflow import
import tensorflow

_tensorflow_installed = True
except (ImportError, OSError):
_tensorflow_installed = False


if _tensorflow_installed:
from .layer import TensorFlowLayerTransformer
from .model import TensorFlowModelTransformer
else:
logger.info(
"Unable to register TensorFlowModelTransformer, TensorFlowLayerTransformer because tensorflow is not installed."
)
46 changes: 46 additions & 0 deletions flytekit/extras/tensorflow/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from typing import Type

import tensorflow as tf

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.literals import Literal, Primitive, Scalar
from flytekit.models.types import LiteralType, SimpleType


class TensorFlowLayerTransformer(TypeTransformer[tf.keras.layers.Layer]):
def __init__(self):
super().__init__(name="TensorFlow Layer", t=tf.keras.layers.Layer)

def get_literal_type(self, t: Type[tf.keras.layers.Layer]) -> LiteralType:
return LiteralType(simple=SimpleType.STRING)

def to_literal(
self,
ctx: FlyteContext,
python_val: tf.keras.layers.Layer,
python_type: Type[tf.keras.layers.Layer],
expected: LiteralType,
) -> Literal:
layer_config = tf.keras.layers.serialize(python_val)

return Literal(Scalar(primitive=Primitive(string_value=json.dumps(layer_config))))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.layers.Layer]
) -> tf.keras.layers.Layer:
if not (lv and lv.scalar and lv.scalar.primitive):
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

layer_config = json.loads(lv.scalar.primitive.string_value)
return tf.keras.layers.deserialize(layer_config)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.layers.Layer]:
if literal_type.simple == literal_type.simple.STRUCT:
techytushar marked this conversation as resolved.
Show resolved Hide resolved
return tf.keras.layers.Layer

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(TensorFlowLayerTransformer())
Loading