Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed optional typing on non-serializable types #2939

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions src/zenml/steps/entrypoint_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
Sequence,
Type,
Union,
get_args,
get_origin,
)

from pydantic import ConfigDict, ValidationError, create_model
from pydantic import ValidationError, create_model

from zenml.constants import ENFORCE_TYPE_ANNOTATIONS
from zenml.exceptions import StepInterfaceError
Expand Down Expand Up @@ -185,7 +187,6 @@ def validate_input(self, key: str, value: Any) -> None:
)

parameter = self.inputs[key]

if isinstance(
value,
(
Expand Down Expand Up @@ -235,17 +236,40 @@ def _validate_input_value(
parameter: The function parameter for which the value was provided.
value: The input value.
"""
config_dict = ConfigDict(arbitrary_types_allowed=False)

# Create a pydantic model with just a single required field with the
# type annotation of the parameter to verify the input type including
# pydantics type coercion
validation_model_class = create_model(
"input_validation_model",
__config__=config_dict,
value=(parameter.annotation, ...),
)
validation_model_class(value=value)
annotation = parameter.annotation

# Handle Optional types
origin = get_origin(annotation)
if origin is Union:
AlexejPenner marked this conversation as resolved.
Show resolved Hide resolved
args = get_args(annotation)
if type(None) in args:
if value is None:
return # None is valid for Optional types
# Remove NoneType from args as this case is handled from here
args = tuple(arg for arg in args if arg is not type(None))
annotation = args[0] if len(args) == 1 else Union[args]

# Handle None values for non-Optional types
if value is None and annotation is not type(None):
AlexejPenner marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Expected {annotation}, but got None")

# Use Pydantic for all types to take advantage of its coercion abilities
try:
config_dict = {"arbitrary_types_allowed": True}
validation_model_class = create_model(
"input_validation_model",
__config__=type("Config", (), config_dict),
value=(annotation, ...),
)
validation_model_class(value=value)
except ValidationError as e:
raise ValueError(f"Invalid input: {e}")
AlexejPenner marked this conversation as resolved.
Show resolved Hide resolved
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be

try:
  ...
except ValidationError:
  raise
except Exception:
  ...

instead? In case pydantic fails with a regular validation error, we actually want to raise that because it means the types don't match right?

# If Pydantic can't handle it, fall back to isinstance
if not isinstance(value, annotation):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this might fail with the raw annotation:

isinstance(int, Union[int, str])

# TypeError: Subscripted generics cannot be used with class and instance checks

raise TypeError(
f"Expected {annotation}, but got {type(value)}"
)


def validate_entrypoint_function(
Expand Down
135 changes: 133 additions & 2 deletions tests/unit/steps/test_base_step_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
import sys
from contextlib import ExitStack as does_not_raise
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple, Union

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_input_validation_inside_pipeline():
def test_pipeline(step_input):
return step_with_int_input(step_input)

with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
test_pipeline(step_input="wrong_type")

with does_not_raise():
Expand Down Expand Up @@ -187,3 +187,134 @@ def test_pipeline():

with pytest.raises(StepInterfaceError):
test_pipeline()


# ------------------------ Optional Input types
@step
def some_step(some_optional_int: Optional[int]) -> None:
pass


def test_step_can_have_optional_input_types():
"""Tests that a step allows None values for optional input types"""

@pipeline
def p():
some_step(some_optional_int=None)

with does_not_raise():
p()


def test_step_fails_on_none_inputs_for_non_optional_input_types():
"""Tests that a step does not allow None values for non-optional input types"""

@step
def some_step(some_optional_int: int) -> None:
pass

@pipeline
def p():
some_step(some_optional_int=None)

with pytest.raises(ValueError):
p().run(unlisted=True)


# --------- Test type coercion


@step
def coerce_step(some_int: int, some_float: float) -> None:
pass


def test_step_with_type_coercion():
"""Tests that a step can coerce types when possible"""

@pipeline
def p():
coerce_step(some_int="42", some_float="3.14")

with does_not_raise():
p()


def test_step_fails_on_invalid_type_coercion():
"""Tests that a step fails when type coercion is not possible"""

@step
def coerce_step(some_int: int) -> None:
pass

@pipeline
def p():
coerce_step(some_int="not an int")

with pytest.raises(ValueError):
p().run(unlisted=True)


# ------------- Non-Json-Serializable types


class NonSerializable:
def __init__(self, value):
self.value = value


@step
def non_serializable_step(some_obj: NonSerializable) -> None:
pass


def test_step_with_non_serializable_type_as_parameter_fails():
"""Tests that a step can handle non-JSON serializable types, but fails if these are passed as parameters"""

@pipeline
def p():
non_serializable_step(some_obj=NonSerializable(42))

with pytest.raises(StepInterfaceError):
p().run(unlisted=True)


def test_step_fails_on_wrong_non_serializable_type():
"""Tests that a step fails when given the wrong non-serializable type"""

@step
def non_serializable_step(some_obj: NonSerializable) -> None:
pass

@pipeline
def p():
non_serializable_step(some_obj=None)

with pytest.raises(ValueError):
p().run(unlisted=True)


# --------- Test union types


@step
def union_step(some_union: Union[int, str]) -> None:
pass


def test_step_with_union_type():
"""Tests that a step can handle Union types"""

@pipeline
def p():
union_step(some_union=42)

with does_not_raise():
p()

@pipeline
def p():
union_step(some_union="fourtytwo")

with does_not_raise():
p()
Loading