Skip to content

Commit

Permalink
pyflyte run can now execute a task either locally or remote (#995)
Browse files Browse the repository at this point in the history
* pyflyte run can now execute a task either locally or remote

Signed-off-by: Ketan Umare <[email protected]>

* Fix namedtuple and test

Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 authored and eapolinario committed Jun 17, 2022
1 parent 8c94362 commit 58aca28
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 30 deletions.
72 changes: 46 additions & 26 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flytekit.configuration import Config, ImageConfig, SerializationSettings
from flytekit.configuration.default_images import DefaultImages
from flytekit.core import context_manager, tracker
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContext
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -395,7 +396,7 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
]


def load_naive_entity(module_name: str, workflow_name: str) -> WorkflowBase:
def load_naive_entity(module_name: str, entity_name: str) -> typing.Union[WorkflowBase, PythonTask]:
"""
Load the workflow of a the script file.
N.B.: it assumes that the file is self-contained, in other words, there are no relative imports.
Expand All @@ -406,7 +407,7 @@ def load_naive_entity(module_name: str, workflow_name: str) -> WorkflowBase:
with context_manager.FlyteContextManager.with_context(flyte_ctx):
with module_loader.add_sys_path(os.getcwd()):
importlib.import_module(module_name)
return module_loader.load_object_from_module(f"{module_name}.{workflow_name}")
return module_loader.load_object_from_module(f"{module_name}.{entity_name}")


def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, domain: str):
Expand All @@ -424,30 +425,49 @@ def dump_flyte_remote_snippet(execution: FlyteWorkflowExecution, project: str, d
)


def get_workflows_in_file(filename: str) -> typing.List[str]:
class Entities(typing.NamedTuple):
"""
Returns a list of flyte workflow names in a file.
NamedTuple to group all entities in a file
"""

workflows: typing.List[str]
tasks: typing.List[str]

def all(self) -> typing.List[str]:
e = []
e.extend(self.workflows)
e.extend(self.tasks)
return e


def get_entities_in_file(filename: str) -> Entities:
"""
Returns a list of flyte workflow names and list of Flyte tasks in a file.
"""
flyte_ctx = context_manager.FlyteContextManager.current_context().with_serialization_settings(
SerializationSettings(None)
)
module_name = os.path.splitext(filename)[0].replace(os.path.sep, ".")
module_name = os.path.splitext(os.path.relpath(filename))[0].replace(os.path.sep, ".")
with context_manager.FlyteContextManager.with_context(flyte_ctx):
with module_loader.add_sys_path(os.getcwd()):
importlib.import_module(module_name)

workflows = []
tasks = []
module = importlib.import_module(module_name)
for k in dir(module):
o = module.__dict__[k]
if isinstance(o, PythonFunctionWorkflow):
_, _, fn, _ = tracker.extract_task_module(o)
workflows.append(fn)
elif isinstance(o, PythonTask):
_, _, fn, _ = tracker.extract_task_module(o)
tasks.append(fn)

return workflows
return Entities(workflows, tasks)


def run_command(ctx: click.Context, wf_entity: PythonFunctionWorkflow):
def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]):
"""
Returns a function that is used to implement WorkflowCommand and execute a flyte workflow.
"""
Expand All @@ -456,11 +476,11 @@ def _run(*args, **kwargs):
run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY]
project, domain = run_level_params.get("project"), run_level_params.get("domain")
inputs = {}
for input_name, _ in wf_entity.python_interface.inputs.items():
for input_name, _ in entity.python_interface.inputs.items():
inputs[input_name] = kwargs.get(input_name)

if not ctx.obj[REMOTE_FLAG_KEY]:
output = wf_entity(**inputs)
output = entity(**inputs)
click.echo(output)
return

Expand All @@ -470,8 +490,8 @@ def _run(*args, **kwargs):
# PandasToParquetDataProxyEncodingHandler(get_upload_url_fn), default_for_type=True
# )

wf = remote.register_script(
wf_entity,
remote_entity = remote.register_script(
entity,
project=project,
domain=domain,
image_config=run_level_params.get("image_config", None),
Expand All @@ -486,14 +506,14 @@ def _run(*args, **kwargs):
options = Options.default_from(k8s_service_account=service_account)

execution = remote.execute(
wf,
remote_entity,
inputs=inputs,
project=project,
domain=domain,
name=run_level_params.get("name"),
wait=run_level_params.get("wait_execution"),
options=options,
type_hints=wf_entity.python_interface.inputs,
type_hints=entity.python_interface.inputs,
)

console_url = remote.generate_console_url(execution)
Expand All @@ -515,18 +535,18 @@ def __init__(self, filename: str, *args, **kwargs):
self._filename = filename

def list_commands(self, ctx):
workflows = get_workflows_in_file(self._filename)
return workflows
entities = get_entities_in_file(self._filename)
return entities.all()

def get_command(self, ctx, workflow):
def get_command(self, ctx, exe_entity):
rel_path = os.path.relpath(self._filename)
if rel_path.startswith(".."):
raise ValueError(
f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}"
)

module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".")
wf_entity = load_naive_entity(module, workflow)
entity = load_naive_entity(module, exe_entity)

# If this is a remote execution, which we should know at this point, then create the remote object
p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
Expand All @@ -544,24 +564,24 @@ def get_command(self, ctx, workflow):

# Add options for each of the workflow inputs
params = []
for input_name, input_type_val in wf_entity.python_interface.inputs_with_defaults.items():
literal_var = wf_entity.interface.inputs.get(input_name)
for input_name, input_type_val in entity.python_interface.inputs_with_defaults.items():
literal_var = entity.interface.inputs.get(input_name)
python_type, default_val = input_type_val
params.append(
to_click_option(ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn)
)
cmd = click.Command(
name=workflow,
name=exe_entity,
params=params,
callback=run_command(ctx, wf_entity),
help=f"Run {module}.{workflow} in script mode",
callback=run_command(ctx, entity),
help=f"Run {module}.{exe_entity} in script mode",
)
return cmd


class RunCommand(click.MultiCommand):
"""
A click command group for registering and executing flyte workflows in a file.
A click command group for registering and executing flyte workflows & tasks in a file.
"""

def __init__(self, *args, **kwargs):
Expand All @@ -573,11 +593,11 @@ def list_commands(self, ctx):

def get_command(self, ctx, filename):
ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params
return WorkflowCommand(filename, name=filename, help="Run a workflow in a file using script mode")
return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode")


run = RunCommand(
name="run",
help="Run command, a.k.a. script mode. It allows for a a single script to be "
+ "registered and run from the command line (e.g. Jupyter notebooks).",
help="Run command: This command can execute either a workflow or a task from the commandline, for "
"fully self-contained scripts. Tasks and workflows cannot be imported from other files currently.",
)
8 changes: 5 additions & 3 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,23 +509,23 @@ def register_workflow(

def register_script(
self,
entity: WorkflowBase,
entity: typing.Union[WorkflowBase, PythonTask],
image_config: typing.Optional[ImageConfig] = None,
version: typing.Optional[str] = None,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
destination_dir: str = ".",
default_launch_plan: typing.Optional[bool] = True,
options: typing.Optional[Options] = None,
) -> FlyteWorkflow:
) -> typing.Union[FlyteWorkflow, FlyteTask]:
"""
Use this method to register a workflow via script mode.
:param destination_dir:
:param domain:
:param project:
:param image_config:
:param version: version for the entity to be registered as
:param entity: The workflow to be registered
:param entity: The workflow to be registered or the task to be registered
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
:param options: Additional execution options that can be configured for the default launchplan
:return:
Expand Down Expand Up @@ -567,6 +567,8 @@ def register_script(
h.update(bytes(__version__, "utf-8"))
version = base64.urlsafe_b64encode(h.digest())

if isinstance(entity, PythonTask):
return self.register_task(entity, serialization_settings, version)
return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options)

def register_launch_plan(
Expand Down
12 changes: 11 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from click.testing import CliRunner

from flytekit.clis.sdk_in_container import pyflyte
from flytekit.clis.sdk_in_container.run import get_entities_in_file

WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")


def test_pyflyte_run_wf():
runner = CliRunner()
module_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py")
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)

assert result.exit_code == 0
Expand Down Expand Up @@ -52,3 +55,10 @@ def test_pyflyte_run_cli():
)
print(result.stdout)
assert result.exit_code == 0


def test_get_entities_in_file():
e = get_entities_in_file(WORKFLOW_FILE)
assert e.workflows == ["my_wf"]
assert e.tasks == ["get_subset_df", "print_all", "show_sd"]
assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd"]

0 comments on commit 58aca28

Please sign in to comment.