Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Updating flytekit to handle dereferencing lists of promises (local) #5

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def local_execute(
if len(output_names) == 0:
return VoidPromise(self.name)

vals = [Promise(var, outputs_literals[var]) for var in output_names]
vals = [Promise(var, outputs_literals[var], type=self.interface.outputs[var].type) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
Expand Down
50 changes: 48 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import inspect
import typing
import dataclasses
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args
Expand Down Expand Up @@ -90,15 +91,31 @@ def my_wf(in1: int, in2: int) -> int:
var = flyte_interface_types[k]
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
v = resolve_any_nested_promises(v)
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
raise TypeTransformerFailedError(f"Failed argument '{k}': {exc}") from exc

return result


def resolve_any_nested_promises(v: Any):
"""Iterate through v in many forms to resolve any nested promises"""
if isinstance(v, Promise):
return resolve_attr_path_in_promise(v)
if isinstance(v, list):
return [resolve_any_nested_promises(x) for x in v]
if isinstance(v, dict):
return {k: resolve_any_nested_promises(v) for k, v in v.items()}
if isinstance(v, tuple):
return tuple(resolve_any_nested_promises(x) for x in v)
if dataclasses.is_dataclass(v):
# Set the fields of the dataclass to the resolved values
for field in dataclasses.fields(v):
setattr(v, field.name, resolve_any_nested_promises(getattr(v, field.name)))
return v


def resolve_attr_path_in_promise(p: Promise) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
Expand Down Expand Up @@ -141,6 +158,7 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
new_st = _maybe_fix_deserialized_ints(p, new_st)
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
Expand All @@ -149,6 +167,28 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
return p


def _maybe_fix_deserialized_ints(p: Promise, new_st: Any) -> Any:
"""
This function is used to fix the deserialized integers in the promise, in the case where
the promise has a type of int, but the value is deserialized as a float.
"""
if p._type is None:
# No typing, nothing to do
return new_st

if p._type.simple != SimpleType.INTEGER:
# Not an integer, nothing to do
return new_st

if type(new_st) is not int:
if type(new_st) is float:
if int(new_st) == new_st:
return int(new_st)
raise ValueError(f"Resolved value {new_st} is a float, but the promise is an integer")
raise ValueError(f"Resolved value {new_st} is not an integer, but the promise is an integer")
return new_st


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
curr_val = st
for attr in attr_path:
Expand Down Expand Up @@ -596,6 +636,12 @@ def _append_attr(self, key) -> Promise:
# The attr_path on the ref node is for remote execute
new_promise._ref = new_promise.ref.with_attr(key)

if self._type is not None:
if self._type.simple == SimpleType.STRUCT and self._type.structure is not None:
# We should specify the type of this node, such that if it's used alone
# it can be resolved correctly.
new_promise._type = self._type.structure.dataclass_type[key]

return new_promise


Expand Down
75 changes: 48 additions & 27 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from types import NoneType
from typing import Dict, List, NamedTuple, Optional, Type, cast

from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -149,7 +149,7 @@ def type_assertions_enabled(self) -> bool:
return self._type_assertions_enabled

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, t):
if not ((get_origin(t) is not None) or isinstance(v, t)):
raise TypeTransformerFailedError(f"Expected value of type {t} but got '{v}' of type {type(v)}")

@abstractmethod
Expand Down Expand Up @@ -493,22 +493,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

self._make_dataclass_serializable(python_val, python_type)

# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder
# The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It deserializes a JSON string into a data class, and provides additional functionality over JSONEncoder
if hasattr(python_val, "to_json"):
json_str = python_val.to_json()
else:
# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)
try:
json_str = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore

Expand Down Expand Up @@ -652,15 +657,20 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

json_str = _json_format.MessageToJson(lv.scalar.generic)

# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder
# The `from_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`.
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder
if hasattr(expected_python_type, "from_json"):
dc = expected_python_type.from_json(json_str) # type: ignore
else:
# The function looks up or creates a JSONDecoder specifically designed for the object's type.
# This decoder is then used to convert a JSON string into a data class.
try:
decoder = self._decoder[expected_python_type]
except KeyError:
decoder = JSONDecoder(expected_python_type)
self._decoder[expected_python_type] = decoder

dc = decoder.decode(json_str)
dc = decoder.decode(json_str)

dc = self._fix_structured_dataset_type(expected_python_type, dc)
return self._fix_dataclass_int(expected_python_type, dc)
Expand Down Expand Up @@ -696,11 +706,22 @@ def tag(expected_python_type: Type[T]) -> str:

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})

def _handle_list_literal(self, ctx: FlyteContext, elems: list) -> Literal:
if len(elems) == 0:
return Literal(collection=LiteralCollection(literals=[]))
st = type(elems[0])
lt = TypeEngine.to_literal_type(st)
lits = [TypeEngine.to_literal(ctx, x, st, lt) for x in elems]
return Literal(collection=LiteralCollection(literals=lits))

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
struct = Struct()
try:
struct.update(_MessageToDict(cast(Message, python_val)))
message_dict = _MessageToDict(cast(Message, python_val))
if isinstance(message_dict, list):
return self._handle_list_literal(ctx, message_dict)
struct.update(message_dict)
except Exception:
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
return Literal(scalar=Scalar(generic=struct))
Expand Down Expand Up @@ -1051,7 +1072,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if python_val is None and expected and expected.union_type is None:
if (python_val is None and python_type != NoneType) and expected and expected.union_type is None:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
76 changes: 76 additions & 0 deletions test_dataclass_elem_list_construction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from flytekit import task, dynamic, workflow
from dataclasses import dataclass
from mashumaro.mixins.json import DataClassJSONMixin


@dataclass
class IntWrapper(DataClassJSONMixin):
x: int

@task
def get_int() -> int:
return 3

@task
def get_wrapped_int() -> IntWrapper:
return IntWrapper(x=3)

@task
def sum_list(input_list: list[int]) -> int:
return sum(input_list)


@dataclass
class StrWrapper(DataClassJSONMixin):
x: str

@task
def get_str() -> str:
return "5"

@task
def get_wrapped_str() -> StrWrapper:
return StrWrapper(x="3")

@task
def concat_list(input_list: list[str]) -> str:
return "".join(input_list)



@workflow
def convert_list_workflow1() -> int:
"""Here's a simple workflow that takes a list of strings and returns a dataclass with that list."""
promised_int = get_int()
joined_list = [4, promised_int]
return sum_list(input_list=joined_list)

@workflow
def convert_list_workflow2() -> int:
wrapped_int = get_wrapped_int()
joined_list = [4, wrapped_int.x]
return sum_list(input_list=joined_list)

@workflow
def convert_list_workflow3() -> str:
"""Here's a simple workflow that takes a list of strings and returns a dataclass with that list."""
promised_str = get_str()
joined_list = ["4", promised_str]
return concat_list(input_list=joined_list)

@workflow
def convert_list_workflow4() -> str:
wrapped_str = get_wrapped_str()
joined_list = ["4", wrapped_str.x]
return concat_list(input_list=joined_list)


if __name__ == "__main__":
print("Run 1")
print(convert_list_workflow1())
print("Run 2")
print(convert_list_workflow2())
print("Run 3")
print(convert_list_workflow3())
print("Run 4")
print(convert_list_workflow4())