From 7ab7854982744c43ca9e306d5a905dc9e66a052d Mon Sep 17 00:00:00 2001 From: Josh McGrath Date: Fri, 21 Jun 2024 11:41:19 -0700 Subject: [PATCH 1/2] apply Jack's list patch --- flytekit/core/type_engine.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4937703ef4..2c42152f5a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -796,11 +796,22 @@ def tag(expected_python_type: Type[T]) -> str: def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)}) + + def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal: + if len(elems) == 0: + return Literal(collection=LiteralCollection(literals=[])) + st = type(elems[0]) + lt = TypeEngine.to_literal_type(st) + lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems] + return Literal(collection=LiteralCollection(literals=lits)) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(cast(Message, python_val))) + message_dict = _MessageToDict(cast(Message, python_val)) + if isinstance(message_dict, list): + return self._handle_list_literal(ctx, message_dict) + struct.update(message_dict) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) From ccbb5a62f385aeb180580e2458a471cd350d6a47 Mon Sep 17 00:00:00 2001 From: Josh McGrath Date: Fri, 21 Jun 2024 13:11:16 -0700 Subject: [PATCH 2/2] some updates for optionals and lists --- flytekit/core/type_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2c42152f5a..567778c743 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,7 +15,6 @@ from collections import OrderedDict from functools import lru_cache from typing import Dict, List, NamedTuple, Optional, Type, cast - from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -153,7 +152,7 @@ def type_assertions_enabled(self) -> bool: return self._type_assertions_enabled def assert_type(self, t: Type[T], v: T): - if not hasattr(t, "__origin__") and not isinstance(v, t): + if not ((get_origin(t) is not None) or isinstance(v, t)): raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @abstractmethod