Skip to content

Commit

Permalink
refactor: update tensorflow name
Browse files Browse the repository at this point in the history
Signed-off-by: Tushar Mittal <[email protected]>
  • Loading branch information
techytushar authored and Tushar Mittal committed Oct 28, 2022
1 parent cc95068 commit 010413b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 24 deletions.
8 changes: 4 additions & 4 deletions flytekit/extras/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Flytekit Tensorflow
Flytekit TensorFlow
=========================================
.. currentmodule:: flytekit.extras.tensorflow
Expand All @@ -18,9 +18,9 @@


if _tensorflow_installed:
from .layer import TensorflowLayerTransformer
from .model import TensorflowModelTransformer
from .layer import TensorFlowLayerTransformer
from .model import TensorFlowModelTransformer
else:
logger.info(
"Unable to register TensorflowModelTransformer, TensorflowLayerTransformer because tensorflow is not installed."
"Unable to register TensorFlowModelTransformer, TensorFlowLayerTransformer because tensorflow is not installed."
)
10 changes: 4 additions & 6 deletions flytekit/extras/tensorflow/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@
from flytekit.models.types import LiteralType, SimpleType


class TensorflowLayerTransformer(TypeTransformer[tf.keras.layers.Layer]):
def __init__(
self,
):
super().__init__(name="Tensorflow Layer", t=tf.keras.layers.Layer)
class TensorFlowLayerTransformer(TypeTransformer[tf.keras.layers.Layer]):
def __init__(self):
super().__init__(name="TensorFlow Layer", t=tf.keras.layers.Layer)

def get_literal_type(self, t: Type[tf.keras.layers.Layer]) -> LiteralType:
return LiteralType(simple=SimpleType.STRING)
Expand Down Expand Up @@ -45,4 +43,4 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.layers.L
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(TensorflowLayerTransformer())
TypeEngine.register(TensorFlowLayerTransformer())
10 changes: 4 additions & 6 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@
from flytekit.models.types import LiteralType


class TensorflowModelTransformer(TypeTransformer[tf.keras.Model]):
class TensorFlowModelTransformer(TypeTransformer[tf.keras.Model]):
TENSORFLOW_FORMAT = "TensorflowModel"

def __init__(
self,
):
super().__init__(name="Tensorflow Model", t=tf.keras.Model)
def __init__(self):
super().__init__(name="TensorFlow Model", t=tf.keras.Model)

def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
return LiteralType(
Expand Down Expand Up @@ -75,4 +73,4 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]:
raise ValueError(f"Transformer {self} cannot reverse {literal_type}")


TypeEngine.register(TensorflowModelTransformer())
TypeEngine.register(TensorFlowModelTransformer())
16 changes: 8 additions & 8 deletions tests/flytekit/unit/extras/tensorflow/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flytekit import task
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.extras.tensorflow import TensorflowLayerTransformer, TensorflowModelTransformer
from flytekit.extras.tensorflow import TensorFlowLayerTransformer, TensorFlowModelTransformer
from flytekit.models.core.types import BlobType
from flytekit.models.literals import BlobMetadata
from flytekit.models.types import LiteralType, SimpleType
Expand All @@ -35,8 +35,8 @@ def get_tf_model():
@pytest.mark.parametrize(
"transformer,python_type,format",
[
(TensorflowModelTransformer(), tf.keras.Model, TensorflowModelTransformer.TENSORFLOW_FORMAT),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None),
(TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT),
(TensorFlowLayerTransformer(), tf.keras.layers.Layer, None),
],
)
def test_get_literal_type(transformer, python_type, format):
Expand All @@ -50,10 +50,10 @@ def test_get_literal_type(transformer, python_type, format):
@pytest.mark.parametrize(
"transformer,python_type,format,python_val",
[
(TensorflowModelTransformer(), tf.keras.Model, TensorflowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Dense(4)),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Conv1D(8, 1, activation="relu")),
(TensorflowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Softmax()),
(TensorFlowModelTransformer(), tf.keras.Model, TensorFlowModelTransformer.TENSORFLOW_FORMAT, get_tf_model()),
(TensorFlowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Dense(4)),
(TensorFlowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Conv1D(8, 1, activation="relu")),
(TensorFlowLayerTransformer(), tf.keras.layers.Layer, None, tf.keras.layers.Softmax()),
],
)
def test_to_python_value_and_literal(transformer, python_type, format, python_val):
Expand Down Expand Up @@ -86,7 +86,7 @@ def t1() -> tf.keras.Model:
return get_tf_model()

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorflowModelTransformer.TENSORFLOW_FORMAT
assert task_spec.template.interface.outputs["o0"].type.blob.format is TensorFlowModelTransformer.TENSORFLOW_FORMAT


def test_example_layer():
Expand Down

0 comments on commit 010413b

Please sign in to comment.