Skip to content

Commit

Permalink
refactor: pr comments
Browse files Browse the repository at this point in the history
Signed-off-by: Tushar Mittal <[email protected]>
  • Loading branch information
techytushar authored and Tushar Mittal committed Oct 31, 2022
1 parent b3229cd commit cb2cf77
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 26 deletions.
4 changes: 2 additions & 2 deletions flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

if _tensorflow_installed:
from .layer import TensorFlowLayerTransformer
from .model import TensorFlowModelTransformer
from .model import TensorFlowModelTransformer, TensorFlowSequentialTransformer
else:
logger.info(
"Unable to register TensorFlowModelTransformer, TensorFlowLayerTransformer because tensorflow is not installed."
"Unable to register TensorFlowModelTransformer, TensorFlowSequentialTransformer, TensorFlowLayerTransformer because tensorflow is not installed."
)
2 changes: 1 addition & 1 deletion flytekit/extras/tensorflow/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def to_python_value(
return tf.keras.layers.deserialize(layer_config)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.layers.Layer]:
if literal_type.simple == literal_type.simple.STRUCT:
if literal_type is not None and literal_type.simple == SimpleType.STRING:
return tf.keras.layers.Layer

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
Expand Down
44 changes: 26 additions & 18 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Type
from typing import Generic, Type, TypeVar

import tensorflow as tf

Expand All @@ -9,14 +9,13 @@
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

T = TypeVar("T")

class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]):
TENSORFLOW_FORMAT = "TensorflowModel"

def __init__(self):
super().__init__(name="TensorFlow Model", t=tf.keras.Model)
class TensorFlowModelTransformerBase(TypeTransformer, Generic[T]):
TENSORFLOW_FORMAT = "TensorFlowModel"

def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
Expand All @@ -27,8 +26,8 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
def to_literal(
self,
ctx: FlyteContext,
python_val: tf.keras.Model,
python_type: Type[tf.keras.Model],
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
Expand All @@ -38,39 +37,48 @@ def to_literal(
)
)

local_path = ctx.file_access.get_random_local_path() + ".h5"
local_path = ctx.file_access.get_random_local_path()
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save tensorflow model in h5 format
# save tensorflow model in SavedModel format
tf.keras.models.save_model(python_val, local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
ctx.file_access.put_data(local_path, remote_path, is_multipart=True)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model]
) -> tf.keras.Model:
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
try:
uri = lv.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)
ctx.file_access.get_data(uri, local_path, is_multipart=True)

# load tensorflow model from the h5 format
# load tensorflow model from the SavedModel format
return tf.keras.models.load_model(local_path)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]:
def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.TENSORFLOW_FORMAT
):
return tf.keras.Model
return T

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


class TensorFlowModelTransformer(TensorFlowModelTransformerBase[tf.keras.Model]):
def __init__(self):
super().__init__(name="TensorFlow Model", t=tf.keras.Model)


class TensorFlowSequentialTransformer(TensorFlowModelTransformerBase[tf.keras.Sequential]):
def __init__(self):
super().__init__(name="TensorFlow Sequential Model", t=tf.keras.Sequential)


TypeEngine.register(TensorFlowModelTransformer())
TypeEngine.register(TensorFlowSequentialTransformer())
5 changes: 2 additions & 3 deletions tests/flytekit/unit/extras/tensorflow/test_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List
from typing import List

import tensorflow as tf

Expand Down Expand Up @@ -31,11 +31,10 @@ def generate_sequential_model() -> List[tf.keras.layers.Layer]:


@task
def get_layers_weights(layers: List[tf.keras.layers.Layer]) -> List[Any]:
def get_layers_weights(layers: List[tf.keras.layers.Layer]) -> List[tf.Variable]:
return layers[-1].weights


@workflow
def wf():
dense_layer = get_layer()
layers = generate_sequential_model()
Expand Down
3 changes: 1 addition & 2 deletions tests/flytekit/unit/extras/tensorflow/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def generate_model() -> tf.keras.Model:


@task
def generate_sequential_model() -> tf.keras.Model:
def generate_sequential_model() -> tf.keras.Sequential:
model = tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(32,)),
Expand All @@ -42,7 +42,6 @@ def model_forward_pass(model: tf.keras.Model) -> tf.Tensor:
return model(x)


@workflow
def wf():
model1 = generate_model()
model2 = generate_sequential_model()
Expand Down

0 comments on commit cb2cf77

Please sign in to comment.