diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index fe10c9024b..0c4d12f443 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -26,9 +26,10 @@ if _tensorflow_installed: + from .model import TensorFlowModelTransformer from .record import TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer else: logger.info( - "We won't register TensorFlowRecordFileTransformer and TensorFlowRecordsDirTransformer " + "We won't register TensorFlowRecordFileTransformer, TensorFlowRecordsDirTransformer and TensorFlowModelTransformer" "because tensorflow is not installed." ) diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py new file mode 100644 index 0000000000..857ec2c984 --- /dev/null +++ b/flytekit/extras/tensorflow/model.py @@ -0,0 +1,76 @@ +import pathlib +from typing import Type + +import tensorflow as tf + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + + +class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]): + TENSORFLOW_FORMAT = "TensorFlowModel" + + def __init__(self): + super().__init__(name="TensorFlow Model", t=tf.keras.Model) + + def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + def to_literal( + self, + ctx: FlyteContext, + python_val: tf.keras.Model, + python_type: Type[tf.keras.Model], + expected: LiteralType, + ) -> Literal: + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.TENSORFLOW_FORMAT, + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) + ) + + local_path = ctx.file_access.get_random_local_path() + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + + # save model in SavedModel format + tf.keras.models.save_model(python_val, local_path) + + remote_path = ctx.file_access.get_random_remote_path() + ctx.file_access.put_data(local_path, remote_path, is_multipart=True) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[tf.keras.Model] + ) -> tf.keras.Model: + try: + uri = lv.scalar.blob.uri + except AttributeError: + TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=True) + + # load model + return tf.keras.models.load_model(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART + and literal_type.blob.format == self.TENSORFLOW_FORMAT + ): + return tf.keras.Model + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + +TypeEngine.register(TensorFlowModelTransformer()) diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index d5d750b521..17e7c37ddd 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -159,7 +159,6 @@ def to_literal( def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[TFRecordsDirectory] ) -> TFRecordDatasetV2: - uri, metadata = extract_metadata_and_uri(lv, expected_python_type) local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(uri, local_dir, is_multipart=True) diff --git a/tests/flytekit/unit/extras/tensorflow/model/__init__.py b/tests/flytekit/unit/extras/tensorflow/model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_model.py b/tests/flytekit/unit/extras/tensorflow/model/test_model.py new file mode 100644 index 0000000000..2464345986 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/model/test_model.py @@ -0,0 +1,54 @@ +import tensorflow as tf + +from flytekit import task, workflow + + +@task +def generate_model() -> tf.keras.Model: + inputs = tf.keras.Input(shape=(32,)) + outputs = tf.keras.layers.Dense(1)(inputs) + model = tf.keras.Model(inputs, outputs) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=[ + tf.keras.metrics.BinaryAccuracy(), + ], + ) + return model + + +@task +def generate_sequential_model() -> tf.keras.Sequential: + model = tf.keras.Sequential( + [ + tf.keras.layers.Input(shape=(32,)), + tf.keras.layers.Dense(1), + ] + ) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), + loss=tf.keras.losses.BinaryCrossentropy(), + metrics=[ + tf.keras.metrics.BinaryAccuracy(), + ], + ) + return model + + +@task +def model_forward_pass(model: tf.keras.Model) -> tf.Tensor: + x: tf.Tensor = tf.ones((1, 32)) + return model(x) + + +@workflow +def wf(): + model1 = generate_model() + model2 = generate_sequential_model() + model_forward_pass(model=model1) + model_forward_pass(model=model2) + + +def test_wf(): + wf() diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py new file mode 100644 index 0000000000..392ab695c5 --- /dev/null +++ b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py @@ -0,0 +1,75 @@ +from collections import OrderedDict + +import numpy as np +import pytest +import tensorflow as tf + +import flytekit +from flytekit import task +from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.extras.tensorflow import TensorFlowModelTransformer +from flytekit.models.core.types import BlobType +from flytekit.models.literals import BlobMetadata +from flytekit.models.types import LiteralType +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def get_tf_model(): + inputs = tf.keras.Input(shape=(32,)) + outputs = tf.keras.layers.Dense(1)(inputs) + tf_model = tf.keras.Model(inputs, outputs) + return tf_model + + +@pytest.mark.parametrize( + "transformer,python_type,format", + [ + (TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT), + ], +) +def test_get_literal_type(transformer, python_type, format): + lt = transformer.get_literal_type(python_type) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.MULTIPART)) + + +@pytest.mark.parametrize( + "transformer,python_type,format,python_val", + [ + (TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()), + ], +) +def test_to_python_value_and_literal(transformer, python_type, format, python_val): + ctx = context_manager.FlyteContext.current_context() + lt = transformer.get_literal_type(python_type) + + lv = transformer.to_literal(ctx, python_val, type(python_val), lt) # type: ignore + output = transformer.to_python_value(ctx, lv, python_type) + + assert lv.scalar.blob.metadata == BlobMetadata( + type=BlobType( + format=format, + dimensionality=BlobType.BlobDimensionality.MULTIPART, + ) + ) + assert lv.scalar.blob.uri is not None + for w1, w2 in zip(output.weights, python_val.weights): + np.testing.assert_allclose(w1.numpy(), w2.numpy()) + + +def test_example_model(): + @task + def t1() -> tf.keras.Model: + return get_tf_model() + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorFlowModelTransformer.TENSORFLOW_FORMAT