diff --git a/tests/flytekit/unit/extras/tensorflow/test_native.py b/tests/flytekit/unit/extras/tensorflow/test_native.py index 5a37f70dc5c..12bb853eaea 100644 --- a/tests/flytekit/unit/extras/tensorflow/test_native.py +++ b/tests/flytekit/unit/extras/tensorflow/test_native.py @@ -2,15 +2,23 @@ from flytekit import task, workflow +a = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"foo", b"bar"])) +b = tf.train.Feature(float_list=tf.train.FloatList(value=[1.0, 2.0])) +c = tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])) + @task -def generate_tf_example() -> tf.train.Example: - a = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"foo", b"bar"])) - b = tf.train.Feature(float_list=tf.train.FloatList(value=[1.0, 2.0])) +def generate_tf_example_1() -> tf.train.Example: features = tf.train.Features(feature=dict(a=a, b=b)) return tf.train.Example(features=features) +@task +def generate_tf_example_2() -> tf.train.Example: + features = tf.train.Features(feature=dict(a=a, b=b, c=c)) + return tf.train.Example(features=features) + + @task def t1(example: tf.train.Example) -> tf.train.Example: assert example.features.feature["a"].bytes_list.value == [b"foo", b"bar"] @@ -20,21 +28,28 @@ def t1(example: tf.train.Example) -> tf.train.Example: @task def t2(example: tf.train.Example) -> tf.train.Example: - # add a third feature - int_feat = tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4])) - example.features.feature.get_or_create("c") - example.features.feature.setdefault("c", int_feat) + assert example.features.feature["c"].int64_list.value == [3, 4] return example +@task +def t3(example: tf.train.Example): + feature_description = { + "b": tf.io.RaggedFeature(dtype=float), + "c": tf.io.RaggedFeature(dtype=tf.int64), + } + parsed = tf.io.parse_example(example.SerializeToString(), feature_description) + result = dict(map(lambda x: (x[0], x[1].numpy().tolist()), parsed.items())) + assert result == {"b": [1.0, 2.0], "c": [3, 4]} + + @workflow -def wf() -> tf.train.Example: - t1(tensor=generate_tf_example()) - result = t2(tensor=generate_tf_example()) - return result +def wf(): + t1(example=generate_tf_example_1()) + input_t3 = t2(example=generate_tf_example_2()) + t3(example=input_t3) @workflow def test_wf(): - example = wf() - assert example.features.feature["c"].int64_list.value == [3, 4] + wf()