diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 88c2b39c02..440c95c4c6 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -351,7 +351,7 @@ def local_execute( if len(output_names) == 0: return VoidPromise(self.name) - vals = [Promise(var, outputs_literals[var]) for var in output_names] + vals = [Promise(var, outputs_literals[var], type=self.interface.outputs[var].type) for var in output_names] return create_task_output(vals, self.python_interface) def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index c4f71eb2d6..3276a25e5b 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -3,6 +3,7 @@ import collections import inspect import typing +import dataclasses from copy import deepcopy from enum import Enum from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args @@ -90,8 +91,7 @@ def my_wf(in1: int, in2: int) -> int: var = flyte_interface_types[k] t = native_types[k] try: - if type(v) is Promise: - v = resolve_attr_path_in_promise(v) + v = resolve_any_nested_promises(v) result[k] = TypeEngine.to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc @@ -99,6 +99,23 @@ def my_wf(in1: int, in2: int) -> int: return result +def resolve_any_nested_promises(v: Any): + """Iterate through v in many forms to resolve any nested promises""" + if isinstance(v, Promise): + return resolve_attr_path_in_promise(v) + if isinstance(v, list): + return [resolve_any_nested_promises(x) for x in v] + if isinstance(v, dict): + return {k: resolve_any_nested_promises(v) for k, v in v.items()} + if isinstance(v, tuple): + return tuple(resolve_any_nested_promises(x) for x in v) + if dataclasses.is_dataclass(v): + # Set the fields of the dataclass to the resolved values + for field in dataclasses.fields(v): + setattr(v, field.name, resolve_any_nested_promises(getattr(v, field.name))) + return v + + def resolve_attr_path_in_promise(p: Promise) -> Promise: """ resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value @@ -141,6 +158,7 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: ): st = curr_val.value.value new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) + new_st = _maybe_fix_deserialized_ints(p, new_st) literal_type = TypeEngine.to_literal_type(type(new_st)) # Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) @@ -149,6 +167,28 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: return p +def _maybe_fix_deserialized_ints(p: Promise, new_st: Any) -> Any: + """ + This function is used to fix the deserialized integers in the promise, in the case where + the promise has a type of int, but the value is deserialized as a float. + """ + if p._type is None: + # No typing, nothing to do + return new_st + + if p._type.simple != SimpleType.INTEGER: + # Not an integer, nothing to do + return new_st + + if type(new_st) is not int: + if type(new_st) is float: + if int(new_st) == new_st: + return int(new_st) + raise ValueError(f"Resolved value {new_st} is a float, but the promise is an integer") + raise ValueError(f"Resolved value {new_st} is not an integer, but the promise is an integer") + return new_st + + def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct: curr_val = st for attr in attr_path: @@ -596,6 +636,12 @@ def _append_attr(self, key) -> Promise: # The attr_path on the ref node is for remote execute new_promise._ref = new_promise.ref.with_attr(key) + if self._type is not None: + if self._type.simple == SimpleType.STRUCT and self._type.structure is not None: + # We should specify the type of this node, such that if it's used alone + # it can be resolved correctly. + new_promise._type = self._type.structure.dataclass_type[key] + return new_promise diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index bd617a161a..1353f63118 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -14,8 +14,8 @@ from abc import ABC, abstractmethod from collections import OrderedDict from functools import lru_cache +from types import NoneType from typing import Dict, List, NamedTuple, Optional, Type, cast - from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -149,7 +149,7 @@ def type_assertions_enabled(self) -> bool: return self._type_assertions_enabled def assert_type(self, t: Type[T], v: T): - if not hasattr(t, "__origin__") and not isinstance(v, t): + if not ((get_origin(t) is not None) or isinstance(v, t)): raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}") @abstractmethod @@ -493,22 +493,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp self._make_dataclass_serializable(python_val, python_type) - # The function looks up or creates a JSONEncoder specifically designed for the object's type. - # This encoder is then used to convert a data class into a JSON string. - try: - encoder = self._encoder[python_type] - except KeyError: - encoder = JSONEncoder(python_type) - self._encoder[python_type] = encoder + # The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`. + # It deserializes a JSON string into a data class, and provides additional functionality over JSONEncoder + if hasattr(python_val, "to_json"): + json_str = python_val.to_json() + else: + # The function looks up or creates a JSONEncoder specifically designed for the object's type. + # This encoder is then used to convert a data class into a JSON string. + try: + encoder = self._encoder[python_type] + except KeyError: + encoder = JSONEncoder(python_type) + self._encoder[python_type] = encoder - try: - json_str = encoder.encode(python_val) - except NotImplementedError: - # you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. - raise NotImplementedError( - f"{python_type} should inherit from mashumaro.types.SerializableType" - f" and implement _serialize and _deserialize methods." - ) + try: + json_str = encoder.encode(python_val) + except NotImplementedError: + # you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. + raise NotImplementedError( + f"{python_type} should inherit from mashumaro.types.SerializableType" + f" and implement _serialize and _deserialize methods." + ) return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore @@ -652,15 +657,20 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: json_str = _json_format.MessageToJson(lv.scalar.generic) - # The function looks up or creates a JSONDecoder specifically designed for the object's type. - # This decoder is then used to convert a JSON string into a data class. - try: - decoder = self._decoder[expected_python_type] - except KeyError: - decoder = JSONDecoder(expected_python_type) - self._decoder[expected_python_type] = decoder + # The `from_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`. + # It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder + if hasattr(expected_python_type, "from_json"): + dc = expected_python_type.from_json(json_str) # type: ignore + else: + # The function looks up or creates a JSONDecoder specifically designed for the object's type. + # This decoder is then used to convert a JSON string into a data class. + try: + decoder = self._decoder[expected_python_type] + except KeyError: + decoder = JSONDecoder(expected_python_type) + self._decoder[expected_python_type] = decoder - dc = decoder.decode(json_str) + dc = decoder.decode(json_str) dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, dc) @@ -696,11 +706,22 @@ def tag(expected_python_type: Type[T]) -> str: def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)}) + + def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal: + if len(elems) == 0: + return Literal(collection=LiteralCollection(literals=[])) + st = type(elems[0]) + lt = TypeEngine.to_literal_type(st) + lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems] + return Literal(collection=LiteralCollection(literals=lits)) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(cast(Message, python_val))) + message_dict = _MessageToDict(cast(Message, python_val)) + if isinstance(message_dict, list): + return self._handle_list_literal(ctx, message_dict) + struct.update(message_dict) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) @@ -1051,7 +1072,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if python_val is None and expected and expected.union_type is None: + if (python_val is None and python_type != NoneType) and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: diff --git a/test_dataclass_elem_list_construction.py b/test_dataclass_elem_list_construction.py new file mode 100644 index 0000000000..1dcc9d92ad --- /dev/null +++ b/test_dataclass_elem_list_construction.py @@ -0,0 +1,76 @@ +from flytekit import task, dynamic, workflow +from dataclasses import dataclass +from mashumaro.mixins.json import DataClassJSONMixin + + +@dataclass +class IntWrapper(DataClassJSONMixin): + x: int + +@task +def get_int() -> int: + return 3 + +@task +def get_wrapped_int() -> IntWrapper: + return IntWrapper(x=3) + +@task +def sum_list(input_list: list[int]) -> int: + return sum(input_list) + + +@dataclass +class StrWrapper(DataClassJSONMixin): + x: str + +@task +def get_str() -> str: + return "5" + +@task +def get_wrapped_str() -> StrWrapper: + return StrWrapper(x="3") + +@task +def concat_list(input_list: list[str]) -> str: + return "".join(input_list) + + + +@workflow +def convert_list_workflow1() -> int: + """Here's a simple workflow that takes a list of strings and returns a dataclass with that list.""" + promised_int = get_int() + joined_list = [4, promised_int] + return sum_list(input_list=joined_list) + +@workflow +def convert_list_workflow2() -> int: + wrapped_int = get_wrapped_int() + joined_list = [4, wrapped_int.x] + return sum_list(input_list=joined_list) + +@workflow +def convert_list_workflow3() -> str: + """Here's a simple workflow that takes a list of strings and returns a dataclass with that list.""" + promised_str = get_str() + joined_list = ["4", promised_str] + return concat_list(input_list=joined_list) + +@workflow +def convert_list_workflow4() -> str: + wrapped_str = get_wrapped_str() + joined_list = ["4", wrapped_str.x] + return concat_list(input_list=joined_list) + + +if __name__ == "__main__": + print("Run 1") + print(convert_list_workflow1()) + print("Run 2") + print(convert_list_workflow2()) + print("Run 3") + print(convert_list_workflow3()) + print("Run 4") + print(convert_list_workflow4())