Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing dataclasses which require list protobuf deserialization #2614

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
return self._type_assertions_enabled

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
if not ((get_origin(t) is not None) or isinstance(v, t)):
raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}")

@abstractmethod
Expand Down Expand Up @@ -706,10 +706,24 @@
def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})

def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal:
if len(elems) == 0:
return Literal(collection=LiteralCollection(literals=[]))
st = type(elems[0])
lt = TypeEngine.to_literal_type(st)

Check warning on line 713 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L711-L713

Added lines #L711 - L713 were not covered by tests
lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems]
return Literal(collection=LiteralCollection(literals=lits))

Check warning on line 715 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L715

Added line #L715 was not covered by tests

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
struct = Struct()
try:
struct.update(_MessageToDict(cast(Message, python_val)))
message_dict = _MessageToDict(cast(Message, python_val))
try:
struct.update(message_dict)
except Exception:

Check warning on line 723 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L720-L723

Added lines #L720 - L723 were not covered by tests
if isinstance(message_dict, list):
return self._handle_list_literal(ctx, message_dict)
Comment on lines +722 to +725
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we do this?

if isinstance(message_dict, list):
    return self._handle_list_literal(ctx, message_dict)
else:
    struct.update(message_dict)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like yes, will update

raise

Check warning on line 726 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L725-L726

Added lines #L725 - L726 were not covered by tests
except Exception:
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
return Literal(scalar=Scalar(generic=struct))
Expand Down
29 changes: 29 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3010,6 +3010,35 @@ class DatumDataUnion:
assert datum_mashumaro_orjson.z.isoformat() == pv.z


def test_DataclassTransformer_for_list_attributes():
dataclass_bases = [DataClassJSONMixin, DataClassJsonMixin, object]
for dataclass_base in dataclass_bases:
@dataclass
class DataclassWithList(dataclass_base):
some_strings: list[str]

@task
def create_dataclass_with_list(
some_strings: list[str],
) -> DataclassWithList:
return DataclassWithList(
some_strings=some_strings,
)

@workflow
def convert_list_workflow(
inputs: DataclassWithList,
) -> DataclassWithList:
result_list = create_dataclass_with_list(some_strings=inputs.some_strings)
return result_list

workflow_input = DataclassWithList(
some_strings=["hello", "world"],
)
output = convert_list_workflow(workflow_input)
assert output.some_strings == ["hello", "world"]


def test_dataclass_encoder_and_decoder_registry():
iterations = 10

Expand Down
Loading