Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytekit][1][SimpleTransformer] Binary IDL With MessagePack #2756

Closed
wants to merge 6 commits into from

Conversation

Future-Outlier
Copy link
Member

@Future-Outlier Future-Outlier commented Sep 18, 2024

Tracking issue

flyteorg/flyte#5318

Why are the changes needed?

These changes enable the SimpleTransformer to use from_binary_idl when receiving binary IDL inputs. This will be particularly useful when dealing with attribute access.

What changes were proposed in this pull request?

  1. Add the from_binary_idl method to the TypeTransformer class.
  2. Use MessagePackDecoder to decode msgpack bytes into Python values.
  3. add _default_flytekit_decoder in MessagePackDecoder to support future cases, such as Dict[int, str] (by setting strict_map_key=False).

How was this patch tested?

Unit tests, local execution, and remote execution.

Setup process

from dataclasses import dataclass, field
from datetime import datetime, date, timedelta
from flytekit import workflow, task, ImageSpec

flytekit_hash = "c24077bce6e63bf8df0d80dbc2c5e2ff3322bca8"

flytekit = f"git+https://github.com/flyteorg/flytekit.git@{flytekit_hash}"

image = ImageSpec(
    packages=[flytekit],
    apt_packages=["git"],
    registry="localhost:30000",
)

@dataclass
class DC:
    a: int = 1
    b: float = 1.0
    c: bool = True
    d: str = "hello"
    e: datetime = field(default_factory=datetime.now)
    f: date = field(default_factory=date.today)
    g: timedelta = field(default_factory=lambda: timedelta(days=1))

@task(container_image=image)
def t_a(a: int):
    print(a)
    assert(type(a), int)

@task(container_image=image)
def t_b(b: float):
    print(b)
    assert(type(b), float)

@task(container_image=image)
def t_c(c: bool):
    print(c)
    assert(type(c) == bool)

@task(container_image=image)
def t_d(d: str):
    print(d)
    assert(type(d) == str)

@task(container_image=image)
def t_e(e: datetime):
    print(e)
    assert(type(e) == datetime)

@task(container_image=image)
def t_f(f: date):
    print(f)
    assert(type(f) == date)

@task(container_image=image)
def t_g(g: timedelta):
    print(g)
    assert(type(g) == timedelta)

@workflow
def dc_wf(dc: DC):
    t_a(dc.a)
    t_b(dc.b)
    t_c(dc.c)
    t_d(dc.d)
    t_e(dc.e)
    t_f(dc.f)
    t_g(dc.g)

if __name__ == "__main__":
    from flytekit.clis.sdk_in_container import pyflyte
    from click.testing import CliRunner
    import os

    runner = CliRunner()
    path = os.path.realpath(__file__)
    input_val = '{"a": 1, "b": 3.14, "c": true, "d": "hello"}'
    result = runner.invoke(pyflyte.main,
                           ["run", path, "dc_wf", "--dc", input_val])

    print("Local Execution: ", result.output)
    #
    result = runner.invoke(pyflyte.main,
                           ["run", "--remote", path,
                            "dc_wf", "--dc", input_val])
    print("Remote Execution: ", result.output)

Screenshots

  • local execution
image
  • remote execution
image

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Copy link

codecov bot commented Sep 18, 2024

Codecov Report

Attention: Patch coverage is 42.10526% with 11 lines in your changes missing coverage. Please review.

Project coverage is 75.84%. Comparing base (7f54171) to head (e3a258a).
Report is 6 commits behind head on master.

Files with missing lines Patch % Lines
flytekit/core/type_engine.py 42.10% 10 Missing and 1 partial ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##            master    #2756       +/-   ##
============================================
- Coverage   100.00%   75.84%   -24.16%     
============================================
  Files            5      194      +189     
  Lines          122    19784    +19662     
  Branches         0     3899     +3899     
============================================
+ Hits           122    15005    +14883     
- Misses           0     4100     +4100     
- Partials         0      679      +679     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Copy link
Contributor

@wild-endeavor wild-endeavor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help me understand the usage of mashumaro please?

from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function doesn't have anything to do with Mashumaro right? why mention mashumaro in the comments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we want to use it in our mashumaro's decoder.

Copy link
Member Author

@Future-Outlier Future-Outlier Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will be put into _default_flytekit_decoder.
For example, if we access dataclasss.dict_int_str in a workflow, then we will use from_binary_idl here to turn the Binary IDL object to Dict[int, str].

Note:

  1. dict_int_str is Dict[int, str].
  2. Dict[int, str] is a non-strict type

flytekit/core/type_engine.py Outdated Show resolved Hide resolved
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part isn't clear to me... why are we using the MessagePackDecoder from Mashumaro? isn't that just for dataclasses?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part isn't clear to me... why are we using the MessagePackDecoder from Mashumaro? isn't that just for dataclasses?

We use MessagePackDecoder because when accessing attributes from a dataclass, we receive a Binary IDL from propeller, which can be any type (int, float, bool, str, list, dict, dataclass, Pydantic BaseModel, or Flyte types).

The expected flow is: Binary IDL -> msgpack bytes -> python val.

MessagePackDecoder[expected_python_type].decode is more reliable than msgpack.dumps because it guarantees the type is always correct.
(It has expected_python_type as a hint, and it can handle cases like
expected type: float, actual type: int, and convert it to float back.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a puristic pov what you're saying is true, @wild-endeavor, but given that mashumaro is already a dependency and also given the fact that msgpack.loads is unable to unmarshal some values, I'm in favor of leaving this more complex implementation of the top-level from_binary_idl.

Simple cases like this fail with using msgpack.loads:

Python 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from datetime import datetime
>>> from mashumaro.codecs.msgpack import MessagePackEncoder
>>> encoder = MessagePackEncoder(type(datetime.now()))
>>> encoder.encode(datetime.now())
b'\xba2024-09-24T21:52:43.704551'
>>> msgpack.dumps(datetime.now())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/eduardo/repos/flyte-examples/.venv/lib/python3.12/site-packages/msgpack/__init__.py", line 36, in packb
    return Packer(**kwargs).pack(o)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "msgpack/_packer.pyx", line 279, in msgpack._cmsgpack.Packer.pack
  File "msgpack/_packer.pyx", line 276, in msgpack._cmsgpack.Packer.pack
  File "msgpack/_packer.pyx", line 270, in msgpack._cmsgpack.Packer._pack
  File "msgpack/_packer.pyx", line 257, in msgpack._cmsgpack.Packer._pack_inner
TypeError: can not serialize 'datetime.datetime' object

Signed-off-by: Future-Outlier <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants