From fe8b842bce10708ddb8a022c2d6af39091515770 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 4 May 2022 03:56:17 +0800 Subject: [PATCH 01/18] Support optional input Signed-off-by: Kevin Su --- flytekit/core/promise.py | 9 ++++++++- flytekit/core/type_engine.py | 6 ++++++ flytekit/core/workflow.py | 11 ++++++++--- tests/flytekit/unit/core/test_composition.py | 16 ++++++++++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 7ecc5e205c..44113b509a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -23,6 +23,7 @@ from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Primitive +from flytekit.models.types import SimpleType def translate_inputs_to_literals( @@ -871,7 +872,13 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: - raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + is_optional = False + for variant in var.type.union_type.variants: + if variant.simple == SimpleType.NONE: + kwargs[k] = None + is_optional = True + if not is_optional: + raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 275a40f451..b8ea22322f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -31,6 +31,7 @@ from flytekit.exceptions import user as user_exceptions from flytekit.loggers import logger from flytekit.models import interface as _interface_models +from flytekit.models import literals from flytekit.models import types as _type_models from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel from flytekit.models.core import types as _core_types @@ -737,6 +738,11 @@ def literal_map_to_kwargs( """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ + # Assign default literal value (void) if python type is an optional type + for k, v in python_types.items(): + if k not in lm.literals and typing.get_origin(v) is typing.Union and type(None) in typing.get_args(v): + lm.literals[k] = Literal(scalar=literals.Scalar(none_type=literals.Void())) + if len(lm.literals) != len(python_types): raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 0f36f86ec3..c7edddcf18 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -39,6 +39,7 @@ from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.types import SimpleType GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -243,10 +244,14 @@ def execute(self, **kwargs): def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. - for k, v in kwargs.items(): - if not isinstance(v, Promise): + for k, v in self.interface.inputs.items(): + if k not in kwargs: + for variant in v.type.union_type.variants: + if variant.simple == SimpleType.NONE: + kwargs[k] = None + if not isinstance(kwargs[k], Promise): t = self.python_interface.inputs[k] - kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) + kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, kwargs[k], t, v.type)) # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 37f5d10195..93d58d468c 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -169,3 +169,19 @@ def my_wf3(a: int = 42) -> (int, str, str, str): return x, y, u, v assert my_wf2() == (44, "world-44", "world-5", "world-7") + + +def test_optional_input(): + @task() + def t1(b: typing.Optional[int]) -> str: + return str(b) + + @task() + def t2(c: str) -> str: + return c + + @workflow + def wf(a: typing.Optional[int]) -> str: + return t2(c=t1()) + + assert wf() == str(None) From 152c10b1ff04b1082d2ed80fb5c6ac3f08e6e80b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 5 May 2022 06:25:19 +0800 Subject: [PATCH 02/18] Fixed tests Signed-off-by: Kevin Su --- flytekit/clients/friendly.py | 3 ++- flytekit/clis/flyte_cli/main.py | 42 ++++++++++++++++++++++----------- flytekit/core/promise.py | 9 +++---- flytekit/core/workflow.py | 2 +- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 8db7c93a98..9bba67439a 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -911,7 +911,7 @@ def update_project_domain_attributes(self, project, domain, matching_attributes) ) ) - def update_workflow_attributes(self, project, domain, workflow, matching_attributes): + def update_workflow_attributes(self, project, domain, workflow, launch_plan, matching_attributes): """ Sets custom attributes for a project, domain, and workflow combination. :param Text project: @@ -926,6 +926,7 @@ def update_workflow_attributes(self, project, domain, workflow, matching_attribu project=project, domain=domain, workflow=workflow, + launch_plan=launch_plan, matching_attributes=matching_attributes.to_flyte_idl(), ) ) diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 3628af7beb..be6ccb1761 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -283,6 +283,7 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC _PROJECT_FLAGS = ["-p", "--project"] _DOMAIN_FLAGS = ["-d", "--domain"] +_LAUNCH_PLAN_FLAGS = ["-d", "--launch-plan"] _NAME_FLAGS = ["-n", "--name"] _VERSION_FLAGS = ["-v", "--version"] _HOST_FLAGS = ["-h", "--host"] @@ -313,6 +314,13 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC default=None, help="[Optional] The name to query.", ) +_optional_launch_plan_option = _click.option( + *_LAUNCH_PLAN_FLAGS, + required=False, + type=str, + default=None, + help="[Optional] The launch plan name to query.", +) _principal_option = _click.option(*_PRINCIPAL_FLAGS, required=True, help="Your team name, or your name") _optional_principal_option = _click.option( *_PRINCIPAL_FLAGS, @@ -1936,9 +1944,10 @@ def update_launch_plan_meta(description, host, insecure, project, domain, name): @_insecure_option @_project_option @_domain_option +@_optional_launch_plan_option @_optional_name_option @_click.option("--attributes", type=(str, str), multiple=True) -def update_cluster_resource_attributes(host, insecure, project, domain, name, attributes): +def update_cluster_resource_attributes(host, insecure, project, domain, launch_plan, name, attributes): """ Sets matchable cluster resource attributes for a project, domain and optionally, workflow name. The attribute names should match the templatized values you use to configure these resource @@ -1956,7 +1965,7 @@ def update_cluster_resource_attributes(host, insecure, project, domain, name, at matching_attributes = _MatchingAttributes(cluster_resource_attributes=cluster_resource_attributes) if name is not None: - client.update_workflow_attributes(project, domain, name, matching_attributes) + client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) _click.echo( "Successfully updated cluster resource attributes for project: {}, domain: {}, and workflow: {}".format( project, domain, name @@ -1974,9 +1983,10 @@ def update_cluster_resource_attributes(host, insecure, project, domain, name, at @_insecure_option @_project_option @_domain_option +@_optional_launch_plan_option @_optional_name_option @_click.option("--tags", multiple=True, help="Tag(s) to be applied.") -def update_execution_queue_attributes(host, insecure, project, domain, name, tags): +def update_execution_queue_attributes(host, insecure, project, domain, launch_plan, name, tags): """ Tags used for assigning execution queues for tasks belonging to a project, domain and optionally, workflow name. @@ -1990,10 +2000,10 @@ def update_execution_queue_attributes(host, insecure, project, domain, name, tag matching_attributes = _MatchingAttributes(execution_queue_attributes=execution_queue_attributes) if name is not None: - client.update_workflow_attributes(project, domain, name, matching_attributes) + client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) _click.echo( - "Successfully updated execution queue attributes for project: {}, domain: {}, and workflow: {}".format( - project, domain, name + "Successfully updated execution queue attributes for project: {}, domain: {}, workflow: {}, and launch plan: {}".format( + project, domain, name, launch_plan ) ) else: @@ -2008,9 +2018,10 @@ def update_execution_queue_attributes(host, insecure, project, domain, name, tag @_insecure_option @_project_option @_domain_option +@_optional_launch_plan_option @_optional_name_option @_click.option("--value", help="Cluster label for which to schedule matching executions") -def update_execution_cluster_label(host, insecure, project, domain, name, value): +def update_execution_cluster_label(host, insecure, project, domain, launch_plan, name, value): """ Label value to determine where an execution's task will be run for tasks belonging to a project, domain and optionally, workflow name. @@ -2024,10 +2035,10 @@ def update_execution_cluster_label(host, insecure, project, domain, name, value) matching_attributes = _MatchingAttributes(execution_cluster_label=execution_cluster_label) if name is not None: - client.update_workflow_attributes(project, domain, name, matching_attributes) + client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) _click.echo( - "Successfully updated execution cluster label for project: {}, domain: {}, and workflow: {}".format( - project, domain, name + "Successfully updated execution cluster label for project: {}, domain: {}, workflow: {}, and launch plan: {}".format( + project, domain, name, launch_plan ) ) else: @@ -2042,13 +2053,16 @@ def update_execution_cluster_label(host, insecure, project, domain, name, value) @_insecure_option @_project_option @_domain_option +@_optional_launch_plan_option @_optional_name_option @_click.option("--task-type", help="Task type for which to apply plugin implementation overrides") @_click.option("--plugin-id", multiple=True, help="Plugin id(s) to be used in place of the default for the task type.") @_click.option( "--missing-plugin-behavior", help="Behavior when no specified plugin_id has an associated handler.", default="FAIL" ) -def update_plugin_override(host, insecure, project, domain, name, task_type, plugin_id, missing_plugin_behavior): +def update_plugin_override( + host, insecure, project, domain, launch_plan, name, task_type, plugin_id, missing_plugin_behavior +): """ Plugin ids designating non-default plugin handlers to be used for tasks of a certain type. @@ -2064,10 +2078,10 @@ def update_plugin_override(host, insecure, project, domain, name, task_type, plu matching_attributes = _MatchingAttributes(plugin_overrides=_PluginOverrides(overrides=[plugin_override])) if name is not None: - client.update_workflow_attributes(project, domain, name, matching_attributes) + client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) _click.echo( - "Successfully updated plugin override for project: {}, domain: {}, and workflow: {}".format( - project, domain, name + "Successfully updated plugin override for project: {}, domain: {}, workflow: {}, and launch plan: {}".format( + project, domain, name, launch_plan ) ) else: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 44113b509a..c04067a5c2 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -873,10 +873,11 @@ def create_and_link_node( var = typed_interface.inputs[k] if k not in kwargs: is_optional = False - for variant in var.type.union_type.variants: - if variant.simple == SimpleType.NONE: - kwargs[k] = None - is_optional = True + if var.type.union_type: + for variant in var.type.union_type.variants: + if variant.simple == SimpleType.NONE: + kwargs[k] = None + is_optional = True if not is_optional: raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) v = kwargs[k] diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index c7edddcf18..305855163a 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -245,7 +245,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. for k, v in self.interface.inputs.items(): - if k not in kwargs: + if k not in kwargs and v.type.union_type: for variant in v.type.union_type.variants: if variant.simple == SimpleType.NONE: kwargs[k] = None From bd93a8138de35aa260e44516529a8c8906532895 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 5 May 2022 06:32:46 +0800 Subject: [PATCH 03/18] Fixed tests Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b8ea22322f..7d9b6949c3 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -740,7 +740,7 @@ def literal_map_to_kwargs( """ # Assign default literal value (void) if python type is an optional type for k, v in python_types.items(): - if k not in lm.literals and typing.get_origin(v) is typing.Union and type(None) in typing.get_args(v): + if k not in lm.literals and get_origin(v) is typing.Union and type(None) in typing.get_args(v): lm.literals[k] = Literal(scalar=literals.Scalar(none_type=literals.Void())) if len(lm.literals) != len(python_types): From 06dfeeed69f653fdd223b4be060d207471d34234 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 5 May 2022 06:40:26 +0800 Subject: [PATCH 04/18] Fixed tests Signed-off-by: Kevin Su --- flytekit/clients/friendly.py | 3 +-- flytekit/clis/flyte_cli/main.py | 44 +++++++++++---------------------- 2 files changed, 16 insertions(+), 31 deletions(-) diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 9bba67439a..8db7c93a98 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -911,7 +911,7 @@ def update_project_domain_attributes(self, project, domain, matching_attributes) ) ) - def update_workflow_attributes(self, project, domain, workflow, launch_plan, matching_attributes): + def update_workflow_attributes(self, project, domain, workflow, matching_attributes): """ Sets custom attributes for a project, domain, and workflow combination. :param Text project: @@ -926,7 +926,6 @@ def update_workflow_attributes(self, project, domain, workflow, launch_plan, mat project=project, domain=domain, workflow=workflow, - launch_plan=launch_plan, matching_attributes=matching_attributes.to_flyte_idl(), ) ) diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index be6ccb1761..d038ddabd4 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -283,7 +283,6 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC _PROJECT_FLAGS = ["-p", "--project"] _DOMAIN_FLAGS = ["-d", "--domain"] -_LAUNCH_PLAN_FLAGS = ["-d", "--launch-plan"] _NAME_FLAGS = ["-n", "--name"] _VERSION_FLAGS = ["-v", "--version"] _HOST_FLAGS = ["-h", "--host"] @@ -314,13 +313,6 @@ def _get_client(host: str, insecure: bool) -> _friendly_client.SynchronousFlyteC default=None, help="[Optional] The name to query.", ) -_optional_launch_plan_option = _click.option( - *_LAUNCH_PLAN_FLAGS, - required=False, - type=str, - default=None, - help="[Optional] The launch plan name to query.", -) _principal_option = _click.option(*_PRINCIPAL_FLAGS, required=True, help="Your team name, or your name") _optional_principal_option = _click.option( *_PRINCIPAL_FLAGS, @@ -1944,10 +1936,9 @@ def update_launch_plan_meta(description, host, insecure, project, domain, name): @_insecure_option @_project_option @_domain_option -@_optional_launch_plan_option @_optional_name_option @_click.option("--attributes", type=(str, str), multiple=True) -def update_cluster_resource_attributes(host, insecure, project, domain, launch_plan, name, attributes): +def update_cluster_resource_attributes(host, insecure, project, domain, name, attributes): """ Sets matchable cluster resource attributes for a project, domain and optionally, workflow name. The attribute names should match the templatized values you use to configure these resource @@ -1965,7 +1956,7 @@ def update_cluster_resource_attributes(host, insecure, project, domain, launch_p matching_attributes = _MatchingAttributes(cluster_resource_attributes=cluster_resource_attributes) if name is not None: - client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) + client.update_workflow_attributes(project, domain, name, matching_attributes) _click.echo( "Successfully updated cluster resource attributes for project: {}, domain: {}, and workflow: {}".format( project, domain, name @@ -1983,10 +1974,9 @@ def update_cluster_resource_attributes(host, insecure, project, domain, launch_p @_insecure_option @_project_option @_domain_option -@_optional_launch_plan_option @_optional_name_option @_click.option("--tags", multiple=True, help="Tag(s) to be applied.") -def update_execution_queue_attributes(host, insecure, project, domain, launch_plan, name, tags): +def update_execution_queue_attributes(host, insecure, project, domain, name, tags): """ Tags used for assigning execution queues for tasks belonging to a project, domain and optionally, workflow name. @@ -2000,10 +1990,10 @@ def update_execution_queue_attributes(host, insecure, project, domain, launch_pl matching_attributes = _MatchingAttributes(execution_queue_attributes=execution_queue_attributes) if name is not None: - client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) + client.update_workflow_attributes(project, domain, name, matching_attributes) _click.echo( - "Successfully updated execution queue attributes for project: {}, domain: {}, workflow: {}, and launch plan: {}".format( - project, domain, name, launch_plan + "Successfully updated execution queue attributes for project: {}, domain: {}, and workflow: {}".format( + project, domain, name ) ) else: @@ -2018,10 +2008,9 @@ def update_execution_queue_attributes(host, insecure, project, domain, launch_pl @_insecure_option @_project_option @_domain_option -@_optional_launch_plan_option @_optional_name_option @_click.option("--value", help="Cluster label for which to schedule matching executions") -def update_execution_cluster_label(host, insecure, project, domain, launch_plan, name, value): +def update_execution_cluster_label(host, insecure, project, domain, name, value): """ Label value to determine where an execution's task will be run for tasks belonging to a project, domain and optionally, workflow name. @@ -2035,10 +2024,10 @@ def update_execution_cluster_label(host, insecure, project, domain, launch_plan, matching_attributes = _MatchingAttributes(execution_cluster_label=execution_cluster_label) if name is not None: - client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) + client.update_workflow_attributes(project, domain, name, matching_attributes) _click.echo( - "Successfully updated execution cluster label for project: {}, domain: {}, workflow: {}, and launch plan: {}".format( - project, domain, name, launch_plan + "Successfully updated execution cluster label for project: {}, domain: {}, and workflow: {}".format( + project, domain, name ) ) else: @@ -2053,16 +2042,13 @@ def update_execution_cluster_label(host, insecure, project, domain, launch_plan, @_insecure_option @_project_option @_domain_option -@_optional_launch_plan_option @_optional_name_option @_click.option("--task-type", help="Task type for which to apply plugin implementation overrides") @_click.option("--plugin-id", multiple=True, help="Plugin id(s) to be used in place of the default for the task type.") @_click.option( "--missing-plugin-behavior", help="Behavior when no specified plugin_id has an associated handler.", default="FAIL" ) -def update_plugin_override( - host, insecure, project, domain, launch_plan, name, task_type, plugin_id, missing_plugin_behavior -): +def update_plugin_override(host, insecure, project, domain, name, task_type, plugin_id, missing_plugin_behavior): """ Plugin ids designating non-default plugin handlers to be used for tasks of a certain type. @@ -2078,10 +2064,10 @@ def update_plugin_override( matching_attributes = _MatchingAttributes(plugin_overrides=_PluginOverrides(overrides=[plugin_override])) if name is not None: - client.update_workflow_attributes(project, domain, name, launch_plan, matching_attributes) + client.update_workflow_attributes(project, domain, name, matching_attributes) _click.echo( - "Successfully updated plugin override for project: {}, domain: {}, workflow: {}, and launch plan: {}".format( - project, domain, name, launch_plan + "Successfully updated plugin override for project: {}, domain: {}, and workflow: {}".format( + project, domain, name ) ) else: @@ -2221,4 +2207,4 @@ def setup_config(host, insecure): if __name__ == "__main__": - _flyte_cli() + _flyte_cli() \ No newline at end of file From 2375cebc41c86cabf7fb9be76f8d835f97acf3db Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 5 May 2022 06:41:17 +0800 Subject: [PATCH 05/18] Fixed tests Signed-off-by: Kevin Su --- flytekit/clis/flyte_cli/main.py | 2 +- flytekit/core/type_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index d038ddabd4..3628af7beb 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -2207,4 +2207,4 @@ def setup_config(host, insecure): if __name__ == "__main__": - _flyte_cli() \ No newline at end of file + _flyte_cli() diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 7d9b6949c3..385b54c860 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -740,7 +740,7 @@ def literal_map_to_kwargs( """ # Assign default literal value (void) if python type is an optional type for k, v in python_types.items(): - if k not in lm.literals and get_origin(v) is typing.Union and type(None) in typing.get_args(v): + if k not in lm.literals and get_origin(v) is typing.Union and type(None) in get_args(v): lm.literals[k] = Literal(scalar=literals.Scalar(none_type=literals.Void())) if len(lm.literals) != len(python_types): From 6bc171966de36158322f78b83412ce2f1be14b9f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 19 May 2022 15:43:20 +0800 Subject: [PATCH 06/18] set default value for optional inputs when compiling Signed-off-by: Kevin Su --- flytekit/core/interface.py | 19 ++++++++++++++----- tests/flytekit/unit/core/test_interface.py | 11 +++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index fed0793e4a..bc711cf741 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,12 +7,15 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing_extensions import get_args, get_origin + from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models +from flytekit.models.literals import Void from flytekit.types.pickle import FlytePickle T = typing.TypeVar("T") @@ -182,11 +185,17 @@ def transform_inputs_to_parameters( inputs_with_def = interface.inputs_with_defaults for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] - required = _default is None - default_lv = None - if _default is not None: - default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) - params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) + if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): + from flytekit import Literal, Scalar + + literal = Literal(scalar=Scalar(none_type=Void())) + params[k] = _interface_models.Parameter(var=v, default=literal, required=False) + else: + required = _default is None + default_lv = None + if _default is not None: + default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) + params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) return _interface_models.ParameterMap(params) diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 2317bab1c2..9e12147ecf 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -12,6 +12,7 @@ transform_variable_map, ) from flytekit.models.core import types as _core_types +from flytekit.models.literals import Void from flytekit.types.file import FlyteFile from flytekit.types.pickle import FlytePickle @@ -199,6 +200,16 @@ def z(a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"] assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]} + def z(a: typing.Optional[int], b: typing.Optional[str] = "eleven") -> typing.Tuple[int, str]: + ... + + our_interface = transform_function_to_interface(z) + params = transform_inputs_to_parameters(ctx, our_interface) + assert not params.parameters["a"].required + assert params.parameters["a"].default.scalar.none_type == Void() + assert not params.parameters["b"].required + assert params.parameters["b"].default.scalar.union.value.scalar.primitive.string_value == "eleven" + def test_parameters_with_docstring(): ctx = context_manager.FlyteContext.current_context() From 0453fccb308747d8082382146edb62b105fbfa36 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 2 Jun 2022 11:10:22 +0800 Subject: [PATCH 07/18] wip Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 1 + flytekit/core/workflow.py | 11 +++-------- tests/flytekit/unit/core/test_interface.py | 4 ++-- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index cc3319e359..1f51512cba 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -281,6 +281,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return create_task_output(vals, self.python_interface) def __call__(self, *args, **kwargs): + print(kwargs) return flyte_entity_call_handler(self, *args, **kwargs) def compile(self, ctx: FlyteContext, *args, **kwargs): diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 305855163a..e3ae8e7d64 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -244,15 +244,10 @@ def execute(self, **kwargs): def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: # This is done to support the invariant that Workflow local executions always work with Promise objects # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. - for k, v in self.interface.inputs.items(): - if k not in kwargs and v.type.union_type: - for variant in v.type.union_type.variants: - if variant.simple == SimpleType.NONE: - kwargs[k] = None - if not isinstance(kwargs[k], Promise): + for k, v in kwargs.items(): + if not isinstance(v, Promise): t = self.python_interface.inputs[k] - kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, kwargs[k], t, v.type)) - + kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. function_outputs = self.execute(**kwargs) diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 9e12147ecf..eb96fa05e9 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -200,7 +200,7 @@ def z(a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"] assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]} - def z(a: typing.Optional[int], b: typing.Optional[str] = "eleven") -> typing.Tuple[int, str]: + def z(a: typing.Optional[int] = None, b: typing.Optional[str] = None) -> typing.Tuple[int, str]: ... our_interface = transform_function_to_interface(z) @@ -208,7 +208,7 @@ def z(a: typing.Optional[int], b: typing.Optional[str] = "eleven") -> typing.Tup assert not params.parameters["a"].required assert params.parameters["a"].default.scalar.none_type == Void() assert not params.parameters["b"].required - assert params.parameters["b"].default.scalar.union.value.scalar.primitive.string_value == "eleven" + assert params.parameters["b"].default.scalar.none_type == Void() def test_parameters_with_docstring(): From f1934a6144c9e928760b3c597c6acca7f0f22f61 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 8 Jun 2022 20:21:33 +0800 Subject: [PATCH 08/18] wip Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 2 ++ flytekit/core/interface.py | 1 - flytekit/core/promise.py | 3 +++ flytekit/core/type_engine.py | 9 +++++---- tests/flytekit/unit/core/test_composition.py | 4 ++-- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 1f51512cba..c8636e6d83 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -236,7 +236,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr flyte_interface_types=self.interface.inputs, # type: ignore native_types=self.get_input_types(), ) + print("kwargs", kwargs) input_literal_map = _literal_models.LiteralMap(literals=kwargs) + print("input_literal_map", input_literal_map) # if metadata.cache is set, check memoized version if self.metadata.cache: diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index bc711cf741..11ef175464 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -187,7 +187,6 @@ def transform_inputs_to_parameters( val, _default = inputs_with_def[k] if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): from flytekit import Literal, Scalar - literal = Literal(scalar=Scalar(none_type=Void())) params[k] = _interface_models.Parameter(var=v, default=literal, required=False) else: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index f75bdecfd6..580fbf6dcb 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -886,6 +886,9 @@ def create_and_link_node( if var.type.union_type: for variant in var.type.union_type.variants: if variant.simple == SimpleType.NONE: + val, _default = interface.inputs_with_defaults[k] + if _default is not None: + raise ValueError(f"The default value for the optional type must be None, but got {_default}") kwargs[k] = None is_optional = True if not is_optional: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b8eac91424..2f9715631a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -739,15 +739,16 @@ def literal_map_to_kwargs( Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ # Assign default literal value (void) if python type is an optional type + expected_input_len = len(python_types) for k, v in python_types.items(): if k not in lm.literals and get_origin(v) is typing.Union and type(None) in get_args(v): - lm.literals[k] = Literal(scalar=literals.Scalar(none_type=literals.Void())) + expected_input_len -= 1 - if len(lm.literals) != len(python_types): + if len(lm.literals) != expected_input_len: raise ValueError( - f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" + f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {expected_input_len}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} + return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} @classmethod def dict_to_literal_map( diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 93d58d468c..0ef99e7a25 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -173,7 +173,7 @@ def my_wf3(a: int = 42) -> (int, str, str, str): def test_optional_input(): @task() - def t1(b: typing.Optional[int]) -> str: + def t1(b: typing.Optional[int] = None) -> str: return str(b) @task() @@ -181,7 +181,7 @@ def t2(c: str) -> str: return c @workflow - def wf(a: typing.Optional[int]) -> str: + def wf(a: typing.Optional[int] = 1) -> str: return t2(c=t1()) assert wf() == str(None) From c7f245428f5643c907d7e819830f0f03a6231e4a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 8 Jun 2022 20:31:25 +0800 Subject: [PATCH 09/18] wip Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 3 --- flytekit/core/interface.py | 1 + flytekit/core/promise.py | 8 +++----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 71747d28ea..6aef36305e 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -236,9 +236,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr flyte_interface_types=self.interface.inputs, # type: ignore native_types=self.get_input_types(), ) - print("kwargs", kwargs) input_literal_map = _literal_models.LiteralMap(literals=kwargs) - print("input_literal_map", input_literal_map) # if metadata.cache is set, check memoized version if self.metadata.cache: @@ -283,7 +281,6 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return create_task_output(vals, self.python_interface) def __call__(self, *args, **kwargs): - print(kwargs) return flyte_entity_call_handler(self, *args, **kwargs) def compile(self, ctx: FlyteContext, *args, **kwargs): diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index b2d35d30e4..a3b1b80dd8 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -187,6 +187,7 @@ def transform_inputs_to_parameters( val, _default = inputs_with_def[k] if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): from flytekit import Literal, Scalar + literal = Literal(scalar=Scalar(none_type=Void())) params[k] = _interface_models.Parameter(var=v, default=literal, required=False) else: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 7dba61a4ec..91928b5431 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -886,12 +886,10 @@ def create_and_link_node( if var.type.union_type: for variant in var.type.union_type.variants: if variant.simple == SimpleType.NONE: - val, _default = interface.inputs_with_defaults[k] - if _default is not None: - raise ValueError(f"The default value for the optional type must be None, but got {_default}") - kwargs[k] = None is_optional = True - if not is_optional: + if is_optional: + continue + else: raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte From d1938f54ff2a6a50f84ae468ab2280cb8c58657d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 8 Jun 2022 20:50:50 +0800 Subject: [PATCH 10/18] wip Signed-off-by: Kevin Su --- flytekit/core/promise.py | 15 +++++++++++---- flytekit/core/type_engine.py | 12 +++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 91928b5431..f1efa61693 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, cast -from typing_extensions import Protocol +from typing_extensions import Protocol, get_args, get_origin from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context @@ -122,6 +122,9 @@ def extract_value( t = native_types[k] result[k] = extract_value(ctx, v, t, var.type) + for k, v in native_types.items(): + if k not in result and get_origin(v) is typing.Union and type(None) in get_args(v): + result[k] = extract_value(ctx, None, native_types[k], flyte_interface_types[k].type) return result @@ -886,10 +889,14 @@ def create_and_link_node( if var.type.union_type: for variant in var.type.union_type.variants: if variant.simple == SimpleType.NONE: + val, _default = interface.inputs_with_defaults[k] + if _default is not None: + raise ValueError( + f"The default value for the optional type must be None, but got {_default}" + ) + kwargs[k] = None is_optional = True - if is_optional: - continue - else: + if not is_optional: raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2f9715631a..a5925dbe62 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -738,17 +738,11 @@ def literal_map_to_kwargs( """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - # Assign default literal value (void) if python type is an optional type - expected_input_len = len(python_types) - for k, v in python_types.items(): - if k not in lm.literals and get_origin(v) is typing.Union and type(None) in get_args(v): - expected_input_len -= 1 - - if len(lm.literals) != expected_input_len: + if len(lm.literals) != len(python_types): raise ValueError( - f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {expected_input_len}" + f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} + return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} @classmethod def dict_to_literal_map( From bce2224cac49f755657803f69e5ebbfed14f6b76 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 8 Jun 2022 21:09:09 +0800 Subject: [PATCH 11/18] wip Signed-off-by: Kevin Su --- flytekit/core/interface.py | 5 +---- flytekit/core/workflow.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index a3b1b80dd8..02e05d7046 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -186,10 +186,7 @@ def transform_inputs_to_parameters( for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): - from flytekit import Literal, Scalar - - literal = Literal(scalar=Scalar(none_type=Void())) - params[k] = _interface_models.Parameter(var=v, default=literal, required=False) + params[k] = _interface_models.Parameter(var=v, required=False) else: required = _default is None default_lv = None diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index e3ae8e7d64..0f36f86ec3 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -39,7 +39,6 @@ from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model -from flytekit.models.types import SimpleType GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -248,6 +247,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr if not isinstance(v, Promise): t = self.python_interface.inputs[k] kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) + # The output of this will always be a combination of Python native values and Promises containing Flyte # Literals. function_outputs = self.execute(**kwargs) From 602865dad851fd8b5a3381c96f6f1c3f9f5e81ff Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 8 Jun 2022 21:22:55 +0800 Subject: [PATCH 12/18] wip Signed-off-by: Kevin Su --- flytekit/core/interface.py | 5 ++++- flytekit/core/promise.py | 3 --- flytekit/core/type_engine.py | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 02e05d7046..a3b1b80dd8 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -186,7 +186,10 @@ def transform_inputs_to_parameters( for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): - params[k] = _interface_models.Parameter(var=v, required=False) + from flytekit import Literal, Scalar + + literal = Literal(scalar=Scalar(none_type=Void())) + params[k] = _interface_models.Parameter(var=v, default=literal, required=False) else: required = _default is None default_lv = None diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index f1efa61693..3014f04da9 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -122,9 +122,6 @@ def extract_value( t = native_types[k] result[k] = extract_value(ctx, v, t, var.type) - for k, v in native_types.items(): - if k not in result and get_origin(v) is typing.Union and type(None) in get_args(v): - result[k] = extract_value(ctx, None, native_types[k], flyte_interface_types[k].type) return result diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index a5925dbe62..d4ca18582f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -31,7 +31,6 @@ from flytekit.exceptions import user as user_exceptions from flytekit.loggers import logger from flytekit.models import interface as _interface_models -from flytekit.models import literals from flytekit.models import types as _type_models from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel from flytekit.models.core import types as _core_types @@ -738,11 +737,11 @@ def literal_map_to_kwargs( """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - if len(lm.literals) != len(python_types): + if len(lm.literals) > len(python_types): raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} + return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} @classmethod def dict_to_literal_map( From 03d73d286305963a5d91f5d32cdcf45d4b5fdc14 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 8 Jun 2022 22:34:22 +0800 Subject: [PATCH 13/18] wip Signed-off-by: Kevin Su --- flytekit/core/promise.py | 3 ++- tests/flytekit/unit/core/test_composition.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 3014f04da9..4d98c370dd 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -891,10 +891,11 @@ def create_and_link_node( raise ValueError( f"The default value for the optional type must be None, but got {_default}" ) - kwargs[k] = None is_optional = True if not is_optional: raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + else: + continue v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 0ef99e7a25..ecf437c625 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -173,15 +173,17 @@ def my_wf3(a: int = 42) -> (int, str, str, str): def test_optional_input(): @task() - def t1(b: typing.Optional[int] = None) -> str: - return str(b) + def t1(b: typing.Optional[int] = None) -> typing.Optional[int]: + print(b) + ... @task() - def t2(c: str) -> str: - return c + def t2(c: typing.Optional[int] = None) -> typing.Optional[int]: + print(c) + ... @workflow - def wf(a: typing.Optional[int] = 1) -> str: - return t2(c=t1()) + def wf(a: typing.Optional[int] = 1) -> typing.Optional[int]: + return t2(c=t1(b=a)) - assert wf() == str(None) + assert wf() is None From 111707fca16763307d2a1ec98ac873f29c7cba9d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Jun 2022 02:55:42 +0800 Subject: [PATCH 14/18] lint Signed-off-by: Kevin Su --- flytekit/core/promise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 4d98c370dd..6b0aafdc1b 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union, cast -from typing_extensions import Protocol, get_args, get_origin +from typing_extensions import Protocol from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context From df0c4a8340dd3d6ee55ee74cebe61c7566af9284 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Jun 2022 03:30:07 +0800 Subject: [PATCH 15/18] nit Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_composition.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index ecf437c625..db1ffe99d1 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -174,16 +174,15 @@ def my_wf3(a: int = 42) -> (int, str, str, str): def test_optional_input(): @task() def t1(b: typing.Optional[int] = None) -> typing.Optional[int]: - print(b) ... @task() def t2(c: typing.Optional[int] = None) -> typing.Optional[int]: - print(c) ... @workflow def wf(a: typing.Optional[int] = 1) -> typing.Optional[int]: - return t2(c=t1(b=a)) + t1() + return t2(c=a) assert wf() is None From 3038755ec3380a09303ed758de380b0b74edae14 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 11 Jun 2022 03:35:03 +0800 Subject: [PATCH 16/18] add tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_composition.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index db1ffe99d1..edc7827215 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,5 +1,7 @@ import typing +import pytest + from flytekit.core import launch_plan from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -186,3 +188,13 @@ def wf(a: typing.Optional[int] = 1) -> typing.Optional[int]: return t2(c=a) assert wf() is None + + @task() + def t3(c: typing.Optional[int] = 3) -> typing.Optional[int]: + ... + + with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): + + @workflow + def wf(): + return t3() From 33764dfadf37d315c40270221d4707552cf52a13 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 17 Jun 2022 03:07:43 +0800 Subject: [PATCH 17/18] More tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_composition.py | 22 ++++++++++---------- tests/flytekit/unit/core/test_interface.py | 6 +++++- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index edc7827215..28cbcdf706 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,4 +1,4 @@ -import typing +from typing import Dict, List, NamedTuple, Optional, Union import pytest @@ -10,7 +10,7 @@ def test_wf1_with_subwf(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -132,7 +132,7 @@ def my_wf2(a: int, b: int = 42) -> (str, str, int, int): def test_wf1_with_lp_node(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -175,26 +175,26 @@ def my_wf3(a: int = 42) -> (int, str, str, str): def test_optional_input(): @task() - def t1(b: typing.Optional[int] = None) -> typing.Optional[int]: + def t1(a: Optional[int] = None, b: Optional[List[int]] = None, c: Optional[Dict[str, int]] = None) -> Optional[int]: ... @task() - def t2(c: typing.Optional[int] = None) -> typing.Optional[int]: + def t2(a: Union[int, Optional[List[int]], None] = None) -> Union[int, Optional[List[int]], None]: ... @workflow - def wf(a: typing.Optional[int] = 1) -> typing.Optional[int]: + def wf(a: Optional[int] = 1) -> Optional[int]: t1() - return t2(c=a) + return t2(a=a) assert wf() is None - @task() - def t3(c: typing.Optional[int] = 3) -> typing.Optional[int]: - ... - with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): + @task() + def t3(c: Optional[int] = 3) -> Optional[int]: + ... + @workflow def wf(): return t3() diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index eb96fa05e9..62c45c346e 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -200,7 +200,9 @@ def z(a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"] assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]} - def z(a: typing.Optional[int] = None, b: typing.Optional[str] = None) -> typing.Tuple[int, str]: + def z( + a: typing.Optional[int] = None, b: typing.Optional[str] = None, c: typing.Union[typing.List[int], None] = None + ) -> typing.Tuple[int, str]: ... our_interface = transform_function_to_interface(z) @@ -209,6 +211,8 @@ def z(a: typing.Optional[int] = None, b: typing.Optional[str] = None) -> typing. assert params.parameters["a"].default.scalar.none_type == Void() assert not params.parameters["b"].required assert params.parameters["b"].default.scalar.none_type == Void() + assert not params.parameters["c"].required + assert params.parameters["c"].default.scalar.none_type == Void() def test_parameters_with_docstring(): From 6740affb7aedcbb7d090d7e05e928ed442839890 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 17 Jun 2022 03:21:22 +0800 Subject: [PATCH 18/18] Fix tests Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_composition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 28cbcdf706..3963c77c8d 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -35,7 +35,7 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = typing.NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", sub_int=int) @task def t1(a: int) -> nt: @@ -70,7 +70,7 @@ def my_wf(a: int) -> int: def test_lp_default_handling(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a)