diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ca7a6cf20d..a9b7c313f0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -10,7 +10,6 @@ import click as _click from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit import PythonFunctionTask from flytekit.configuration import ( SERIALIZED_CONTEXT_ENV_VAR, FastSerializationSettings, @@ -23,7 +22,7 @@ from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapTaskResolver from flytekit.core.promise import VoidPromise from flytekit.exceptions import scopes as _scoped_exceptions from flytekit.exceptions import scopes as _scopes @@ -391,12 +390,8 @@ def _execute_map_task( with setup_execution( raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir ) as ctx: - resolver_obj = load_object_from_module(resolver) - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if not isinstance(_task_def, PythonFunctionTask): - raise Exception("Map tasks cannot be run with instance tasks.") - map_task = MapPythonTask(_task_def, max_concurrency) + mtr = MapTaskResolver() + map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) task_index = _compute_array_job_index() output_prefix = os.path.join(output_prefix, str(task_index)) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 3c24e65db2..eae7a8e0cf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -21,6 +21,28 @@ T = typing.TypeVar("T") +def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str: + if isinstance(v, tuple): + if v[1]: + return f"{k}: {v[0]}={v[1]}" + return f"{k}: {v[0]}" + return f"{k}: {v}" + + +def repr_type_signature(io: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]]) -> str: + """ + Converts an inputs and outputs to a type signature + """ + s = "(" + i = 0 + for k, v in io.items(): + if i > 0: + s += ", " + s += repr_kv(k, v) + i = i + 1 + return s + ")" + + class Interface(object): """ A Python native interface object, like inspect.signature but simpler. @@ -57,7 +79,9 @@ def __init__( variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): # type: ignore + class Output( # type: ignore + collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore + ): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -167,6 +191,12 @@ def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: new_outputs[k] = v return Interface(self._inputs, new_outputs) + def __str__(self): + return f"{repr_type_signature(self._inputs)} -> {repr_type_signature(self._outputs)}" + + def __repr__(self): + return str(self) + def transform_inputs_to_parameters( ctx: context_manager.FlyteContext, interface: Interface @@ -220,7 +250,7 @@ def transform_interface_to_typed_interface( return _interface_models.TypedInterface(inputs_map, outputs_map) -def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: +def transform_types_to_list_of_type(m: Dict[str, type], bound_inputs: typing.Set[str]) -> Dict[str, type]: """ Converts a given variables to be collections of their type. This is useful for array jobs / map style code. It will create a collection of types even if any one these types is not a collection type @@ -230,6 +260,10 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: all_types_are_collection = True for k, v in m.items(): + if k in bound_inputs: + # Skip the inputs that are bound. If they are bound, it does not matter if they are collection or + # singletons + continue v_type = type(v) if v_type != typing.List and v_type != list: all_types_are_collection = False @@ -240,17 +274,22 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] # type: ignore + if k in bound_inputs: + om[k] = v + else: + om[k] = typing.List[v] # type: ignore return om # type: ignore -def transform_interface_to_list_interface(interface: Interface) -> Interface: +def transform_interface_to_list_interface(interface: Interface, bound_inputs: typing.Set[str]) -> Interface: """ Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map like functions + :param interface: Interface to be upgraded toa list interface + :param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars. """ - map_inputs = transform_types_to_list_of_type(interface.inputs) - map_outputs = transform_types_to_list_of_type(interface.outputs) + map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs) + map_outputs = transform_types_to_list_of_type(interface.outputs, set()) return Interface(inputs=map_inputs, outputs=map_outputs) @@ -288,7 +327,6 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 48d0f0b335..83b2542fe3 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -2,71 +2,87 @@ Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with a reference task as well as run-time parameters that limit execution concurrency and failure tolerations. """ - +import functools +import hashlib +import logging import os import typing from contextlib import contextmanager -from itertools import count -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set from flytekit.configuration import SerializationSettings from flytekit.core import tracker -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.tracker import TrackedInstance from flytekit.exceptions import scopes as exception_scopes from flytekit.models.array_job import ArrayJob from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql +from flytekit.tools.module_loader import load_object_from_module class MapPythonTask(PythonTask): """ A MapPythonTask defines a :py:class:`flytekit.PythonTask` which specifies how to run an inner :py:class:`flytekit.PythonFunctionTask` across a range of inputs in parallel. - TODO: support lambda functions """ - # To support multiple map tasks declared around identical python function tasks, we keep a global count of - # MapPythonTask instances to uniquely differentiate map task names for each declared instance. - _ids = count(0) - def __init__( self, - python_function_task: PythonFunctionTask, + python_function_task: typing.Union[PythonFunctionTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, + bound_inputs: Optional[Set[str]] = None, **kwargs, ): """ + Wrapper that creates a MapPythonTask + :param python_function_task: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given - batch size + batch size :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete - successfully before terminating this task and marking it successful. + successfully before terminating this task and marking it successful + :param bound_inputs: List[str] specifies a list of variable names within the interface of python_function_task, + that are already bound and should not be considered as list inputs, but scalar values. This is mostly + useful at runtime and is passed in by MapTaskResolver. This field is not required when a `partial` method + is specified. The bound_vars will be auto-deduced from the `partial.keywords`. """ - if len(python_function_task.python_interface.inputs.keys()) > 1: - raise ValueError("Map tasks only accept python function tasks with 0 or 1 inputs") + self._partial = None + if isinstance(python_function_task, functools.partial): + self._partial = python_function_task + actual_task = self._partial.func + else: + actual_task = python_function_task + + if not isinstance(actual_task, PythonFunctionTask): + raise ValueError("Map tasks can only compose of Python Functon Tasks currently") - if len(python_function_task.python_interface.outputs.keys()) > 1: + if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") - collection_interface = transform_interface_to_list_interface(python_function_task.python_interface) - instance = next(self._ids) - _, mod, f, _ = tracker.extract_task_module(python_function_task.task_function) - name = f"{mod}.mapper_{f}_{instance}" - - self._cmd_prefix = None - self._run_task = python_function_task - self._max_concurrency = concurrency - self._min_success_ratio = min_success_ratio - self._array_task_interface = python_function_task.python_interface - if "metadata" not in kwargs and python_function_task.metadata: - kwargs["metadata"] = python_function_task.metadata - if "security_ctx" not in kwargs and python_function_task.security_context: - kwargs["security_ctx"] = python_function_task.security_context + self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set() + if self._partial: + self._bound_inputs = set(self._partial.keywords.keys()) + + collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) + self._run_task: PythonFunctionTask = actual_task + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() + name = f"{mod}.map_{f}_{h}" + + self._cmd_prefix: typing.Optional[typing.List[str]] = None + self._max_concurrency: typing.Optional[int] = concurrency + self._min_success_ratio: typing.Optional[float] = min_success_ratio + self._array_task_interface = actual_task.python_interface + if "metadata" not in kwargs and actual_task.metadata: + kwargs["metadata"] = actual_task.metadata + if "security_ctx" not in kwargs and actual_task.security_context: + kwargs["security_ctx"] = actual_task.security_context super().__init__( name=name, interface=collection_interface, @@ -76,7 +92,15 @@ def __init__( **kwargs, ) + @property + def bound_inputs(self) -> Set[str]: + return self._bound_inputs + def get_command(self, settings: SerializationSettings) -> List[str]: + """ + TODO ADD bound variables to the resolver. Maybe we need a different resolver? + """ + mt = MapTaskResolver() container_args = [ "pyflyte-map-execute", "--inputs", @@ -90,9 +114,9 @@ def get_command(self, settings: SerializationSettings) -> List[str]: "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - self._run_task.task_resolver.location, + mt.name(), "--", - *self._run_task.task_resolver.loader_args(settings, self._run_task), + *mt.loader_args(settings, self), ] if self._cmd_prefix: @@ -100,7 +124,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): - self._cmd_prefix = cmd # type: ignore + self._cmd_prefix = cmd @contextmanager def prepare_target(self): @@ -135,6 +159,18 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] def run_task(self) -> PythonFunctionTask: return self._run_task + def __call__(self, *args, **kwargs): + """ + This call method modifies the kwargs and adds kwargs from partial. + This is mostly done in the local_execute and compilation only. + At runtime, the map_task is created with all the inputs filled in. to support this, we have modified + the map_task interface in the constructor. + """ + if self._partial: + """If partial exists, then mix-in all partial values""" + kwargs = {**self._partial.keywords, **kwargs} + return super().__call__(*args, **kwargs) + def execute(self, **kwargs) -> Any: ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: @@ -191,7 +227,11 @@ def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: task_index = self._compute_array_job_index() map_task_inputs = {} for k in self.interface.inputs.keys(): - map_task_inputs[k] = kwargs[k][task_index] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + map_task_inputs[k] = v[task_index] + else: + map_task_inputs[k] = v return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs) def _raw_execute(self, **kwargs) -> Any: @@ -213,7 +253,11 @@ def _raw_execute(self, **kwargs) -> Any: for i in range(len(kwargs[any_input_key])): single_instance_inputs = {} for k in self.interface.inputs.keys(): - single_instance_inputs[k] = kwargs[k][i] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + single_instance_inputs[k] = kwargs[k][i] + else: + single_instance_inputs[k] = kwargs[k] o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs) if outputs_expected: outputs.append(o) @@ -221,7 +265,12 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs): +def map_task( + task_function: typing.Union[PythonFunctionTask, functools.partial], + concurrency: int = 0, + min_success_ratio: float = 1.0, + **kwargs, +): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. @@ -267,8 +316,63 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_succes successfully before terminating this task and marking it successful. """ - if not isinstance(task_function, PythonFunctionTask): - raise ValueError( - f"Only Flyte python task types are supported in map tasks currently, received {type(task_function)}" - ) return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs) + + +class MapTaskResolver(TrackedInstance, TaskResolverMixin): + """ + Special resolver that is used for MapTasks. + This exists because it is possible that MapTasks are created using nested "partial" subtasks. + When a maptask is created its interface is interpolated from the interface of the subtask - the interpolation, + simply converts every input into a list/collection input. + + For example: + interface -> (i: int, j: str) -> str => map_task interface -> (i: List[int], j: List[str]) -> List[str] + + But in cases in which `j` is bound to a fixed value by using `functools.partial` we need a way to ensure that + the interface is not simply interpolated, but only the unbound inputs are interpolated. + + .. code-block:: python + + def foo((i: int, j: str) -> str: + ... + + mt = map_task(functools.partial(foo, j=10)) + + print(mt.interface) + + output: + + (i: List[int], j: str) -> List[str] + + But, at runtime this information is lost. To reconstruct this, we use MapTaskResolver that records the "bound vars" + and then at runtime reconstructs the interface with this knowledge + """ + + def name(self) -> str: + return "MapTaskResolver" + + def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPythonTask: + """ + Loader args should be of the form + vars "var1,var2,.." resolver "resolver" [resolver_args] + """ + _, bound_vars, _, resolver, *resolver_args = loader_args + logging.info(f"MapTask found task resolver {resolver} and arguments {resolver_args}") + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + bound_inputs = set(bound_vars.split(",")) + return MapPythonTask(python_function_task=_task_def, max_concurrency=max_concurrency, bound_inputs=bound_inputs) + + def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore + return [ + "vars", + f'{",".join(t.bound_inputs)}', + "resolver", + t.run_task.task_resolver.location, + *t.run_task.task_resolver.loader_args(settings, t.run_task), + ] + + def get_all_tasks(self) -> List[Task]: + raise NotImplementedError("MapTask resolver cannot return every instance of the map task") diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 0d6788ac92..014b88f4f3 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -355,8 +355,12 @@ def simple_pod_task(i: int): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.test_pod", "task-name", diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 95927873d0..d032aca2d1 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -1,3 +1,4 @@ +import functools import typing from collections import OrderedDict @@ -6,7 +7,7 @@ import flytekit.configuration from flytekit import LaunchPlan, map_task from flytekit.configuration import Image, ImageConfig -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapPythonTask, MapTaskResolver from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable @@ -36,6 +37,11 @@ def t2(a: int) -> str: return str(b) +@task(cache=True, cache_version="1") +def t3(a: int, b: str, c: float) -> str: + pass + + # This test is for documentation. def test_map_docs(): # test_map_task_start @@ -87,8 +93,12 @@ def test_serialization(serialization_settings): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.flytekit.unit.core.test_map_task", "task-name", @@ -177,15 +187,42 @@ def test_inputs_outputs_length(): def many_inputs(a: int, b: str, c: float) -> str: return f"{a} - {b} - {c}" - with pytest.raises(ValueError): - _ = map_task(many_inputs) + m = map_task(many_inputs) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_24c08b3a2f9c2e389ad9fc6a03482cf9" + r_m = MapPythonTask(many_inputs) + assert str(r_m.python_interface) == str(m.python_interface) + + p1 = functools.partial(many_inputs, c=1.0) + m = map_task(p1) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_697aa7389996041183cf6cfd102be4f7" + r_m = MapPythonTask(many_inputs, bound_inputs=set("c")) + assert str(r_m.python_interface) == str(m.python_interface) + + p2 = functools.partial(p1, b="hello") + m = map_task(p2) + assert m.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_cc18607da7494024a402a5fa4b3ea5c6" + r_m = MapPythonTask(many_inputs, bound_inputs={"c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + p3 = functools.partial(p2, a=1) + m = map_task(p3) + assert m.python_interface.inputs == {"a": int, "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_52fe80b04781ea77ef6f025f4b49abef" + r_m = MapPythonTask(many_inputs, bound_inputs={"a", "c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + with pytest.raises(TypeError): + m(a=[1, 2, 3]) @task def many_outputs(a: int) -> (int, str): return a, f"{a}" with pytest.raises(ValueError): - _ = map_task(many_inputs) + _ = map_task(many_outputs) def test_map_task_metadata(): @@ -194,3 +231,34 @@ def test_map_task_metadata(): assert mapped_1.metadata is map_meta mapped_2 = map_task(t2) assert mapped_2.metadata is t2.metadata + + +def test_map_task_resolver(serialization_settings): + list_outputs = {"o0": typing.List[str]} + mt = map_task(t3) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + assert mtr.name() == "MapTaskResolver" + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello", c=1.0)) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello")) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py new file mode 100644 index 0000000000..0a78c825f8 --- /dev/null +++ b/tests/flytekit/unit/core/test_partials.py @@ -0,0 +1,181 @@ +import typing +from collections import OrderedDict +from functools import partial + +import pandas as pd +import pytest + +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.map_task import MapTaskResolver, map_task +from flytekit.core.task import TaskMetadata, task +from flytekit.core.workflow import workflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + +def test_basics_1(): + @task + def t1(a: int, b: str, c: float) -> int: + return a + len(b) + int(c) + + outside_p = partial(t1, b="hello", c=3.14) + + @workflow + def my_wf_1(a: int) -> typing.Tuple[int, int]: + inner_partial = partial(t1, b="world", c=2.7) + out = outside_p(a=a) + inside = inner_partial(a=a) + return out, inside + + with pytest.raises(Exception): + get_serializable(OrderedDict(), serialization_settings, outside_p) + + # check the od todo + od = OrderedDict() + wf_1_spec = get_serializable(od, serialization_settings, my_wf_1) + tts, wspecs, lps = gather_dependent_entities(od) + tts = [t for t in tts.values()] + assert len(tts) == 1 + assert len(wf_1_spec.template.nodes) == 2 + assert wf_1_spec.template.nodes[0].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[1].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[0].inputs[0].binding.promise.var == "a" + assert wf_1_spec.template.nodes[0].inputs[1].binding.scalar is not None + assert wf_1_spec.template.nodes[0].inputs[2].binding.scalar is not None + + @task + def get_str() -> str: + return "got str" + + bind_c = partial(t1, c=2.7) + + @workflow + def my_wf_2(a: int) -> int: + s = get_str() + inner_partial = partial(bind_c, b=s) + inside = inner_partial(a=a) + return inside + + wf_2_spec = get_serializable(OrderedDict(), serialization_settings, my_wf_2) + assert len(wf_2_spec.template.nodes) == 2 + + +def test_map_task_types(): + @task(cache=True, cache_version="1") + def t3(a: int, b: str, c: float) -> str: + return str(a) + b + str(c) + + t3_bind_b1 = partial(t3, b="hello") + t3_bind_b2 = partial(t3, b="world") + t3_bind_c1 = partial(t3_bind_b1, c=3.14) + t3_bind_c2 = partial(t3_bind_b2, c=2.78) + + mt1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + + @task + def print_lists(i: typing.List[str], j: typing.List[str]): + print(f"First: {i}") + print(f"Second: {j}") + + @workflow + def wf_out(a: typing.List[int]): + i = mt1(a=a) + j = mt2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_out(a=[1, 2]) + + @workflow + def wf_in(a: typing.List[int]): + mt_in1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt_in2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + i = mt_in1(a=a) + j = mt_in2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_in(a=[1, 2]) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf_in) + tts, _, _ = gather_dependent_entities(od) + assert len(tts) == 2 # one map task + the print task + assert ( + wf_spec.template.nodes[0].task_node.reference_id.name == wf_spec.template.nodes[1].task_node.reference_id.name + ) + assert wf_spec.template.nodes[0].inputs[0].binding.promise is not None # comes from wf input + assert wf_spec.template.nodes[1].inputs[0].binding.collection is not None # bound to static list + assert wf_spec.template.nodes[1].inputs[1].binding.scalar is not None # these are bound + assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None + + +def test_everything(): + @task + def get_static_list() -> typing.List[float]: + return [3.14, 2.718] + + @task + def get_list_of_pd(s: int) -> typing.List[pd.DataFrame]: + df1 = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + df2 = pd.DataFrame({"Name": ["Rachel", "Eve", "Mary"], "Age": [22, 23, 24]}) + if s == 2: + return [df1, df2] + else: + return [df1, df2, df1] + + @task + def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.DataFrame) -> str: + return str(a) + f"pdsize{len(a2)}" + b + str(c) + "&&" + str(d) + + t3_bind_b1 = partial(t3, b="hello") + t3_bind_b2 = partial(t3, b="world") + t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) + + mt1 = map_task(t3_bind_c1) + + mr = MapTaskResolver() + aa = mr.loader_args(serialization_settings, mt1) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["b", "c", "d"] + + @task + def print_lists(i: typing.List[str], j: typing.List[str]) -> str: + print(f"First: {i}") + print(f"Second: {j}") + return f"{i}-{j}" + + @dynamic + def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str: + i = mt1(a=a, a2=a2) + t3_bind_c2 = partial(t3_bind_b2, c=[1.0, 2.0, 3.0], d=sl) + mt_in2 = map_task(t3_bind_c2) + dfs = get_list_of_pd(s=3) + j = mt_in2(a=[3, 4, 5], a2=dfs) + return print_lists(i=i, j=j) + + @workflow + def wf_dt(a: typing.List[int]) -> str: + sl = get_static_list() + dfs = get_list_of_pd(s=2) + return dt1(a=a, a2=dfs, sl=sl) + + print(wf_dt(a=[1, 2])) + assert ( + wf_dt(a=[1, 2]) + == "['1pdsize2hello[6.674, 1.618, 6.626]&&[1.0]', '2pdsize3hello[6.674, 1.618, 6.626]&&[1.0]']-['3pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '4pdsize3world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '5pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]']" + )