Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote workflow & task execution #2094

Merged
merged 3 commits into from
Jan 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 102 additions & 43 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.exceptions.system import FlyteSystemException
from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback
from flytekit.interaction.string_literals import literal_string_repr
from flytekit.loggers import logger
from flytekit.models import security
from flytekit.models.common import RawOutputDataConfig
Expand Down Expand Up @@ -240,10 +241,10 @@ class RunLevelParams(PyFlyteParams):
param_decls=["--limit", "limit"],
required=False,
type=int,
default=10,
default=50,
hidden=True,
show_default=True,
help="Use this to limit number of launch plans retreived from the backend, "
"if `from-server` option is used",
help="Use this to limit number of entities to fetch",
)
)
cluster_pool: str = make_click_option_field(
Expand Down Expand Up @@ -553,54 +554,72 @@ def _run(*args, **kwargs):
return _run


class DynamicLaunchPlanCommand(click.RichCommand):
class DynamicEntityLaunchCommand(click.RichCommand):
"""
This is a dynamic command that is created for each launch plan. This is used to execute a launch plan.
It will fetch the launch plan from remote and create parameters from all the inputs of the launch plan.
"""

def __init__(self, name: str, h: str, lp_name: str, **kwargs):
LP_LAUNCHER = "lp"
TASK_LAUNCHER = "task"

def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs):
super().__init__(name=name, help=h, **kwargs)
self._lp_name = lp_name
self._lp = None
self._entity_name = entity_name
self._launcher = launcher
self._entity = None

def _fetch_launch_plan(self, ctx: click.Context) -> FlyteLaunchPlan:
if self._lp:
return self._lp
def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]:
if self._entity:
return self._entity
run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
self._lp = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._lp_name)
return self._lp
if self._launcher == self.LP_LAUNCHER:
entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name)
else:
entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name)
self._entity = entity
return entity

def _get_params(
self,
ctx: click.Context,
inputs: typing.Dict[str, Variable],
native_inputs: typing.Dict[str, type],
fixed: typing.Dict[str, Literal],
defaults: typing.Dict[str, Parameter],
fixed: typing.Optional[typing.Dict[str, Literal]] = None,
defaults: typing.Optional[typing.Dict[str, Parameter]] = None,
) -> typing.List["click.Parameter"]:
params = []
flyte_ctx = context_manager.FlyteContextManager.current_context()
for name, var in inputs.items():
if fixed and name in fixed:
continue
required = True
default_val = None
if defaults and name in defaults:
required = False
params.append(to_click_option(ctx, flyte_ctx, name, var, native_inputs[name], None, required))
if not defaults[name].required:
required = False
default_val = literal_string_repr(defaults[name].default) if defaults[name].default else None
params.append(to_click_option(ctx, flyte_ctx, name, var, native_inputs[name], default_val, required))
return params

def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]:
if not self.params:
self.params = []
lp = self._fetch_launch_plan(ctx)
if lp.interface:
if lp.interface.inputs:
types = TypeEngine.guess_python_types(lp.interface.inputs)
self.params = self._get_params(
ctx, lp.interface.inputs, types, lp.fixed_inputs.literals, lp.default_inputs.parameters
)
entity = self._fetch_entity(ctx)
if entity.interface:
if entity.interface.inputs:
types = TypeEngine.guess_python_types(entity.interface.inputs)
if isinstance(entity, FlyteLaunchPlan):
self.params = self._get_params(
ctx,
entity.interface.inputs,
types,
entity.fixed_inputs.literals,
entity.default_inputs.parameters,
)
else:
self.params = self._get_params(ctx, entity.interface.inputs, types)

return super().get_params(ctx)

Expand All @@ -611,40 +630,61 @@ def invoke(self, ctx: click.Context) -> typing.Any:
"""
run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
lp = self._fetch_launch_plan(ctx)
entity = self._fetch_entity(ctx)
run_remote(
r,
lp,
entity,
run_level_params.project,
run_level_params.domain,
ctx.params,
run_level_params,
type_hints=lp.python_interface.inputs if lp.python_interface else None,
type_hints=entity.python_interface.inputs if entity.python_interface else None,
)


class RemoteLaunchPlanGroup(click.RichGroup):
class RemoteEntityGroup(click.RichGroup):
"""
click multicommand that retrieves launchplans from a remote flyte instance and executes them.
"""

COMMAND_NAME = "remote-launchplan"
LAUNCHPLAN_COMMAND = "remote-launchplan"
WORKFLOW_COMMAND = "remote-workflow"
TASK_COMMAND = "remote-task"

def __init__(self):
def __init__(self, command_name: str):
super().__init__(
name="from-server",
help="Retrieve launchplans from a remote flyte instance and execute them.",
name=command_name,
help=f"Retrieve {command_name} from a remote flyte instance and execute them.",
params=[
click.Option(
["--limit"], help="Limit the number of launchplans to retrieve.", default=10, show_default=True
["--limit", "limit"],
help=f"Limit the number of {command_name}'s to retrieve.",
default=50,
show_default=True,
)
],
)
self._lps = []
self._command_name = command_name
self._entities = []

def _get_entities(self, r: FlyteRemote, project: str, domain: str, limit: int) -> typing.List[str]:
"""
Retreieves the right entities from the remote flyte instance.
"""
if self._command_name == self.LAUNCHPLAN_COMMAND:
lps = r.client.list_launch_plan_ids_paginated(project=project, domain=domain, limit=limit)
return [l.name for l in lps[0]]
elif self._command_name == self.WORKFLOW_COMMAND:
wfs = r.client.list_workflow_ids_paginated(project=project, domain=domain, limit=limit)
return [w.name for w in wfs[0]]
elif self._command_name == self.TASK_COMMAND:
tasks = r.client.list_task_ids_paginated(project=project, domain=domain, limit=limit)
return [t.name for t in tasks[0]]
return []

def list_commands(self, ctx):
if self._lps or ctx.obj is None:
return self._lps
if self._entities or ctx.obj is None:
return self._entities

run_level_params: RunLevelParams = ctx.obj
r = run_level_params.remote_instance()
Expand All @@ -653,17 +693,28 @@ def list_commands(self, ctx):
with progress:
progress.start_task(task)
try:
lps = r.client.list_launch_plan_ids_paginated(
project=run_level_params.project, domain=run_level_params.domain, limit=run_level_params.limit
self._entities = self._get_entities(
r, run_level_params.project, run_level_params.domain, run_level_params.limit
)
self._lps = [l.name for l in lps[0]]
return self._lps
return self._entities
except FlyteSystemException as e:
pretty_print_exception(e)
return []

def get_command(self, ctx, name):
return DynamicLaunchPlanCommand(name=name, h="Execute a launchplan from remote.", lp_name=name)
if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]:
return DynamicEntityLaunchCommand(
name=name,
h=f"Execute a {self._command_name}.",
entity_name=name,
launcher=DynamicEntityLaunchCommand.LP_LAUNCHER,
)
return DynamicEntityLaunchCommand(
name=name,
h=f"Execute a {self._command_name}.",
entity_name=name,
launcher=DynamicEntityLaunchCommand.TASK_LAUNCHER,
)


class WorkflowCommand(click.RichGroup):
Expand Down Expand Up @@ -789,7 +840,11 @@ def list_commands(self, ctx, add_remote: bool = True):
self._files = [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"]
self._files = sorted(self._files)
if add_remote:
self._files = self._files + [RemoteLaunchPlanGroup.COMMAND_NAME]
self._files = self._files + [
RemoteEntityGroup.LAUNCHPLAN_COMMAND,
RemoteEntityGroup.WORKFLOW_COMMAND,
RemoteEntityGroup.TASK_COMMAND,
]
return self._files

def get_command(self, ctx, filename):
Expand All @@ -800,8 +855,12 @@ def get_command(self, ctx, filename):
params.update(ctx.params)
params.update(ctx.obj)
ctx.obj = self._run_params.from_dict(params)
if filename == RemoteLaunchPlanGroup.COMMAND_NAME:
return RemoteLaunchPlanGroup()
if filename == RemoteEntityGroup.LAUNCHPLAN_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND)
elif filename == RemoteEntityGroup.WORKFLOW_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.WORKFLOW_COMMAND)
elif filename == RemoteEntityGroup.TASK_COMMAND:
return RemoteEntityGroup(RemoteEntityGroup.TASK_COMMAND)
return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}")


Expand Down