Skip to content

Commit

Permalink
Support default values in typing.List[dataclass] and typing.Dict[data…
Browse files Browse the repository at this point in the history
…class] (flyteorg#2603)

* fix: set dataclass member as optional if default value is provided

Signed-off-by: mao3267 <[email protected]>

* lint

Signed-off-by: mao3267 <[email protected]>

* feat: handle nested dataclass conversion in JsonParamType

Signed-off-by: mao3267 <[email protected]>

* fix: handle errors caused by NoneType default value

Signed-off-by: mao3267 <[email protected]>

* test: add nested dataclass unit tests

Signed-off-by: mao3267 <[email protected]>

* Sagemaker dict determinism (flyteorg#2597)

* truncate sagemaker agent outputs

Signed-off-by: Samhita Alla <[email protected]>

* fix tests and update agent output

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* fix test

Signed-off-by: Samhita Alla <[email protected]>

* add idempotence token to workflow

Signed-off-by: Samhita Alla <[email protected]>

* fix type

Signed-off-by: Samhita Alla <[email protected]>

* fix mixin

Signed-off-by: Samhita Alla <[email protected]>

* modify output handler

Signed-off-by: Samhita Alla <[email protected]>

* make the dictionary deterministic

Signed-off-by: Samhita Alla <[email protected]>

* nit

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* refactor(core): Enhance return type extraction logic (flyteorg#2598)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Feat: Make exception raised by external command authenticator more actionable (flyteorg#2594)

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Fix: Properly re-raise non-grpc exceptions during refreshing of proxy-auth credentials in auth interceptor (flyteorg#2591)

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* validate idempotence token length in subsequent tasks (flyteorg#2604)

* validate idempotence token length in subsequent tasks

Signed-off-by: Samhita Alla <[email protected]>

* remove redundant param

Signed-off-by: Samhita Alla <[email protected]>

* add tests

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Add nvidia-l4 gpu accelerator (flyteorg#2608)

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* eliminate redundant literal conversion for `Iterator[JSON]` type (flyteorg#2602)

* eliminate redundant literal conversion for  type

Signed-off-by: Samhita Alla <[email protected]>

* add test

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* add isclass check

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* [FlyteSchema] Fix numpy problems (flyteorg#2619)

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* add nim plugin (flyteorg#2475)

* add nim plugin

Signed-off-by: Samhita Alla <[email protected]>

* move nim to inference

Signed-off-by: Samhita Alla <[email protected]>

* import fix

Signed-off-by: Samhita Alla <[email protected]>

* fix port

Signed-off-by: Samhita Alla <[email protected]>

* add pod_template method

Signed-off-by: Samhita Alla <[email protected]>

* add containers

Signed-off-by: Samhita Alla <[email protected]>

* update

Signed-off-by: Samhita Alla <[email protected]>

* clean up

Signed-off-by: Samhita Alla <[email protected]>

* remove cloud import

Signed-off-by: Samhita Alla <[email protected]>

* fix extra config

Signed-off-by: Samhita Alla <[email protected]>

* remove decorator

Signed-off-by: Samhita Alla <[email protected]>

* add tests, update readme

Signed-off-by: Samhita Alla <[email protected]>

* add env

Signed-off-by: Samhita Alla <[email protected]>

* add support for lora adapter

Signed-off-by: Samhita Alla <[email protected]>

* minor fixes

Signed-off-by: Samhita Alla <[email protected]>

* add startup probe

Signed-off-by: Samhita Alla <[email protected]>

* increase failure threshold

Signed-off-by: Samhita Alla <[email protected]>

* remove ngc secret group

Signed-off-by: Samhita Alla <[email protected]>

* move plugin to flytekit core

Signed-off-by: Samhita Alla <[email protected]>

* fix docs

Signed-off-by: Samhita Alla <[email protected]>

* remove hf group

Signed-off-by: Samhita Alla <[email protected]>

* modify podtemplate import

Signed-off-by: Samhita Alla <[email protected]>

* fix import

Signed-off-by: Samhita Alla <[email protected]>

* fix ngc api key

Signed-off-by: Samhita Alla <[email protected]>

* fix tests

Signed-off-by: Samhita Alla <[email protected]>

* fix formatting

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* docs fix

Signed-off-by: Samhita Alla <[email protected]>

* docs fix

Signed-off-by: Samhita Alla <[email protected]>

* update secrets interface

Signed-off-by: Samhita Alla <[email protected]>

* add secret prefix

Signed-off-by: Samhita Alla <[email protected]>

* fix tests

Signed-off-by: Samhita Alla <[email protected]>

* add urls

Signed-off-by: Samhita Alla <[email protected]>

* add urls

Signed-off-by: Samhita Alla <[email protected]>

* remove urls

Signed-off-by: Samhita Alla <[email protected]>

* minor modifications

Signed-off-by: Samhita Alla <[email protected]>

* remove secrets prefix; add failure threshold

Signed-off-by: Samhita Alla <[email protected]>

* add hard-coded prefix

Signed-off-by: Samhita Alla <[email protected]>

* add comment

Signed-off-by: Samhita Alla <[email protected]>

* make secrets prefix a required param

Signed-off-by: Samhita Alla <[email protected]>

* move nim to flytekit plugin

Signed-off-by: Samhita Alla <[email protected]>

* update readme

Signed-off-by: Samhita Alla <[email protected]>

* update readme

Signed-off-by: Samhita Alla <[email protected]>

* update readme

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* [Elastic/Artifacts] Pass through model card (flyteorg#2575)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Remove pyarrow as a direct dependency (flyteorg#2228)

Signed-off-by: Thomas J. Fan <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Boolean flag to show local container logs to the terminal (flyteorg#2521)

Signed-off-by: aditya7302 <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Enable Ray Fast Register (flyteorg#2606)

Signed-off-by: Jan Fiedler <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* [Artifacts/Elastic] Skip partitions (flyteorg#2620)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Install flyteidl from master in plugins tests (flyteorg#2621)

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Using ParamSpec to show underlying typehinting (flyteorg#2617)

Signed-off-by: JackUrb <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Support ArrayNode mapping over Launch Plans (flyteorg#2480)

* set up array node

Signed-off-by: Paul Dittamo <[email protected]>

* wip array node task wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* support function like callability

Signed-off-by: Paul Dittamo <[email protected]>

* temp check in some progress on python func wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* only support launch plans in new array node class for now

Signed-off-by: Paul Dittamo <[email protected]>

* add map task array node implementation wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* ArrayNode only supports LPs for now

Signed-off-by: Paul Dittamo <[email protected]>

* support local execute for new array node implementation

Signed-off-by: Paul Dittamo <[email protected]>

* add local execute unit tests for array node

Signed-off-by: Paul Dittamo <[email protected]>

* set exeucution version in array node spec

Signed-off-by: Paul Dittamo <[email protected]>

* check input types for local execute

Signed-off-by: Paul Dittamo <[email protected]>

* remove code that is un-needed for now

Signed-off-by: Paul Dittamo <[email protected]>

* clean up array node class

Signed-off-by: Paul Dittamo <[email protected]>

* improve naming

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* utilize enum execution mode to set array node execution path

Signed-off-by: Paul Dittamo <[email protected]>

* default execution mode to FULL_STATE for new array node class

Signed-off-by: Paul Dittamo <[email protected]>

* support min_successes for new array node

Signed-off-by: Paul Dittamo <[email protected]>

* add map task wrapper unit test

Signed-off-by: Paul Dittamo <[email protected]>

* set min successes for array node map task wrapper

Signed-off-by: Paul Dittamo <[email protected]>

* update docstrings

Signed-off-by: Paul Dittamo <[email protected]>

* Install flyteidl from master in plugins tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* clean up min success/ratio setting

Signed-off-by: Paul Dittamo <[email protected]>

* lint

Signed-off-by: Paul Dittamo <[email protected]>

* make array node class callable

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Richer printing for some artifact objects (flyteorg#2624)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* ci: Add Python 3.9 to build matrix (flyteorg#2622)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* bump (flyteorg#2627)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Added alt prefix head to FlyteFile.new_remote (flyteorg#2601)

* Added alt prefix head to FlyteFile.new_remote

Signed-off-by: pryce-turner <[email protected]>

* Added get_new_path method to FileAccessProvider, fixed new_remote method of FlyteFile

Signed-off-by: pryce-turner <[email protected]>

* Updated tests and added new path creator to FlyteFile/Dir new_remote methods

Signed-off-by: pryce-turner <[email protected]>

* Improved docstrings, fixed minor path sep bug, more descriptive naming, better test

Signed-off-by: pryce-turner <[email protected]>

* Formatting

Signed-off-by: pryce-turner <[email protected]>

---------

Signed-off-by: pryce-turner <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Feature gate for FlyteMissingReturnValueException (flyteorg#2623)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Remove use of multiprocessing from the OAuth client (flyteorg#2626)

* Remove use of multiprocessing from the OAuth client

Signed-off-by: Robert Deaton <[email protected]>

* Lint

Signed-off-by: Robert Deaton <[email protected]>

---------

Signed-off-by: Robert Deaton <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Update codespell in precommit to version 2.3.0 (flyteorg#2630)

Signed-off-by: mao3267 <[email protected]>

* Fix Snowflake Agent Bug (flyteorg#2605)

* fix snowflake agent bug

Signed-off-by: Future-Outlier <[email protected]>

* a work version

Signed-off-by: Future-Outlier <[email protected]>

* Snowflake work version

Signed-off-by: Future-Outlier <[email protected]>

* fix secret encode

Signed-off-by: Future-Outlier <[email protected]>

* all works, I am so happy

Signed-off-by: Future-Outlier <[email protected]>

* improve additional protocol

Signed-off-by: Future-Outlier <[email protected]>

* fix tests

Signed-off-by: Future-Outlier <[email protected]>

* Fix Tests

Signed-off-by: Future-Outlier <[email protected]>

* update agent

Signed-off-by: Kevin Su <[email protected]>

* Add snowflake test

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* sd

Signed-off-by: Kevin Su <[email protected]>

* snowflake loglinks

Signed-off-by: Future-Outlier <[email protected]>

* add metadata

Signed-off-by: Future-Outlier <[email protected]>

* secret

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* remove table

Signed-off-by: Future-Outlier <[email protected]>

* add comment for get private key

Signed-off-by: Future-Outlier <[email protected]>

* update comments:

Signed-off-by: Future-Outlier <[email protected]>

* Fix Tests

Signed-off-by: Future-Outlier <[email protected]>

* update comments

Signed-off-by: Future-Outlier <[email protected]>

* update comments

Signed-off-by: Future-Outlier <[email protected]>

* Better Secrets

Signed-off-by: Future-Outlier <[email protected]>

* use union secret

Signed-off-by: Future-Outlier <[email protected]>

* Update Changes

Signed-off-by: Future-Outlier <[email protected]>

* use if not get_plugin().secret_requires_group()

Signed-off-by: Future-Outlier <[email protected]>

* Use Union SDK

Signed-off-by: Future-Outlier <[email protected]>

* Update

Signed-off-by: Future-Outlier <[email protected]>

* Fix Secrets

Signed-off-by: Future-Outlier <[email protected]>

* Fix Secrets

Signed-off-by: Future-Outlier <[email protected]>

* remove pacakge.json

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* add snowflake-connector-python

Signed-off-by: Future-Outlier <[email protected]>

* fix test_snowflake

Signed-off-by: Future-Outlier <[email protected]>

* Try to fix tests

Signed-off-by: Future-Outlier <[email protected]>

* fix tests

Signed-off-by: Future-Outlier <[email protected]>

* Try Fix snowflake Import

Signed-off-by: Future-Outlier <[email protected]>

* snowflake test passed

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* run test_missing_return_value on python 3.10+ (flyteorg#2637)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* [Elastic] Fix context usage and apply fix to fork method (flyteorg#2628)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Add flytekit-omegaconf plugin (flyteorg#2299)

* add flytekit-hydra

Signed-off-by: mg515 <[email protected]>

* fix small typo readme

Signed-off-by: mg515 <[email protected]>

* ruff ruff

Signed-off-by: mg515 <[email protected]>

* lint more

Signed-off-by: mg515 <[email protected]>

* rename plugin into flytekit-omegaconf

Signed-off-by: mg515 <[email protected]>

* lint sort imports

Signed-off-by: mg515 <[email protected]>

* use flytekit logger

Signed-off-by: mg515 <[email protected]>

* use flytekit logger #2

Signed-off-by: mg515 <[email protected]>

* fix typing info in is_flatable

Signed-off-by: mg515 <[email protected]>

* use default_factory instead of mutable default value

Signed-off-by: mg515 <[email protected]>

* add python3.11 and python3.12 to setup.py

Signed-off-by: mg515 <[email protected]>

* make fmt

Signed-off-by: mg515 <[email protected]>

* define error message only once

Signed-off-by: mg515 <[email protected]>

* add docstring

Signed-off-by: mg515 <[email protected]>

* remove GenericEnumTransformer and tests

Signed-off-by: mg515 <[email protected]>

* fallback to TypeEngine.get_transformer(node_type) to find suitable transformer

Signed-off-by: mg515 <[email protected]>

* explicit valueerrors instead of asserts

Signed-off-by: mg515 <[email protected]>

* minor style improvements

Signed-off-by: mg515 <[email protected]>

* remove obsolete warnings

Signed-off-by: mg515 <[email protected]>

* import flytekit logger instead of instantiating our own

Signed-off-by: mg515 <[email protected]>

* docstrings in reST format

Signed-off-by: mg515 <[email protected]>

* refactor transformer mode

Signed-off-by: mg515 <[email protected]>

* improve docs

Signed-off-by: mg515 <[email protected]>

* refactor dictconfig class into smaller methods

Signed-off-by: mg515 <[email protected]>

* add unit tests for dictconfig transformer

Signed-off-by: mg515 <[email protected]>

* refactor of parse_type_description()

Signed-off-by: mg515 <[email protected]>

* add omegaconf plugin to pythonbuild.yaml

---------

Signed-off-by: mg515 <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Adds extra-index-url to default image builder (flyteorg#2636)

Signed-off-by: Thomas J. Fan <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* reference_task should inherit from PythonTask (flyteorg#2643)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* Fix Get Agent Secret Using Key (flyteorg#2644)

Signed-off-by: Future-Outlier <[email protected]>
Signed-off-by: mao3267 <[email protected]>

* fix: prevent converting Flyte types as custom dataclasses

Signed-off-by: mao3267 <[email protected]>

* fix: add None to output type

Signed-off-by: mao3267 <[email protected]>

* test: add unit test for nested dataclass inputs

Signed-off-by: mao3267 <[email protected]>

* test: add unit tests for nested dataclass, dataclass default value as None, and flyte type exceptions

Signed-off-by: mao3267 <[email protected]>

* fix: handle NoneType as default value of list type dataclass members

Signed-off-by: mao3267 <[email protected]>

* fix: add comments for `has_nested_dataclass` function

Signed-off-by: mao3267 <[email protected]>

* fix: make lint

Signed-off-by: mao3267 <[email protected]>

* fix: update tests regarding input through file and pipe

Signed-off-by: mao3267 <[email protected]>

* Make JsonParamType convert faster

Signed-off-by: Future-Outlier <[email protected]>

* make has_nested_dataclass func more clean and add tests for dataclass_with_optional_fields

Signed-off-by: Future-Outlier <[email protected]>

* make logic more backward compatible

Signed-off-by: Future-Outlier <[email protected]>

* fix: handle indexing errors in dict/list while checking nested dataclass, add comments

Signed-off-by: mao3267 <[email protected]>

---------

Signed-off-by: mao3267 <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
  • Loading branch information
3 people authored Aug 26, 2024
1 parent 54f0a46 commit 83b90fa
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 4 deletions.
9 changes: 8 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):

expected_type = get_underlying_type(expected_type)
expected_fields_dict = {}

for f in dataclasses.fields(expected_type):
expected_fields_dict[f.name] = f.type

Expand Down Expand Up @@ -539,11 +540,13 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
field.type = self._get_origin_type_in_annotation(field.type)
return python_type

def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T | None:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit.types.structured import StructuredDataset

if python_val is None:
return python_val
if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
elif get_origin(python_type) is list:
Expand Down Expand Up @@ -575,9 +578,13 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t
return self._make_dataclass_serializable(python_val, get_args(python_type)[0])

if hasattr(python_type, "__origin__") and get_origin(python_type) is list:
if python_val is None:
return None
return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)]

if hasattr(python_type, "__origin__") and get_origin(python_type) is dict:
if python_val is None:
return None
return {
k: self._make_dataclass_serializable(v, get_args(python_type)[1])
for k, v in cast(dict, python_val).items()
Expand Down
43 changes: 42 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
import datetime
import enum
import json
import logging
import os
import pathlib
import typing
from typing import cast
from typing import cast, get_args

import rich_click as click
import yaml
Expand All @@ -22,6 +23,7 @@
from flytekit.types.file import FlyteFile
from flytekit.types.iterator.json_iterator import JSONIteratorTransformer
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema.types import FlyteSchema


def is_pydantic_basemodel(python_type: typing.Type) -> bool:
Expand Down Expand Up @@ -305,11 +307,50 @@ def convert(
if value is None:
raise click.BadParameter("None value cannot be converted to a Json type.")

FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema]

def has_nested_dataclass(t: typing.Type) -> bool:
"""
Recursively checks whether the given type or its nested types contain any dataclass.
This function is typically called with a dictionary or list type and will return True if
any of the nested types within the dictionary or list is a dataclass.
Note:
- A single dataclass will return True.
- The function specifically excludes certain Flyte types like FlyteFile, FlyteDirectory,
StructuredDataset, and FlyteSchema from being considered as dataclasses. This is because
these types are handled separately by Flyte and do not need to be converted to dataclasses.
Args:
t (typing.Type): The type to check for nested dataclasses.
Returns:
bool: True if the type or its nested types contain a dataclass, False otherwise.
"""

if dataclasses.is_dataclass(t):
# FlyteTypes is not supported now, we can support it in the future.
return t not in FLYTE_TYPES

return any(has_nested_dataclass(arg) for arg in get_args(t))

parsed_value = self._parse(value, param)

# We compare the origin type because the json parsed value for list or dict is always a list or dict without
# the covariant type information.
if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type:
# Indexing the return value of get_args will raise an error for native dict and list types.
# We don't support native list/dict types with nested dataclasses.
if get_args(self._python_type) == ():
return parsed_value
elif isinstance(parsed_value, list) and has_nested_dataclass(get_args(self._python_type)[0]):
j = JsonParamType(get_args(self._python_type)[0])
return [j.convert(v, param, ctx) for v in parsed_value]
elif isinstance(parsed_value, dict) and has_nested_dataclass(get_args(self._python_type)[1]):
j = JsonParamType(get_args(self._python_type)[1])
return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()}

return parsed_value

if is_pydantic_basemodel(self._python_type):
Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/my_wf_input.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
},
"p": "None",
"q": "tests/flytekit/unit/cli/pyflyte/testdata",
"r": [{"i": 1, "a": ["h", "e"]}],
"s": {"x": {"i": 1, "a": ["h", "e"]}},
"t": {"i": [{"i":1,"a":["h","e"]}]},
"remote": "tests/flytekit/unit/cli/pyflyte/testdata",
"image": "tests/flytekit/unit/cli/pyflyte/testdata"
}
17 changes: 17 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,22 @@ o:
- tests/flytekit/unit/cli/pyflyte/testdata/df.parquet
p: 'None'
q: tests/flytekit/unit/cli/pyflyte/testdata
r:
- i: 1
a:
- h
- e
s:
x:
i: 1
a:
- h
- e
t:
i:
- i: 1
a:
- h
- e
remote: tests/flytekit/unit/cli/pyflyte/testdata
image: tests/flytekit/unit/cli/pyflyte/testdata
6 changes: 6 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def test_pyflyte_run_cli(workflow_file):
"Any",
"--q",
DIR_NAME,
"--r",
json.dumps([{"i": 1, "a": ["h", "e"]}]),
"--s",
json.dumps({"x": {"i": 1, "a": ["h", "e"]}}),
"--t",
json.dumps({"i": [{"i":1,"a":["h","e"]}]}),
],
catch_exceptions=False,
)
Expand Down
13 changes: 11 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class MyDataclass(DataClassJsonMixin):
i: int
a: typing.List[str]

@dataclass
class NestedDataclass(DataClassJsonMixin):
i: typing.List[MyDataclass]

class Color(enum.Enum):
RED = "RED"
Expand All @@ -61,8 +64,11 @@ def print_all(
o: typing.Dict[str, typing.List[FlyteFile]],
p: typing.Any,
q: FlyteDirectory,
r: typing.List[MyDataclass],
s: typing.Dict[str, MyDataclass],
t: NestedDataclass,
):
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}")
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}, {r}, {s}, {t}")


@task
Expand Down Expand Up @@ -93,14 +99,17 @@ def my_wf(
o: typing.Dict[str, typing.List[FlyteFile]],
p: typing.Any,
q: FlyteDirectory,
r: typing.List[MyDataclass],
s: typing.Dict[str, MyDataclass],
t: NestedDataclass,
remote: pd.DataFrame,
image: StructuredDataset,
m: dict = {"hello": "world"},
) -> Annotated[StructuredDataset, subset_cols]:
x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks
show_sd(in_sd=x)
show_sd(in_sd=image)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q, r=r, s=s, t=t)
return x


Expand Down
Loading

0 comments on commit 83b90fa

Please sign in to comment.