Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyflyte run & register asynchronously #2276

Merged
merged 8 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from __future__ import annotations

import asyncio
import base64
import configparser
import functools
Expand Down Expand Up @@ -724,7 +725,7 @@ def raw_register(

raise AssertionError(f"Unknown entity of type {type(cp_entity)}")

def _serialize_and_register(
async def _serialize_and_register(
self,
entity: FlyteLocalEntity,
settings: typing.Optional[SerializationSettings],
Expand All @@ -740,43 +741,34 @@ def _serialize_and_register(
# Create dummy serialization settings for now.
# TODO: Clean this up by using lazy usage of serialization settings in translator.py
serialization_settings = settings
is_dummy_serialization_setting = False
if not settings:
serialization_settings = SerializationSettings(
ImageConfig.auto_default_image(),
project=self.default_project,
domain=self.default_domain,
version=version,
)
is_dummy_serialization_setting = True

if serialization_settings.version is None:
serialization_settings.version = version

_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)

ident = None
for entity, cp_entity in m.items():
if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting:
# Only in the case of workflows can we use the dummy serialization settings.
raise user_exceptions.FlyteValueException(
settings,
f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}",
)

try:
ident = self.raw_register(
cp_entity,
settings=settings,
version=version,
create_default_launchplan=create_default_launchplan,
options=options,
og_entity=entity,
# concurrent register
cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items()))
tasks = []
loop = asyncio.get_event_loop()
for entity, cp_entity in cp_task_entity_map.items():
tasks.append(
loop.run_in_executor(
None, functools.partial(self.raw_register, cp_entity, settings, version, og_entity=entity)
)
except RegistrationSkipped:
pass

return ident
)
ident = []
ident.extend(await asyncio.gather(*tasks))
# serial register
cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items()))
for entity, cp_entity in cp_other_entities.items():
ident.append(self.raw_register(cp_entity, settings, version, og_entity=entity))
return ident[-1]

def register_task(
self,
Expand All @@ -803,7 +795,10 @@ def register_task(
source_root=project_root,
)

ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
ident = asyncio.run(
self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
)

ft = self.fetch_task(
ident.project,
ident.domain,
Expand Down Expand Up @@ -837,7 +832,9 @@ def register_workflow(
b.domain = ident.domain
b.version = ident.version
serialization_settings = b.build()
ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan)
ident = asyncio.run(
self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan)
)
fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version)
fwf._python_interface = entity.python_interface
return fwf
Expand Down Expand Up @@ -1023,7 +1020,9 @@ def _get_image_names(entity: typing.Union[PythonAutoContainerTask, WorkflowBase]

if isinstance(entity, PythonTask):
return self.register_task(entity, serialization_settings, version)
return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options)
fwf = self.register_workflow(entity, serialization_settings, version, default_launch_plan, options)
fwf._python_interface = entity.python_interface
return fwf

def register_launch_plan(
self,
Expand All @@ -1048,7 +1047,6 @@ def register_launch_plan(
domain=domain or self.default_domain,
version=version,
)

ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss)
m = OrderedDict()
idl_lp = get_serializable_launch_plan(m, ss, entity, recurse_downstream=False, options=options)
Expand Down
23 changes: 21 additions & 2 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import functools
import os
import tarfile
import tempfile
Expand All @@ -9,7 +11,7 @@
from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.core.context_manager import FlyteContextManager
from flytekit.loggers import logger
from flytekit.models import launch_plan
from flytekit.models import launch_plan, task
from flytekit.models.core.identifier import Identifier
from flytekit.remote import FlyteRemote
from flytekit.remote.remote import RegistrationSkipped, _get_git_repo_url
Expand Down Expand Up @@ -269,7 +271,7 @@ def register(
click.secho("No Flyte entities were detected. Aborting!", fg="red")
return

for cp_entity in registrable_entities:
def _raw_register(cp_entity: FlyteControlPlaneEntity):
is_lp = False
if isinstance(cp_entity, launch_plan.LaunchPlan):
og_id = cp_entity.id
Expand All @@ -296,4 +298,21 @@ def register(
secho(og_id, reason="Dry run Mode!")
except RegistrationSkipped:
secho(og_id, "failed")

async def _register(entities: typing.List[task.TaskSpec]):
loop = asyncio.get_event_loop()
tasks = []
for entity in entities:
tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity)))
await asyncio.gather(*tasks)
return

# concurrent register
cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities))
asyncio.run(_register(cp_task_entities))
# serial register
cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities))
for entity in cp_other_entities:
_raw_register(entity)

click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green")
28 changes: 28 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ def register():
assert out.returncode == 0


def run(file_name, wf_name, *args):
out = subprocess.run(
[
"pyflyte",
"--verbose",
"-c",
CONFIG,
"run",
"--remote",
"--image",
IMAGE,
"--project",
PROJECT,
"--domain",
DOMAIN,
MODULE_PATH / file_name,
wf_name,
*args,
]
)
assert out.returncode == 0


# test child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2.
def test_remote_run():
run("child_workflow.py", "parent_wf", "--a", "3")


def test_fetch_execute_launch_plan(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
flyte_launch_plan = remote.fetch_launch_plan(name="basic.hello_world.my_wf", version=VERSION)
Expand Down
Loading