Skip to content

Commit

Permalink
pyflyte register: async register
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Mar 19, 2024
1 parent 15ed365 commit 3bbb889
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
37 changes: 20 additions & 17 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import typing

Expand Down Expand Up @@ -180,23 +181,25 @@ def register(
remote = get_and_save_remote_with_click_context(ctx, project, domain, data_upload_location="flyte://data")
click.secho(f"Registering against {remote.config.platform.endpoint}")
try:
repo.register(
project,
domain,
image_config,
output,
destination_dir,
service_account,
raw_data_prefix,
version,
deref_symlinks,
fast=not non_fast,
package_or_module=package_or_module,
remote=remote,
env=env,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
asyncio.run(
repo.register(
project,
domain,
image_config,
output,
destination_dir,
service_account,
raw_data_prefix,
version,
deref_symlinks,
fast=not non_fast,
package_or_module=package_or_module,
remote=remote,
env=env,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
)
)
except Exception as e:
raise e
12 changes: 10 additions & 2 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import functools
import os
import tarfile
import tempfile
Expand Down Expand Up @@ -206,7 +208,7 @@ def secho(i: Identifier, state: str = "success", reason: str = None, op: str = "
)


def register(
async def register(
project: str,
domain: str,
image_config: ImageConfig,
Expand Down Expand Up @@ -269,7 +271,7 @@ def register(
click.secho("No Flyte entities were detected. Aborting!", fg="red")
return

for cp_entity in registrable_entities:
def _raw_register(cp_entity):
is_lp = False
if isinstance(cp_entity, launch_plan.LaunchPlan):
og_id = cp_entity.id
Expand All @@ -296,4 +298,10 @@ def register(
secho(og_id, reason="Dry run Mode!")
except RegistrationSkipped:
secho(og_id, "failed")

tasks = []
loop = asyncio.get_event_loop()
for cp_entity in registrable_entities:
tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, cp_entity)))
res = await asyncio.gather(*tasks)
click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green")

0 comments on commit 3bbb889

Please sign in to comment.