Skip to content

Commit

Permalink
add task for tfrecord dir with no config in test
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nazareth <[email protected]>
  • Loading branch information
ryankarlos committed Dec 2, 2022
1 parent d45efc5 commit d6ae50c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
12 changes: 1 addition & 11 deletions flytekit/extras/tensorflow/record.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, overload
from typing import Optional, Tuple, Type, Union

import tensorflow as tf
from dataclasses_json import dataclass_json
Expand Down Expand Up @@ -36,16 +36,6 @@ class TFRecordDatasetConfig:
name: Optional[str] = None


@overload
def extract_metadata_and_uri(t: TFRecordFile) -> Tuple[TFRecordFile, TFRecordDatasetConfig]:
...


@overload
def extract_metadata_and_uri(t: TFRecordsDirectory) -> Tuple[TFRecordsDirectory, TFRecordDatasetConfig]:
...


def extract_metadata_and_uri(
lv: Literal, t: Type[Union[TFRecordFile, TFRecordsDirectory]]
) -> Tuple[Union[TFRecordFile, TFRecordsDirectory], TFRecordDatasetConfig]:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/directory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TensorBoard.
"""

tfrecords_dir = typing.TypeVar("tfrecords_dir")
tfrecords_dir = typing.TypeVar("tfrecord")
TFRecordsDirectory = FlyteDirectory[tfrecords_dir]
"""
This type can be used to denote that the output is a folder that contains tensorflow record files.
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def check_and_convert_to_str(item: typing.Union[typing.Type, str]) -> str:
#: decoration and useful for attaching content type information with the file and automatically documenting code.
ONNXFile = FlyteFile[onnx]

tfrecords_file = Annotated[str, FileExt("tfrecords_file")]
tfrecords_file = Annotated[str, FileExt("tfrecord")]
#: Can be used to receive or return an TFRecordFile. The underlying type is a FlyteFile type. This is just a
#: decoration and useful for attaching content type information with the file and automatically documenting code.
TFRecordFile = FlyteFile[tfrecords_file]
19 changes: 15 additions & 4 deletions tests/flytekit/unit/extras/tensorflow/record/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,22 @@ def t2(dataset: TFRecordFile):


@task
def t3(dataset: Annotated[TFRecordFile, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]:
def t3(dataset: TFRecordsDirectory):

# if not annotated with TFRecordDatasetConfig, all attributes should default to None
assert isinstance(dataset, TFRecordDatasetV2)
assert dataset._compression_type is None
assert dataset._buffer_size is None
assert dataset._num_parallel_reads is None


@task
def t4(dataset: Annotated[TFRecordFile, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]:
return decode_fn(dataset)


@task
def t4(dataset: Annotated[TFRecordsDirectory, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]:
def t5(dataset: Annotated[TFRecordsDirectory, TFRecordDatasetConfig(buffer_size=1024)]) -> Dict[str, np.ndarray]:
return decode_fn(dataset)


Expand All @@ -90,8 +100,9 @@ def wf() -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
files = generate_tf_record_dir()
t1(dataset=file)
t2(dataset=file)
files_res = t3(dataset=file)
dir_res = t4(dataset=files)
t3(dataset=files)
files_res = t4(dataset=file)
dir_res = t5(dataset=files)
return files_res, dir_res


Expand Down

0 comments on commit d6ae50c

Please sign in to comment.