Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytekit][3] [list, dict and nested cases] Binary IDL With MessagePack #2758

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand All @@ -26,6 +27,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

Expand All @@ -42,22 +44,21 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
from flytekit.models.literals import (
Literal,
LiteralCollection,
LiteralMap,
Primitive,
Scalar,
Union,
Void,
)
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
# This is relevant for cases like Dict[int, str].
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.`
def _default_flytekit_decoder(data: bytes) -> Any:
return msgpack.unpackb(data, raw=False, strict_map_key=False)

Check warning on line 59 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L59

Added line #L59 was not covered by tests


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,
Expand Down Expand Up @@ -129,6 +130,8 @@
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = dict()
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = dict()

@property
def name(self):
Expand Down Expand Up @@ -221,6 +224,17 @@
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
)

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
if binary_idl_object.tag == "msgpack":
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder)
self._msgpack_decoder[expected_python_type] = decoder
return decoder.decode(binary_idl_object.value)

Check warning on line 234 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L229-L234

Added lines #L229 - L234 were not covered by tests
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")

Check warning on line 236 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L236

Added line #L236 was not covered by tests

def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str:
"""
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
Expand Down Expand Up @@ -271,6 +285,9 @@
f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

Check warning on line 289 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L289

Added line #L289 was not covered by tests

try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal
res = self._from_literal_transformer(lv)
if type(res) != self._type:
Expand Down Expand Up @@ -366,8 +383,8 @@

def __init__(self):
super().__init__("Object-Dataclass-Transformer", object)
self._encoder: Dict[Type, JSONEncoder] = {}
self._decoder: Dict[Type, JSONDecoder] = {}
self._encoder: Dict[Type, JSONEncoder] = dict()
self._decoder: Dict[Type, JSONDecoder] = dict()

def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
# Skip iterating all attributes in the dataclass if the type of v already matches the expected_type
Expand Down Expand Up @@ -1390,6 +1407,9 @@
return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

Check warning on line 1411 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1411

Added line #L1411 was not covered by tests

try:
lits = lv.collection.literals
except AttributeError:
Expand Down Expand Up @@ -1671,8 +1691,8 @@

class DictTransformer(TypeTransformer[dict]):
"""
Transformer that transforms a univariate dictionary Dict[str, T] to a Literal Map or
transforms a untyped dictionary to a JSON (struct/Generic)
Transformer that transforms an univariate dictionary Dict[str, T] to a Literal Map or
transforms an untyped dictionary to a Binary Scalar Literal with a Struct Literal Type.
"""

def __init__(self):
Expand All @@ -1697,17 +1717,15 @@
return None, None

@staticmethod
def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
from flytekit.types.pickle import FlytePickle

try:
return Literal(
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())),
metadata={"format": "json"},
)
msgpack_bytes = msgpack.dumps(v)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

Check warning on line 1728 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1727-L1728

Added lines #L1727 - L1728 were not covered by tests
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
Expand All @@ -1717,7 +1735,7 @@
),
metadata={"format": "pickle"},
)
raise e
raise TypeTransformerFailedError(f"Cannot convert {v} to Flyte Literal.\n" f"Error Message: {e}")

Check warning on line 1738 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1738

Added line #L1738 was not covered by tests

@staticmethod
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
Expand Down Expand Up @@ -1768,7 +1786,7 @@
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_generic_literal(ctx, python_val, allow_pickle)
return self.dict_to_binary_literal(ctx, python_val, allow_pickle)

Check warning on line 1789 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1789

Added line #L1789 was not covered by tests

lit_map = {}
for k, v in python_val.items():
Expand All @@ -1785,6 +1803,9 @@
return Literal(map=LiteralMap(literals=lit_map))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict:
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

Check warning on line 1807 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1807

Added line #L1807 was not covered by tests

if lv and lv.map and lv.map.literals is not None:
tp = self.dict_types(expected_python_type)

Expand Down
3 changes: 3 additions & 0 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime, timedelta
from unittest import mock
import msgpack
import base64

import pytest
from flyteidl.core.execution_pb2 import TaskExecution
Expand Down Expand Up @@ -161,6 +163,7 @@ async def test_agent(mock_boto_call, mock_return_value):
if "pickle_check" in mock_return_value[0][0]:
assert "pickle_file" in outputs["result"]
else:
outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"]))
assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
Expand Down
5 changes: 3 additions & 2 deletions plugins/flytekit-openai/tests/openai_batch/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import timedelta
from unittest import mock
from unittest.mock import AsyncMock

import msgpack
import base64
import pytest
from flyteidl.core.execution_pb2 import TaskExecution
from flytekitplugins.openai.batch.agent import BatchEndpointMetadata
Expand Down Expand Up @@ -159,7 +160,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context):
outputs = literal_map_string_repr(resource.outputs)
result = outputs["result"]

assert result == batch_retrieve_result.to_dict()
assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict()

# Status: Failed
mock_retrieve.return_value = batch_retrieve_result_failure
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.11",
"msgpack>=1.1.0",
"protobuf!=4.25.0",
"pygments",
"python-json-logger>=2.0.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def test_stable_cache_key():
}
)
key = _calculate_cache_key("task_name_1", "31415", lm)
assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb"
assert key == "task_name_1-31415-e3a85f91467d1e1f721ebe8129b2de31"


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3218,10 +3218,10 @@ def test_union_file_directory():
(typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]),
(typing.List[Annotated[int, "tag"]], [1, 2, 3]),
(typing.List[Annotated[str, "tag"]], ["a", "b", "c"]),
(typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}),
(typing.Dict[int, str], {1: "a", 2: "b", 3: "c"}),
(typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}),
(typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}),
(typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}),
(typing.Dict[str, typing.Dict[int, str]], {"a": {1: "a", 2: "b", 3: "c"}, "b": {4: "d", 5: "e", 6: "f"}}),
(typing.Union[int, str], 42),
(typing.Union[int, str], "hello"),
(typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]),
Expand Down
Loading
Loading