Skip to content
This repository has been archived by the owner on Dec 20, 2023. It is now read-only.

Commit

Permalink
Workflow fetch and sub-workflows (#75)
Browse files Browse the repository at this point in the history
Please see flyteorg/flyte#139 for additional context.

* Completes the SdkWorkflow.promote_from_model functionality.  Promoting from model now optionally takes subworkflows and tasks.  If specified, the ones provided will be used instead of fetching from Admin.
* Enables the sub-workflow behavior
* Updates flyteidl to a version that has sub-workflows in the workflow creation request to Admin.
* Calling an SdkWorkflow now produces a node.
* `SdkTaskNode` and `SdkWorkflowNode` have been moved into a separate file from nodes.py.

Related things but don't have time to put into this PR:
* Proper serialization of components (we want to use the same construct as the current registration call path)
* How do you not register subworkflows, but do register them if they're also standalone workflows, and always use the correct name, and also always have the correct dependency structure when doing the topologic sort.
* Discovered that we have duplicate `CompiledTask` models while writing this PR but will fix later.  Left a todo.
  • Loading branch information
wild-endeavor authored Mar 9, 2020
1 parent 74b0df2 commit ed11847
Show file tree
Hide file tree
Showing 22 changed files with 656 additions and 139 deletions.
2 changes: 1 addition & 1 deletion flytekit/flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import absolute_import
import flytekit.plugins

__version__ = '0.5.3'
__version__ = '0.6.0b1'
2 changes: 2 additions & 0 deletions flytekit/flytekit/clis/flyte_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def _welcome_message():
def _get_user_filepath_home():
return _os.path.expanduser("~")


def _get_config_file_path():
home = _get_user_filepath_home()
return _os.path.join(home, _default_config_file_dir, _default_config_file_name)


def _detect_default_config_file():
config_file = _get_config_file_path()
if _get_user_filepath_home() and _os.path.exists(config_file):
Expand Down
8 changes: 5 additions & 3 deletions flytekit/flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import click

from flytekit.common.tasks import task as _task
from flytekit.clis.sdk_in_container.constants import CTX_PROJECT, CTX_DOMAIN, CTX_TEST, CTX_PACKAGES, CTX_VERSION
from flytekit.common import utils as _utils
from flytekit.common.tasks import task as _task
from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \
IMAGE as _IMAGE
from flytekit.clis.sdk_in_container.constants import CTX_PROJECT, CTX_DOMAIN, CTX_TEST, CTX_PACKAGES, CTX_VERSION
from flytekit.configuration.sdk import WORKFLOW_PACKAGES as _WORKFLOW_PACKAGES
from flytekit.tools.module_loader import iterate_registerable_entities_in_order


Expand All @@ -18,6 +17,9 @@ def register_all(project, domain, pkgs, test, version):
click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format(
project, domain, pkgs, version))

# m = module (i.e. python file)
# k = value of dir(m), type str
# o = object (e.g. SdkWorkflow)
for m, k, o in iterate_registerable_entities_in_order(pkgs):
name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type)

Expand Down
139 changes: 139 additions & 0 deletions flytekit/flytekit/common/component_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import absolute_import

import six as _six
import logging as _logging

from flytekit.common import sdk_bases as _sdk_bases
from flytekit.common.exceptions import system as _system_exceptions
from flytekit.models.core import workflow as _workflow_model


class SdkTaskNode(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _workflow_model.TaskNode)):

def __init__(self, sdk_task):
"""
:param flytekit.common.tasks.task.SdkTask sdk_task:
"""
self._sdk_task = sdk_task
super(SdkTaskNode, self).__init__(None)

@property
def reference_id(self):
"""
A globally unique identifier for the task.
:rtype: flytekit.models.core.identifier.Identifier
"""
return self._sdk_task.id

@property
def sdk_task(self):
"""
:rtype: flytekit.common.tasks.task.SdkTask
"""
return self._sdk_task

@classmethod
def promote_from_model(cls, base_model, tasks):
"""
Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it from the
engine.
:param flytekit.models.core.workflow.TaskNode base_model:
:param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks:
:rtype: SdkTaskNode
"""
from flytekit.common.tasks import task as _task
if base_model.reference_id in tasks:
t = tasks[base_model.reference_id]
_logging.debug("Found existing task template for {}, will not retrieve from Admin".format(t.id))
sdk_task = _task.SdkTask.promote_from_model(t)
return cls(sdk_task)

# If not found, fetch it from Admin
_logging.debug("Fetching task template for {} from Admin".format(base_model.reference_id))
project = base_model.reference_id.project
domain = base_model.reference_id.domain
name = base_model.reference_id.name
version = base_model.reference_id.version
sdk_task = _task.SdkTask.fetch(project, domain, name, version)
return cls(sdk_task)


class SdkWorkflowNode(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _workflow_model.WorkflowNode)):
def __init__(self, sdk_workflow=None, sdk_launch_plan=None):
"""
:param flytekit.common.workflow.SdkWorkflow sdk_workflow:
:param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan:
"""
self._sdk_workflow = sdk_workflow
self._sdk_launch_plan = sdk_launch_plan
super(SdkWorkflowNode, self).__init__()

@property
def launchplan_ref(self):
"""
[Optional] A globally unique identifier for the launch plan. Should map to Admin.
:rtype: flytekit.models.core.identifier.Identifier
"""
return self._sdk_launch_plan.id if self._sdk_launch_plan else None

@property
def sub_workflow_ref(self):
"""
[Optional] Reference to a subworkflow, that should be defined with the compiler context.
:rtype: flytekit.models.core.identifier.Identifier
"""
return self._sdk_workflow.id if self._sdk_workflow else None

@property
def sdk_launch_plan(self):
"""
:rtype: flytekit.common.launch_plan.SdkLaunchPlan
"""
return self._sdk_launch_plan

@property
def sdk_workflow(self):
"""
:rtype: flytekit.common.workflow.SdkWorkflow
"""
return self._sdk_workflow

@classmethod
def promote_from_model(cls, base_model, sub_workflows, tasks):
"""
:param flytekit.models.core.workflow.WorkflowNode base_model:
:param dict[flytekit.models.core.identifier.Identifier, flytekit.models.core.workflow.WorkflowTemplate]
sub_workflows:
:param dict[flytekit.models.core.identifier.Identifier, flytekit.models.task.TaskTemplate] tasks:
:rtype: SdkWorkflowNode
"""
# put the import statement here to prevent circular dependency error
from flytekit.common import workflow as _workflow, launch_plan as _launch_plan

project = base_model.reference.project
domain = base_model.reference.domain
name = base_model.reference.name
version = base_model.reference.version
if base_model.launchplan_ref is not None:
sdk_launch_plan = _launch_plan.SdkLaunchPlan.fetch(project, domain, name, version)
return cls(sdk_launch_plan=sdk_launch_plan)
elif base_model.sub_workflow_ref is not None:
# The workflow templates for sub-workflows should have been included in the original response
if base_model.reference in sub_workflows:
sw = sub_workflows[base_model.reference]
promoted = _workflow.SdkWorkflow.promote_from_model(sw, sub_workflows=sub_workflows,
tasks=tasks)
return cls(sdk_workflow=promoted)

# If not found for some reason, fetch it from Admin again.
# The reason there is a warning here but not for tasks is because sub-workflows should always be passed
# along. Ideally subworkflows are never even registered with Admin, so fetching from Admin ideally doesn't
# return anything.
_logging.warning("Your subworkflow with id {} is not included in the promote call.".format(
base_model.reference))
sdk_workflow = _workflow.SdkWorkflow.fetch(project, domain, name, version)
return cls(sdk_workflow=sdk_workflow)
else:
raise _system_exceptions.FlyteSystemException("Bad workflow node model, neither subworkflow nor "
"launchplan specified.")
3 changes: 3 additions & 0 deletions flytekit/flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class SdkTaskType(object):

GLOBAL_INPUT_NODE_ID = ''

START_NODE_ID = "start-node"
END_NODE_ID = "end-node"


class CloudProvider(object):
AWS = "aws"
Expand Down
11 changes: 10 additions & 1 deletion flytekit/flytekit/common/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, *args, **kwargs):
super(SdkLaunchPlan, self).__init__(*args, **kwargs)
self._id = None

# The interface is not set explicitly unless fetched in an engine context
self._interface = None

@classmethod
def promote_from_model(cls, model):
"""
Expand Down Expand Up @@ -62,12 +65,18 @@ def fetch(cls, project, domain, name, version=None):
domain, and name.
:rtype: SdkLaunchPlan
"""
from flytekit.common import workflow as _workflow
launch_plan_id = _identifier.Identifier(
_identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version
)
lp = _engine_loader.get_engine().fetch_launch_plan(launch_plan_id)
sdk_lp = cls.promote_from_model(lp.spec)
sdk_lp._id = lp.id

# TODO: Add a test for this, and this function as a whole
wf_id = sdk_lp.workflow_id
lp_wf = _workflow.SdkWorkflow.fetch(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
sdk_lp._interface = lp_wf.interface
return sdk_lp

@_exception_scopes.system_entry_point
Expand Down Expand Up @@ -356,7 +365,7 @@ def __call__(self, *args, **input_map):
"""
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
"When adding a task as a node in a workflow, all inputs must be specified with kwargs only. We "
"When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We "
"detected {} positional args.".format(self, len(args))
)

Expand Down
Loading

0 comments on commit ed11847

Please sign in to comment.