From b1db9d0e8f22e5ddb4fe649a236c24dea55b0df3 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 18 Mar 2024 15:36:43 +0800 Subject: [PATCH 1/8] pyflyte remote `run`: async register & run Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu --- flytekit/remote/remote.py | 40 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 0948de5065..b80e1eb2d7 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import asyncio import base64 import configparser import functools @@ -724,7 +725,7 @@ def raw_register( raise AssertionError(f"Unknown entity of type {type(cp_entity)}") - def _serialize_and_register( + async def _serialize_and_register( self, entity: FlyteLocalEntity, settings: typing.Optional[SerializationSettings], @@ -755,7 +756,8 @@ def _serialize_and_register( _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) - ident = None + tasks = [] + loop = asyncio.get_event_loop() for entity, cp_entity in m.items(): if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting: # Only in the case of workflows can we use the dummy serialization settings. @@ -765,18 +767,25 @@ def _serialize_and_register( ) try: - ident = self.raw_register( - cp_entity, - settings=settings, - version=version, - create_default_launchplan=create_default_launchplan, - options=options, - og_entity=entity, + tasks.append( + loop.run_in_executor( + None, + functools.partial( + self.raw_register, + cp_entity=cp_entity, + settings=settings, + version=version, + create_default_launchplan=create_default_launchplan, + options=options, + og_entity=entity, + ), + ) ) except RegistrationSkipped: pass - return ident + res = await asyncio.gather(*tasks) + return res[-1] def register_task( self, @@ -803,7 +812,9 @@ def register_task( source_root=project_root, ) - ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version) + ident = asyncio.run( + self._serialize_and_register(entity=entity, settings=serialization_settings, version=version) + ) ft = self.fetch_task( ident.project, ident.domain, @@ -837,7 +848,9 @@ def register_workflow( b.domain = ident.domain b.version = ident.version serialization_settings = b.build() - ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) + ident = asyncio.run( + self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) + ) fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) fwf._python_interface = entity.python_interface return fwf @@ -1048,8 +1061,7 @@ def register_launch_plan( domain=domain or self.default_domain, version=version, ) - - ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss) + ident = asyncio.run(ident=self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss)) m = OrderedDict() idl_lp = get_serializable_launch_plan(m, ss, entity, recurse_downstream=False, options=options) try: From 02e28b69b390bbc65924b83fc4d9366f57764b07 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 20 Mar 2024 03:41:57 +0800 Subject: [PATCH 2/8] pyflyte `register`: async register Signed-off-by: Austin Liu lint Signed-off-by: Austin Liu fix dependency Signed-off-by: Austin Liu fix dependency Signed-off-by: Austin Liu fix dependency Signed-off-by: Austin Liu fix dependency Signed-off-by: Austin Liu fix Signed-off-by: Austin Liu lint Signed-off-by: Austin Liu --- flytekit/remote/remote.py | 4 +++- flytekit/tools/repo.py | 18 +++++++++++++++++- flytekit/tools/translator.py | 6 +++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index b80e1eb2d7..779052cfb0 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1036,7 +1036,9 @@ def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase] if isinstance(entity, PythonTask): return self.register_task(entity, serialization_settings, version) - return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options) + fwf = self.register_workflow(entity, serialization_settings, version, default_launch_plan, options) + fwf._python_interface = entity.python_interface + return fwf def register_launch_plan( self, diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index c3a456c20b..fb85c3be23 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -1,3 +1,5 @@ +import asyncio +import functools import os import tarfile import tempfile @@ -269,7 +271,7 @@ def register( click.secho("No Flyte entities were detected. Aborting!", fg="red") return - for cp_entity in registrable_entities: + def _raw_register(cp_entity: FlyteControlPlaneEntity): is_lp = False if isinstance(cp_entity, launch_plan.LaunchPlan): og_id = cp_entity.id @@ -296,4 +298,18 @@ def register( secho(og_id, reason="Dry run Mode!") except RegistrationSkipped: secho(og_id, "failed") + + async def _register(entities: typing.List[FlyteControlPlaneEntity]): + loop = asyncio.get_event_loop() + tasks = [] + for entity in entities: + tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) + if tasks: + await asyncio.wait(tasks) + return + + for type_ in FlyteControlPlaneEntity.__args__: + cp_entities = list(filter(lambda x: isinstance(x, type_), registrable_entities)) + asyncio.run(_register(cp_entities)) + click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b49639d23a..137e71d8b9 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -50,11 +50,11 @@ ] FlyteControlPlaneEntity = Union[ TaskSpec, - _launch_plan_models.LaunchPlan, - admin_workflow_models.WorkflowSpec, - workflow_model.Node, BranchNodeModel, ArrayNodeModel, + workflow_model.Node, + admin_workflow_models.WorkflowSpec, + _launch_plan_models.LaunchPlan, ] From 4695ea00f01ec64136e4422e1bbc495f854b0186 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Sat, 23 Mar 2024 08:53:24 +0800 Subject: [PATCH 3/8] only concurrent register tasks Signed-off-by: Austin Liu --- flytekit/tools/repo.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index fb85c3be23..074d449f28 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -11,7 +11,7 @@ from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings from flytekit.core.context_manager import FlyteContextManager from flytekit.loggers import logger -from flytekit.models import launch_plan +from flytekit.models import launch_plan, task from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote from flytekit.remote.remote import RegistrationSkipped, _get_git_repo_url @@ -308,8 +308,12 @@ async def _register(entities: typing.List[FlyteControlPlaneEntity]): await asyncio.wait(tasks) return - for type_ in FlyteControlPlaneEntity.__args__: - cp_entities = list(filter(lambda x: isinstance(x, type_), registrable_entities)) - asyncio.run(_register(cp_entities)) + # concurrent register + cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) + asyncio.run(_register(cp_task_entities)) + # serial register + cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) + for entity in cp_other_entities: + _raw_register(entity) click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") From 4a3f28788851b122613627d6b0e90277d9fc8fb1 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 26 Mar 2024 13:59:33 +0800 Subject: [PATCH 4/8] rollback cp_entity_type order Signed-off-by: Austin Liu --- flytekit/tools/translator.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 137e71d8b9..e6d3d5d9f8 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -48,13 +48,14 @@ ReferenceLaunchPlan, ReferenceEntity, ] + FlyteControlPlaneEntity = Union[ TaskSpec, + _launch_plan_models.LaunchPlan, + admin_workflow_models.WorkflowSpec, + workflow_model.Node, BranchNodeModel, ArrayNodeModel, - workflow_model.Node, - admin_workflow_models.WorkflowSpec, - _launch_plan_models.LaunchPlan, ] From 847ee62a3734ad08ad66f7c9408364ad969fb59d Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 26 Mar 2024 14:00:46 +0800 Subject: [PATCH 5/8] only concurrent register tasks in remote `run` Signed-off-by: Austin Liu --- flytekit/remote/remote.py | 90 +++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 779052cfb0..3118f4c2dc 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -725,7 +725,42 @@ def raw_register( raise AssertionError(f"Unknown entity of type {type(cp_entity)}") - async def _serialize_and_register( + async def _register( + self, + entity_map: OrderedDict, + settings: SerializationSettings, + version: str, + create_default_launchplan: bool = True, + options: Options = None, + og_entity: FlyteLocalEntity = None, + ) -> typing.Optional[Identifier]: + """ + Raw register method, can be used to register control plane entities. Usually if you have a Flyte Entity like a + WorkflowBase, Task, LaunchPlan then use other methods. This should be used only if you have already serialized entities + + :param cp_entity: The controlplane "serializable" version of a flyte entity. This is in the form that FlyteAdmin + understands. + :param settings: SerializationSettings to be used for registration - especially to identify the id + :param version: Version to be registered + :param create_default_launchplan: boolean that indicates if a default launch plan should be created + :param options: Options to be used if registering a default launch plan + :param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true + :return: Identifier of the created entity + """ + loop = asyncio.get_event_loop() + tasks = [] + res = [] + for entity, cp_entity in entity_map.items(): + tasks.append( + loop.run_in_executor( + None, functools.partial(self.raw_register, cp_entity, settings, version, og_entity=entity) + ) + ) + if tasks: + res, _ = await asyncio.wait(tasks) + return res + + def _serialize_and_register( self, entity: FlyteLocalEntity, settings: typing.Optional[SerializationSettings], @@ -741,7 +776,6 @@ async def _serialize_and_register( # Create dummy serialization settings for now. # TODO: Clean this up by using lazy usage of serialization settings in translator.py serialization_settings = settings - is_dummy_serialization_setting = False if not settings: serialization_settings = SerializationSettings( ImageConfig.auto_default_image(), @@ -749,43 +783,18 @@ async def _serialize_and_register( domain=self.default_domain, version=version, ) - is_dummy_serialization_setting = True - if serialization_settings.version is None: serialization_settings.version = version _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) - - tasks = [] - loop = asyncio.get_event_loop() - for entity, cp_entity in m.items(): - if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting: - # Only in the case of workflows can we use the dummy serialization settings. - raise user_exceptions.FlyteValueException( - settings, - f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}", - ) - - try: - tasks.append( - loop.run_in_executor( - None, - functools.partial( - self.raw_register, - cp_entity=cp_entity, - settings=settings, - version=version, - create_default_launchplan=create_default_launchplan, - options=options, - og_entity=entity, - ), - ) - ) - except RegistrationSkipped: - pass - - res = await asyncio.gather(*tasks) - return res[-1] + # concurrent register + cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items())) + ident = asyncio.run(self._register(cp_task_entity_map, settings, version)) + # serial register + cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items())) + for entity, cp_entity in cp_other_entities.items(): + ident = self.raw_register(cp_entity, settings, version, og_entity=entity) + return ident def register_task( self, @@ -812,9 +821,8 @@ def register_task( source_root=project_root, ) - ident = asyncio.run( - self._serialize_and_register(entity=entity, settings=serialization_settings, version=version) - ) + ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version) + ft = self.fetch_task( ident.project, ident.domain, @@ -848,9 +856,7 @@ def register_workflow( b.domain = ident.domain b.version = ident.version serialization_settings = b.build() - ident = asyncio.run( - self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) - ) + ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) fwf._python_interface = entity.python_interface return fwf @@ -1063,7 +1069,7 @@ def register_launch_plan( domain=domain or self.default_domain, version=version, ) - ident = asyncio.run(ident=self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss)) + ident = ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss) m = OrderedDict() idl_lp = get_serializable_launch_plan(m, ss, entity, recurse_downstream=False, options=options) try: From 2844869240ef8ecbe4b1f4db025063ff9567814d Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 26 Mar 2024 14:01:51 +0800 Subject: [PATCH 6/8] test child_workflow.py parent_wf by `pyflyte run --remote` Signed-off-by: Austin Liu --- .../integration/remote/test_remote.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 1e03b55098..854672db53 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -50,6 +50,37 @@ def register(): assert out.returncode == 0 +# only accept one arg for now. +# be careful with different wf_name in each wf. +def run(file_name, wf_name, arg_key, arg_value): + out = subprocess.run( + [ + "pyflyte", + "--verbose", + "-c", + CONFIG, + "run", + "--remote", + "--image", + IMAGE, + "--project", + PROJECT, + "--domain", + DOMAIN, + MODULE_PATH / file_name, + wf_name, + arg_key, + arg_value, + ] + ) + assert out.returncode == 0 + + +# test child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2. +def test_remote_run(): + run("child_workflow.py", "parent_wf", "--a", "3") + + def test_fetch_execute_launch_plan(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.hello_world.my_wf", version=VERSION) From da600318de03600b73ed51ad62e41926b3d6d66b Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Tue, 26 Mar 2024 14:40:19 +0800 Subject: [PATCH 7/8] clean up Signed-off-by: Austin Liu clean up Signed-off-by: Austin Liu clean up Signed-off-by: Austin Liu clean up Signed-off-by: Austin Liu clean up Signed-off-by: Austin Liu --- flytekit/remote/remote.py | 62 ++++++++++++------------------------ flytekit/tools/repo.py | 3 +- flytekit/tools/translator.py | 1 - 3 files changed, 21 insertions(+), 45 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 3118f4c2dc..856563f0db 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -725,42 +725,7 @@ def raw_register( raise AssertionError(f"Unknown entity of type {type(cp_entity)}") - async def _register( - self, - entity_map: OrderedDict, - settings: SerializationSettings, - version: str, - create_default_launchplan: bool = True, - options: Options = None, - og_entity: FlyteLocalEntity = None, - ) -> typing.Optional[Identifier]: - """ - Raw register method, can be used to register control plane entities. Usually if you have a Flyte Entity like a - WorkflowBase, Task, LaunchPlan then use other methods. This should be used only if you have already serialized entities - - :param cp_entity: The controlplane "serializable" version of a flyte entity. This is in the form that FlyteAdmin - understands. - :param settings: SerializationSettings to be used for registration - especially to identify the id - :param version: Version to be registered - :param create_default_launchplan: boolean that indicates if a default launch plan should be created - :param options: Options to be used if registering a default launch plan - :param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true - :return: Identifier of the created entity - """ - loop = asyncio.get_event_loop() - tasks = [] - res = [] - for entity, cp_entity in entity_map.items(): - tasks.append( - loop.run_in_executor( - None, functools.partial(self.raw_register, cp_entity, settings, version, og_entity=entity) - ) - ) - if tasks: - res, _ = await asyncio.wait(tasks) - return res - - def _serialize_and_register( + async def _serialize_and_register( self, entity: FlyteLocalEntity, settings: typing.Optional[SerializationSettings], @@ -789,12 +754,21 @@ def _serialize_and_register( _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) # concurrent register cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items())) - ident = asyncio.run(self._register(cp_task_entity_map, settings, version)) + tasks = [] + loop = asyncio.get_event_loop() + for entity, cp_entity in cp_task_entity_map.items(): + tasks.append( + loop.run_in_executor( + None, functools.partial(self.raw_register, cp_entity, settings, version, og_entity=entity) + ) + ) + ident = [] + ident.extend(await asyncio.gather(*tasks)) # serial register cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items())) for entity, cp_entity in cp_other_entities.items(): - ident = self.raw_register(cp_entity, settings, version, og_entity=entity) - return ident + ident.append(self.raw_register(cp_entity, settings, version, og_entity=entity)) + return ident[-1] def register_task( self, @@ -821,7 +795,9 @@ def register_task( source_root=project_root, ) - ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version) + ident = asyncio.run( + self._serialize_and_register(entity=entity, settings=serialization_settings, version=version) + ) ft = self.fetch_task( ident.project, @@ -856,7 +832,9 @@ def register_workflow( b.domain = ident.domain b.version = ident.version serialization_settings = b.build() - ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) + ident = asyncio.run( + self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan) + ) fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) fwf._python_interface = entity.python_interface return fwf @@ -1069,7 +1047,7 @@ def register_launch_plan( domain=domain or self.default_domain, version=version, ) - ident = ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss) + ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss) m = OrderedDict() idl_lp = get_serializable_launch_plan(m, ss, entity, recurse_downstream=False, options=options) try: diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 074d449f28..e01d05279b 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -304,8 +304,7 @@ async def _register(entities: typing.List[FlyteControlPlaneEntity]): tasks = [] for entity in entities: tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) - if tasks: - await asyncio.wait(tasks) + await asyncio.gather(*tasks) return # concurrent register diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index e6d3d5d9f8..b49639d23a 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -48,7 +48,6 @@ ReferenceLaunchPlan, ReferenceEntity, ] - FlyteControlPlaneEntity = Union[ TaskSpec, _launch_plan_models.LaunchPlan, From f0eab0a553f374042c4175087fe547da0dbb3eb0 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Thu, 11 Apr 2024 15:32:29 +0800 Subject: [PATCH 8/8] clean up Signed-off-by: Austin Liu --- flytekit/tools/repo.py | 2 +- tests/flytekit/integration/remote/test_remote.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index e01d05279b..b782f35496 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -299,7 +299,7 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity): except RegistrationSkipped: secho(og_id, "failed") - async def _register(entities: typing.List[FlyteControlPlaneEntity]): + async def _register(entities: typing.List[task.TaskSpec]): loop = asyncio.get_event_loop() tasks = [] for entity in entities: diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 854672db53..f23fc061d9 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -50,9 +50,7 @@ def register(): assert out.returncode == 0 -# only accept one arg for now. -# be careful with different wf_name in each wf. -def run(file_name, wf_name, arg_key, arg_value): +def run(file_name, wf_name, *args): out = subprocess.run( [ "pyflyte", @@ -69,8 +67,7 @@ def run(file_name, wf_name, arg_key, arg_value): DOMAIN, MODULE_PATH / file_name, wf_name, - arg_key, - arg_value, + *args, ] ) assert out.returncode == 0