Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Aug 19, 2022
1 parent b2e2cf3 commit 148d9a4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 4 deletions.
3 changes: 2 additions & 1 deletion flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from flytekit.models import literals
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
BIGQUERY,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
)

BIGQUERY = "bq"


def _write_to_bq(structured_dataset: StructuredDataset):
table_id = typing.cast(str, structured_dataset.uri).split("://", 1)[1].replace(":", ".")
Expand Down
6 changes: 4 additions & 2 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def __init__(self, python_type: Type[T], protocol: Optional[str] = None, support
:param python_type: The dataframe class in question that you want to register this encoder with
:param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either
"s3" or "s3://". They are the same since the "://" will just be stripped by the constructor.
If None, this encoder will be registered with all protocols that flytekit's data persistence layer
is capable of handling.
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the encoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning.
Expand Down Expand Up @@ -231,6 +233,8 @@ def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, suppor
:param python_type: The dataframe class in question that you want to register this decoder with
:param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either
"s3" or "s3://". They are the same since the "://" will just be stripped by the constructor.
If None, this decoder will be registered with all protocols that flytekit's data persistence layer
is capable of handling.
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the decoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" decoder instead and write a warning.
Expand Down Expand Up @@ -413,8 +417,6 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid
cls.register_for_protocol(h, stripped, False, override)
except DuplicateHandlerError:
...
# Add this?
# cls.register_for_protocol(h, "/", False, override)

elif h.protocol == "":
raise ValueError(f"Use None instead of empty string for registering handler {h}")
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ def test_is_remote():

def test_lister():
x = DataPersistencePlugins.supported_protocols()
main_protocols = set(["file", "/", "gs", "http", "https", "s3"])
main_protocols = {"file", "/", "gs", "http", "https", "s3"}
all_protocols = set([y.replace("://", "") for y in x])
assert main_protocols.issubset(all_protocols)
26 changes: 26 additions & 0 deletions tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def encode(self):
with pytest.raises(ValueError):
StructuredDatasetTransformerEngine.register(TempEncoder("gs://"), default_for_type=False)

with pytest.raises(ValueError, match="Use None instead"):
e = TempEncoder("")
e._protocol = ""
StructuredDatasetTransformerEngine.register(e)

class TempEncoder:
pass

Expand Down Expand Up @@ -202,6 +207,24 @@ def encode(
assert res is empty_format_temp_encoder


def test_slash_register():
class TempEncoder(StructuredDatasetEncoder):
def __init__(self, fmt: str):
super().__init__(MyDF, None, supported_format=fmt)

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
return literals.StructuredDataset(uri="")

# Check that registering with a / triggers the file protocol instead.
StructuredDatasetTransformerEngine.register(TempEncoder("/"))
assert StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("file") is not None


def test_sd():
sd = StructuredDataset(dataframe="hi")
sd.uri = "my uri"
Expand Down Expand Up @@ -266,6 +289,9 @@ def test_convert_schema_type_to_structured_dataset_type():
with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"):
convert_schema_type_to_structured_dataset_type(int)

with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"):
convert_schema_type_to_structured_dataset_type(20)


def test_to_python_value_with_incoming_columns():
# make a literal with a type that has two columns
Expand Down
11 changes: 11 additions & 0 deletions tests/flytekit/unit/core/test_structured_dataset_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetTransformerEngine,
)

my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str)
Expand Down Expand Up @@ -41,3 +42,13 @@ def test_base_isnt_instantiable():

with pytest.raises(TypeError):
StructuredDatasetDecoder(pd.DataFrame, "", "")


def test_arrow():
encoder = basic_dfs.ArrowToParquetEncodingHandler()
decoder = basic_dfs.ParquetToArrowDecodingHandler()
assert encoder.protocol is None
assert decoder.protocol is None
assert encoder.python_type is decoder.python_type
d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["s3"]["parquet"]
assert d is not None

0 comments on commit 148d9a4

Please sign in to comment.