Skip to content

Commit

Permalink
pyflyte remote run: async register & run
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>

fix

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Mar 19, 2024
1 parent 3f45131 commit 15ed365
Showing 1 changed file with 26 additions and 14 deletions.
40 changes: 26 additions & 14 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from __future__ import annotations

import asyncio
import base64
import configparser
import functools
Expand Down Expand Up @@ -722,7 +723,7 @@ def raw_register(

raise AssertionError(f"Unknown entity of type {type(cp_entity)}")

def _serialize_and_register(
async def _serialize_and_register(
self,
entity: FlyteLocalEntity,
settings: typing.Optional[SerializationSettings],
Expand Down Expand Up @@ -753,7 +754,8 @@ def _serialize_and_register(

_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)

ident = None
tasks = []
loop = asyncio.get_event_loop()
for entity, cp_entity in m.items():
if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting:
# Only in the case of workflows can we use the dummy serialization settings.
Expand All @@ -763,18 +765,25 @@ def _serialize_and_register(
)

try:
ident = self.raw_register(
cp_entity,
settings=settings,
version=version,
create_default_launchplan=create_default_launchplan,
options=options,
og_entity=entity,
tasks.append(
loop.run_in_executor(
None,
functools.partial(
self.raw_register,
cp_entity=cp_entity,
settings=settings,
version=version,
create_default_launchplan=create_default_launchplan,
options=options,
og_entity=entity,
),
)
)
except RegistrationSkipped:
pass

return ident
res = await asyncio.gather(*tasks)
return res[-1]

def register_task(
self, entity: PythonTask, serialization_settings: SerializationSettings, version: typing.Optional[str] = None
Expand All @@ -788,7 +797,9 @@ def register_task(
:param version: version that will be used to register. If not specified will default to using the serialization settings default
:return:
"""
ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
ident = asyncio.run(
self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
)
ft = self.fetch_task(
ident.project,
ident.domain,
Expand Down Expand Up @@ -822,7 +833,9 @@ def register_workflow(
b.domain = ident.domain
b.version = ident.version
serialization_settings = b.build()
ident = self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan)
ident = asyncio.run(
self._serialize_and_register(entity, serialization_settings, version, options, default_launch_plan)
)
fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version)
fwf._python_interface = entity.python_interface
return fwf
Expand Down Expand Up @@ -1032,8 +1045,7 @@ def register_launch_plan(
domain=domain or self.default_domain,
version=version,
)

ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss)
ident = asyncio.run(ident=self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss))
m = OrderedDict()
idl_lp = get_serializable_launch_plan(m, ss, entity, recurse_downstream=False, options=options)
try:
Expand Down

0 comments on commit 15ed365

Please sign in to comment.