diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index a656a2901c..1353f63118 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -493,22 +493,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp self._make_dataclass_serializable(python_val, python_type) - # The function looks up or creates a JSONEncoder specifically designed for the object's type. - # This encoder is then used to convert a data class into a JSON string. - try: - encoder = self._encoder[python_type] - except KeyError: - encoder = JSONEncoder(python_type) - self._encoder[python_type] = encoder + # The `to_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`. + # It deserializes a JSON string into a data class, and provides additional functionality over JSONEncoder + if hasattr(python_val, "to_json"): + json_str = python_val.to_json() + else: + # The function looks up or creates a JSONEncoder specifically designed for the object's type. + # This encoder is then used to convert a data class into a JSON string. + try: + encoder = self._encoder[python_type] + except KeyError: + encoder = JSONEncoder(python_type) + self._encoder[python_type] = encoder - try: - json_str = encoder.encode(python_val) - except NotImplementedError: - # you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. - raise NotImplementedError( - f"{python_type} should inherit from mashumaro.types.SerializableType" - f" and implement _serialize and _deserialize methods." - ) + try: + json_str = encoder.encode(python_val) + except NotImplementedError: + # you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. + raise NotImplementedError( + f"{python_type} should inherit from mashumaro.types.SerializableType" + f" and implement _serialize and _deserialize methods." + ) return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore @@ -652,15 +657,20 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: json_str = _json_format.MessageToJson(lv.scalar.generic) - # The function looks up or creates a JSONDecoder specifically designed for the object's type. - # This decoder is then used to convert a JSON string into a data class. - try: - decoder = self._decoder[expected_python_type] - except KeyError: - decoder = JSONDecoder(expected_python_type) - self._decoder[expected_python_type] = decoder + # The `from_json` function is integrated through either the `dataclasses_json` decorator or by inheriting from `DataClassJsonMixin`. + # It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder + if hasattr(expected_python_type, "from_json"): + dc = expected_python_type.from_json(json_str) # type: ignore + else: + # The function looks up or creates a JSONDecoder specifically designed for the object's type. + # This decoder is then used to convert a JSON string into a data class. + try: + decoder = self._decoder[expected_python_type] + except KeyError: + decoder = JSONDecoder(expected_python_type) + self._decoder[expected_python_type] = decoder - dc = decoder.decode(json_str) + dc = decoder.decode(json_str) dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, dc) @@ -1062,7 +1072,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if python_val is None and python_type != NoneType and expected and expected.union_type is None: + if (python_val is None and python_type != NoneType) and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: