Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable JSON serialization of dataclasses produced by hydra-zen builds #5

Open
cameronraysmith opened this issue Nov 27, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@cameronraysmith
Copy link
Contributor

cameronraysmith commented Nov 27, 2023

flytekit's dataclass transformer requires JSON-serializable dataclasses (docs). Currently, we construct JSON-serializable dataclasses for arguments or whole function interfaces along the lines of the pseudocode

from datclasses import make_dataclass
from dataclasses_json import DataClassJsonMixin
# OR
# from mashumaro.mixins.json import DataClassJSONMixin
from hydra_zen import builds
from sklearn.linear_model import LogisticRegression

# ...
# passing DataClassJsonMixin to bases via make_dataclass
logistic_regression_fields = create_dataclass_from_callable(
    LogisticRegression, custom_types_defaults
)

LRI_DataClass = make_dataclass(
    "LRI_DataClass",
    logistic_regression_fields,
    bases=(DataClassJsonMixin,),
)
LRI_DataClass.__module__ = __name__

approximating the usage in the logistic regression example, which, accounting for the dependency on

def create_dataclass_from_callable(
callable_obj: Callable,
overrides: Optional[Dict[str, Tuple[Type, Any]]] = None,
) -> List[Tuple[str, Type, Any]]:
"""
Creates the fields of a dataclass from a `Callable` that includes all
parameters of the callable as typed fields with default values inferred or
taken from type hints. The function also accepts a dictionary containing
parameter names together with a tuple of a type and default to allow
specification of or override (un)typed defaults from the target callable.
Args:
callable_obj (Callable): The callable object to create a dataclass from.
overrides (Optional[Dict[str, Tuple[Type, Any]]]): Dictionary to
override inferred types and default values. Each dict value is a tuple
(Type, default_value).
Returns:
Fields that can be used to construct a new dataclass type that
represents the interface of the callable.
Examples:
>>> from pprint import pprint
>>> custom_types_defaults: Dict[str, Tuple[Type, Any]] = {
... "penalty": (str, "l2"),
... "class_weight": (Optional[dict], None),
... "random_state": (Optional[int], None),
... "max_iter": (int, 2000),
... "n_jobs": (Optional[int], None),
... "l1_ratio": (Optional[float], None),
... }
>>> fields = create_dataclass_from_callable(LogisticRegression, custom_types_defaults)
>>> LogisticRegressionInterface = dataclasses.make_dataclass(
... "LogisticRegressionInterface", fields, bases=(DataClassJSONMixin,)
... )
>>> lr_instance = LogisticRegressionInterface()
>>> isinstance(lr_instance, DataClassJSONMixin)
True
>>> pprint(lr_instance)
LogisticRegressionInterface(penalty='l2',
dual=False,
tol=0.0001,
C=1.0,
fit_intercept=True,
intercept_scaling=1,
class_weight=None,
random_state=None,
solver='lbfgs',
max_iter=2000,
multi_class='auto',
verbose=0,
warm_start=False,
n_jobs=None,
l1_ratio=None)
"""
if inspect.isclass(callable_obj):
func = callable_obj.__init__
else:
func = callable_obj
signature = inspect.signature(func)
type_hints = get_type_hints(func)
fields = []
for name, param in signature.parameters.items():
if name == "self":
continue
if overrides and name in overrides:
field_type, default_value = overrides[name]
else:
inferred_type = infer_type_from_default(param.default)
field_type = type_hints.get(name, inferred_type)
default_value = (
param.default
if param.default is not inspect.Parameter.empty
else dataclasses.field(default_factory=lambda: None)
)
fields.append((name, field_type, default_value))
return fields

create_dataclass_from_callable, is verbose, even though it behaves as expected.

Based on the documentation for hydra_zen.builds zen_dataclass argument it seems like it should be possible to use the dataclasses constructed by hydra-zen instead

# passing DataClassJsonMixin to bases via zen_dataclass
Builds_LRI = builds(
    LogisticRegression,
    populate_full_signature=True,
    dataclass_name="Builds_LRI",
    zen_dataclass={"bases": (DataClassJsonMixin,), "module": __name__},
)

and eliminate create_dataclass_from_callable from flytezen altogether, but this produces

TypeError: dataclass option `bases` must be a tuple of dataclass types
@cameronraysmith cameronraysmith added the enhancement New feature or request label Nov 27, 2023
@rsokl
Copy link

rsokl commented Dec 31, 2023

Sorry for lurking but I just came across this!

I think you can make inheritable via builds by making a dataclass-typed subclass of it

from dataclasses import dataclass

from dataclasses_json import DataClassJsonMixin as _DataClassJsonMixin

@dataclass
class DataClassJsonMixin(_DataClassJsonMixin): pass

then you can pass DataClassJsonMixin into bases

@cameronraysmith
Copy link
Contributor Author

Sorry for lurking but I just came across this!

Not at all! Thank you for chiming in. I would have posted this in a discussion on the hydra-zen repository, but had not yet invested enough time to understand what I was missing, so I logged it here for future work.

I think you can make inheritable via builds by making a dataclass-typed subclass of it

from dataclasses import dataclass

from dataclasses_json import DataClassJsonMixin as _DataClassJsonMixin

@dataclass
class DataClassJsonMixin(_DataClassJsonMixin): pass

then you can pass DataClassJsonMixin into bases

That sounds like it should essentially lead to the resolution of this issue. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants