Skip to content

Commit

Permalink
refactor: move layer transformer into separate file
Browse files Browse the repository at this point in the history
Signed-off-by: Tushar Mittal <[email protected]>
  • Loading branch information
techytushar authored and Tushar Mittal committed Oct 28, 2022
1 parent 7fbe8df commit cc95068
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 25 deletions.
5 changes: 4 additions & 1 deletion flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@


if _tensorflow_installed:
from .layer import TensorflowLayerTransformer
from .model import TensorflowModelTransformer
else:
logger.info("Unable to register TensorflowModelTransformer because tensorflow is not installed.")
logger.info(
"Unable to register TensorflowModelTransformer, TensorflowLayerTransformer because tensorflow is not installed."
)
48 changes: 48 additions & 0 deletions flytekit/extras/tensorflow/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json as _json
from typing import Type

import tensorflow as tf

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.literals import Literal, Primitive, Scalar
from flytekit.models.types import LiteralType, SimpleType


class TensorflowLayerTransformer(TypeTransformer[tf.keras.layers.Layer]):
def __init__(
self,
):
super().__init__(name="Tensorflow Layer", t=tf.keras.layers.Layer)

def get_literal_type(self, t: Type[tf.keras.layers.Layer]) -> LiteralType:
return LiteralType(simple=SimpleType.STRING)

def to_literal(
self,
ctx: FlyteContext,
python_val: tf.keras.layers.Layer,
python_type: Type[tf.keras.layers.Layer],
expected: LiteralType,
) -> Literal:
layer_config = tf.keras.layers.serialize(python_val)

return Literal(Scalar(primitive=Primitive(string_value=_json.dumps(layer_config))))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.layers.Layer]
) -> tf.keras.layers.Layer:
if not (lv and lv.scalar and lv.scalar.primitive):
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

layer_config = _json.loads(lv.scalar.primitive.string_value)
return tf.keras.layers.deserialize(layer_config)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.layers.Layer]:
if literal_type.simple == literal_type.simple.STRUCT:
return tf.keras.layers.Layer

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(TensorflowLayerTransformer())
4 changes: 2 additions & 2 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def to_literal(
local_path = ctx.file_access.get_random_local_path() + ".h5"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save tensorflow model to a folder in SavedModel format
# save tensorflow model in h5 format
tf.keras.models.save_model(python_val, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
Expand All @@ -61,7 +61,7 @@ def to_python_value(
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

# load tensorflow model from the SavedModel folder
# load tensorflow model from the h5 format
return tf.keras.models.load_model(local_path)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]:
Expand Down
48 changes: 48 additions & 0 deletions tests/flytekit/unit/extras/tensorflow/test_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Any, List

import tensorflow as tf

from flytekit import task, workflow


@task
def get_layer() -> tf.keras.layers.Dense:
layer = tf.keras.layers.Dense(10)
layer(tf.ones((10, 1)))
return layer


@task
def generate_sequential_model() -> List[tf.keras.layers.Layer]:
model = tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(32,)),
tf.keras.layers.Dense(1),
]
)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[
tf.keras.metrics.BinaryAccuracy(),
],
)
return model.layers


@task
def get_layers_weights(layers: List[tf.keras.layers.Layer]) -> List[Any]:
return layers[-1].weights


@workflow
def wf():
dense_layer = get_layer()
layers = generate_sequential_model()
get_layers_weights(layers=[dense_layer])
get_layers_weights(layers=layers)


@workflow
def test_wf():
wf()
11 changes: 5 additions & 6 deletions tests/flytekit/unit/extras/tensorflow/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import List

import tensorflow as tf

from flytekit import task, workflow
Expand Down Expand Up @@ -39,16 +37,17 @@ def generate_sequential_model() -> tf.keras.Model:


@task
def get_model_layers(model: tf.keras.Model) -> List[tf.keras.layers.Layer]:
return model.layers
def model_forward_pass(model: tf.keras.Model) -> tf.Tensor:
x: tf.Tensor = tf.ones((1, 32))
return model(x)


@workflow
def wf():
model1 = generate_model()
model2 = generate_sequential_model()
get_model_layers(model=model1)
get_model_layers(model=model2)
model_forward_pass(model=model1)
model_forward_pass(model=model2)


@workflow
Expand Down
50 changes: 34 additions & 16 deletions tests/flytekit/unit/extras/tensorflow/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections import OrderedDict

import numpy as np
Expand All @@ -8,10 +9,10 @@
from flytekit import task
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.extras.tensorflow import TensorflowModelTransformer
from flytekit.extras.tensorflow import TensorflowLayerTransformer, TensorflowModelTransformer
from flytekit.models.core.types import BlobType
from flytekit.models.literals import BlobMetadata
from flytekit.models.types import LiteralType
from flytekit.models.types import LiteralType, SimpleType
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
Expand All @@ -35,40 +36,48 @@ def get_tf_model():
"transformer,python_type,format",
[
(TensorflowModelTransformer(), tf.keras.Model, TensorflowModelTransformer.TENSORFLOW_FORMAT),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None),
],
)
def test_get_literal_type(transformer, python_type, format):
lt = transformer.get_literal_type(python_type)
assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE))
if isinstance(python_type, tf.keras.Model):
assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE))
elif isinstance(python_type, tf.keras.layers.Layer):
assert lt == LiteralType(simple=SimpleType.STRING)


@pytest.mark.parametrize(
"transformer,python_type,format,python_val",
[
(TensorflowModelTransformer(), tf.keras.Model, TensorflowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Dense(4)),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Conv1D(8, 1, activation="relu")),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Softmax()),
],
)
def test_to_python_value_and_literal(transformer, python_type, format, python_val):
ctx = context_manager.FlyteContext.current_context()
python_val = python_val
lt = transformer.get_literal_type(python_type)

lv = transformer.to_literal(ctx, python_val, type(python_val), lt) # type: ignore
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None

output = transformer.to_python_value(ctx, lv, python_type)
if isinstance(python_val, tf.keras.Model):

if isinstance(python_type, tf.keras.Model):
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None
for w1, w2 in zip(output.weights, python_val.weights):
np.testing.assert_allclose(w1.numpy(), w2.numpy())
assert True
else:
assert isinstance(output, dict)

elif isinstance(python_type, tf.keras.layers.Layer):
assert bool(lv.scalar.primitive.string_value)
json.loads(lv.scalar.primitive.string_value)
assert output.get_config() == python_val.get_config()


def test_example_model():
Expand All @@ -78,3 +87,12 @@ def t1() -> tf.keras.Model:

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorflowModelTransformer.TENSORFLOW_FORMAT


def test_example_layer():
@task
def t1() -> tf.keras.layers.Layer:
return tf.keras.layers.Dense(10)

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert task_spec.template.interface.outputs["o0"].type == LiteralType(simple=SimpleType.STRING)

0 comments on commit cc95068

Please sign in to comment.