diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index 6fb2bd7375..754adcf6a2 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -19,8 +19,8 @@ if _tensorflow_installed: from .layer import TensorFlowLayerTransformer - from .model import TensorFlowModelTransformer + from .model import TensorFlowModelTransformer, TensorFlowSequentialTransformer else: logger.info( - "Unable to register TensorFlowModelTransformer, TensorFlowLayerTransformer because tensorflow is not installed." + "Unable to register TensorFlowModelTransformer, TensorFlowSequentialTransformer, TensorFlowLayerTransformer because tensorflow is not installed." ) diff --git a/flytekit/extras/tensorflow/layer.py b/flytekit/extras/tensorflow/layer.py index 3e189dc073..03f922f41e 100644 --- a/flytekit/extras/tensorflow/layer.py +++ b/flytekit/extras/tensorflow/layer.py @@ -37,7 +37,7 @@ def to_python_value( return tf.keras.layers.deserialize(layer_config) def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.layers.Layer]: - if literal_type.simple == literal_type.simple.STRUCT: + if literal_type is not None and literal_type.simple == SimpleType.STRING: return tf.keras.layers.Layer raise ValueError(f"Transformer {self} cannot reverse {literal_type}") diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py index c121400ee9..1d2beee039 100644 --- a/flytekit/extras/tensorflow/model.py +++ b/flytekit/extras/tensorflow/model.py @@ -1,5 +1,5 @@ import pathlib -from typing import Type +from typing import Generic, Type, TypeVar import tensorflow as tf @@ -9,14 +9,13 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType +T = TypeVar("T") -class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]): - TENSORFLOW_FORMAT = "TensorflowModel" - def __init__(self): - super().__init__(name="TensorFlow Model", t=tf.keras.Model) +class TensorFlowModelTransformerBase(TypeTransformer, Generic[T]): + TENSORFLOW_FORMAT = "TensorFlowModel" - def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: + def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( format=self.TENSORFLOW_FORMAT, @@ -27,8 +26,8 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: def to_literal( self, ctx: FlyteContext, - python_val: tf.keras.Model, - python_type: Type[tf.keras.Model], + python_val: T, + python_type: Type[T], expected: LiteralType, ) -> Literal: meta = BlobMetadata( @@ -38,39 +37,48 @@ def to_literal( ) ) - local_path = ctx.file_access.get_random_local_path() + ".h5" + local_path = ctx.file_access.get_random_local_path() pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - # save tensorflow model in h5 format + # save tensorflow model in SavedModel format tf.keras.models.save_model(python_val, local_path) remote_path = ctx.file_access.get_random_remote_path(local_path) - ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + ctx.file_access.put_data(local_path, remote_path, is_multipart=True) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) - def to_python_value( - self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model] - ) -> tf.keras.Model: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: try: uri = lv.scalar.blob.uri except AttributeError: TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() - ctx.file_access.get_data(uri, local_path, is_multipart=False) + ctx.file_access.get_data(uri, local_path, is_multipart=True) - # load tensorflow model from the h5 format + # load tensorflow model from the SavedModel format return tf.keras.models.load_model(local_path) - def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]: + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: if ( literal_type.blob is not None and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.TENSORFLOW_FORMAT ): - return tf.keras.Model + return T raise ValueError(f"Transformer {self} cannot reverse {literal_type}") +class TensorFlowModelTransformer(TensorFlowModelTransformerBase[tf.keras.Model]): + def __init__(self): + super().__init__(name="TensorFlow Model", t=tf.keras.Model) + + +class TensorFlowSequentialTransformer(TensorFlowModelTransformerBase[tf.keras.Sequential]): + def __init__(self): + super().__init__(name="TensorFlow Sequential Model", t=tf.keras.Sequential) + + TypeEngine.register(TensorFlowModelTransformer()) +TypeEngine.register(TensorFlowSequentialTransformer()) diff --git a/tests/flytekit/unit/extras/tensorflow/test_layer.py b/tests/flytekit/unit/extras/tensorflow/test_layer.py index aafd6d4e83..15dca9ad41 100644 --- a/tests/flytekit/unit/extras/tensorflow/test_layer.py +++ b/tests/flytekit/unit/extras/tensorflow/test_layer.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import List import tensorflow as tf @@ -31,11 +31,10 @@ def generate_sequential_model() -> List[tf.keras.layers.Layer]: @task -def get_layers_weights(layers: List[tf.keras.layers.Layer]) -> List[Any]: +def get_layers_weights(layers: List[tf.keras.layers.Layer]) -> List[tf.Variable]: return layers[-1].weights -@workflow def wf(): dense_layer = get_layer() layers = generate_sequential_model() diff --git a/tests/flytekit/unit/extras/tensorflow/test_model.py b/tests/flytekit/unit/extras/tensorflow/test_model.py index 74efbc2a5e..4838d2ed16 100644 --- a/tests/flytekit/unit/extras/tensorflow/test_model.py +++ b/tests/flytekit/unit/extras/tensorflow/test_model.py @@ -19,7 +19,7 @@ def generate_model() -> tf.keras.Model: @task -def generate_sequential_model() -> tf.keras.Model: +def generate_sequential_model() -> tf.keras.Sequential: model = tf.keras.Sequential( [ tf.keras.layers.Input(shape=(32,)), @@ -42,7 +42,6 @@ def model_forward_pass(model: tf.keras.Model) -> tf.Tensor: return model(x) -@workflow def wf(): model1 = generate_model() model2 = generate_sequential_model()