diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 9c2bb2c17a..79e7f0fce5 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -32,6 +32,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.exceptions.system import FlyteSystemException from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback +from flytekit.interaction.string_literals import literal_string_repr from flytekit.loggers import logger from flytekit.models import security from flytekit.models.common import RawOutputDataConfig @@ -240,10 +241,10 @@ class RunLevelParams(PyFlyteParams): param_decls=["--limit", "limit"], required=False, type=int, - default=10, + default=50, + hidden=True, show_default=True, - help="Use this to limit number of launch plans retreived from the backend, " - "if `from-server` option is used", + help="Use this to limit number of entities to fetch", ) ) cluster_pool: str = make_click_option_field( @@ -553,32 +554,40 @@ def _run(*args, **kwargs): return _run -class DynamicLaunchPlanCommand(click.RichCommand): +class DynamicEntityLaunchCommand(click.RichCommand): """ This is a dynamic command that is created for each launch plan. This is used to execute a launch plan. It will fetch the launch plan from remote and create parameters from all the inputs of the launch plan. """ - def __init__(self, name: str, h: str, lp_name: str, **kwargs): + LP_LAUNCHER = "lp" + TASK_LAUNCHER = "task" + + def __init__(self, name: str, h: str, entity_name: str, launcher: str, **kwargs): super().__init__(name=name, help=h, **kwargs) - self._lp_name = lp_name - self._lp = None + self._entity_name = entity_name + self._launcher = launcher + self._entity = None - def _fetch_launch_plan(self, ctx: click.Context) -> FlyteLaunchPlan: - if self._lp: - return self._lp + def _fetch_entity(self, ctx: click.Context) -> typing.Union[FlyteLaunchPlan, FlyteTask]: + if self._entity: + return self._entity run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() - self._lp = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._lp_name) - return self._lp + if self._launcher == self.LP_LAUNCHER: + entity = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._entity_name) + else: + entity = r.fetch_task(run_level_params.project, run_level_params.domain, self._entity_name) + self._entity = entity + return entity def _get_params( self, ctx: click.Context, inputs: typing.Dict[str, Variable], native_inputs: typing.Dict[str, type], - fixed: typing.Dict[str, Literal], - defaults: typing.Dict[str, Parameter], + fixed: typing.Optional[typing.Dict[str, Literal]] = None, + defaults: typing.Optional[typing.Dict[str, Parameter]] = None, ) -> typing.List["click.Parameter"]: params = [] flyte_ctx = context_manager.FlyteContextManager.current_context() @@ -586,21 +595,31 @@ def _get_params( if fixed and name in fixed: continue required = True + default_val = None if defaults and name in defaults: - required = False - params.append(to_click_option(ctx, flyte_ctx, name, var, native_inputs[name], None, required)) + if not defaults[name].required: + required = False + default_val = literal_string_repr(defaults[name].default) if defaults[name].default else None + params.append(to_click_option(ctx, flyte_ctx, name, var, native_inputs[name], default_val, required)) return params def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]: if not self.params: self.params = [] - lp = self._fetch_launch_plan(ctx) - if lp.interface: - if lp.interface.inputs: - types = TypeEngine.guess_python_types(lp.interface.inputs) - self.params = self._get_params( - ctx, lp.interface.inputs, types, lp.fixed_inputs.literals, lp.default_inputs.parameters - ) + entity = self._fetch_entity(ctx) + if entity.interface: + if entity.interface.inputs: + types = TypeEngine.guess_python_types(entity.interface.inputs) + if isinstance(entity, FlyteLaunchPlan): + self.params = self._get_params( + ctx, + entity.interface.inputs, + types, + entity.fixed_inputs.literals, + entity.default_inputs.parameters, + ) + else: + self.params = self._get_params(ctx, entity.interface.inputs, types) return super().get_params(ctx) @@ -611,40 +630,61 @@ def invoke(self, ctx: click.Context) -> typing.Any: """ run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() - lp = self._fetch_launch_plan(ctx) + entity = self._fetch_entity(ctx) run_remote( r, - lp, + entity, run_level_params.project, run_level_params.domain, ctx.params, run_level_params, - type_hints=lp.python_interface.inputs if lp.python_interface else None, + type_hints=entity.python_interface.inputs if entity.python_interface else None, ) -class RemoteLaunchPlanGroup(click.RichGroup): +class RemoteEntityGroup(click.RichGroup): """ click multicommand that retrieves launchplans from a remote flyte instance and executes them. """ - COMMAND_NAME = "remote-launchplan" + LAUNCHPLAN_COMMAND = "remote-launchplan" + WORKFLOW_COMMAND = "remote-workflow" + TASK_COMMAND = "remote-task" - def __init__(self): + def __init__(self, command_name: str): super().__init__( - name="from-server", - help="Retrieve launchplans from a remote flyte instance and execute them.", + name=command_name, + help=f"Retrieve {command_name} from a remote flyte instance and execute them.", params=[ click.Option( - ["--limit"], help="Limit the number of launchplans to retrieve.", default=10, show_default=True + ["--limit", "limit"], + help=f"Limit the number of {command_name}'s to retrieve.", + default=50, + show_default=True, ) ], ) - self._lps = [] + self._command_name = command_name + self._entities = [] + + def _get_entities(self, r: FlyteRemote, project: str, domain: str, limit: int) -> typing.List[str]: + """ + Retreieves the right entities from the remote flyte instance. + """ + if self._command_name == self.LAUNCHPLAN_COMMAND: + lps = r.client.list_launch_plan_ids_paginated(project=project, domain=domain, limit=limit) + return [l.name for l in lps[0]] + elif self._command_name == self.WORKFLOW_COMMAND: + wfs = r.client.list_workflow_ids_paginated(project=project, domain=domain, limit=limit) + return [w.name for w in wfs[0]] + elif self._command_name == self.TASK_COMMAND: + tasks = r.client.list_task_ids_paginated(project=project, domain=domain, limit=limit) + return [t.name for t in tasks[0]] + return [] def list_commands(self, ctx): - if self._lps or ctx.obj is None: - return self._lps + if self._entities or ctx.obj is None: + return self._entities run_level_params: RunLevelParams = ctx.obj r = run_level_params.remote_instance() @@ -653,17 +693,28 @@ def list_commands(self, ctx): with progress: progress.start_task(task) try: - lps = r.client.list_launch_plan_ids_paginated( - project=run_level_params.project, domain=run_level_params.domain, limit=run_level_params.limit + self._entities = self._get_entities( + r, run_level_params.project, run_level_params.domain, run_level_params.limit ) - self._lps = [l.name for l in lps[0]] - return self._lps + return self._entities except FlyteSystemException as e: pretty_print_exception(e) return [] def get_command(self, ctx, name): - return DynamicLaunchPlanCommand(name=name, h="Execute a launchplan from remote.", lp_name=name) + if self._command_name in [self.LAUNCHPLAN_COMMAND, self.WORKFLOW_COMMAND]: + return DynamicEntityLaunchCommand( + name=name, + h=f"Execute a {self._command_name}.", + entity_name=name, + launcher=DynamicEntityLaunchCommand.LP_LAUNCHER, + ) + return DynamicEntityLaunchCommand( + name=name, + h=f"Execute a {self._command_name}.", + entity_name=name, + launcher=DynamicEntityLaunchCommand.TASK_LAUNCHER, + ) class WorkflowCommand(click.RichGroup): @@ -789,7 +840,11 @@ def list_commands(self, ctx, add_remote: bool = True): self._files = [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] self._files = sorted(self._files) if add_remote: - self._files = self._files + [RemoteLaunchPlanGroup.COMMAND_NAME] + self._files = self._files + [ + RemoteEntityGroup.LAUNCHPLAN_COMMAND, + RemoteEntityGroup.WORKFLOW_COMMAND, + RemoteEntityGroup.TASK_COMMAND, + ] return self._files def get_command(self, ctx, filename): @@ -800,8 +855,12 @@ def get_command(self, ctx, filename): params.update(ctx.params) params.update(ctx.obj) ctx.obj = self._run_params.from_dict(params) - if filename == RemoteLaunchPlanGroup.COMMAND_NAME: - return RemoteLaunchPlanGroup() + if filename == RemoteEntityGroup.LAUNCHPLAN_COMMAND: + return RemoteEntityGroup(RemoteEntityGroup.LAUNCHPLAN_COMMAND) + elif filename == RemoteEntityGroup.WORKFLOW_COMMAND: + return RemoteEntityGroup(RemoteEntityGroup.WORKFLOW_COMMAND) + elif filename == RemoteEntityGroup.TASK_COMMAND: + return RemoteEntityGroup(RemoteEntityGroup.TASK_COMMAND) return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}")