Skip to content

Commit

Permalink
fix test_native.py
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nazareth <[email protected]>
  • Loading branch information
ryankarlos committed Oct 18, 2022
1 parent abd30b9 commit 8918108
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions tests/flytekit/unit/extras/tensorflow/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@

from flytekit import task, workflow

a = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"foo", b"bar"]))
b = tf.train.Feature(float_list=tf.train.FloatList(value=[1.0, 2.0]))
c = tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4]))


@task
def generate_tf_example() -> tf.train.Example:
a = tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"foo", b"bar"]))
b = tf.train.Feature(float_list=tf.train.FloatList(value=[1.0, 2.0]))
def generate_tf_example_1() -> tf.train.Example:
features = tf.train.Features(feature=dict(a=a, b=b))
return tf.train.Example(features=features)


@task
def generate_tf_example_2() -> tf.train.Example:
features = tf.train.Features(feature=dict(a=a, b=b, c=c))
return tf.train.Example(features=features)


@task
def t1(example: tf.train.Example) -> tf.train.Example:
assert example.features.feature["a"].bytes_list.value == [b"foo", b"bar"]
Expand All @@ -20,21 +28,28 @@ def t1(example: tf.train.Example) -> tf.train.Example:

@task
def t2(example: tf.train.Example) -> tf.train.Example:
# add a third feature
int_feat = tf.train.Feature(int64_list=tf.train.Int64List(value=[3, 4]))
example.features.feature.get_or_create("c")
example.features.feature.setdefault("c", int_feat)
assert example.features.feature["c"].int64_list.value == [3, 4]
return example


@task
def t3(example: tf.train.Example):
feature_description = {
"b": tf.io.RaggedFeature(dtype=float),
"c": tf.io.RaggedFeature(dtype=tf.int64),
}
parsed = tf.io.parse_example(example.SerializeToString(), feature_description)
result = dict(map(lambda x: (x[0], x[1].numpy().tolist()), parsed.items()))
assert result == {"b": [1.0, 2.0], "c": [3, 4]}


@workflow
def wf() -> tf.train.Example:
t1(tensor=generate_tf_example())
result = t2(tensor=generate_tf_example())
return result
def wf():
t1(example=generate_tf_example_1())
input_t3 = t2(example=generate_tf_example_2())
t3(example=input_t3)


@workflow
def test_wf():
example = wf()
assert example.features.feature["c"].int64_list.value == [3, 4]
wf()

0 comments on commit 8918108

Please sign in to comment.