Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
iaroslav-ciupin committed Sep 19, 2023
1 parent 998810d commit 132c2cc
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
NodeExecutionGetDataResponse,
NotificationList,
WorkflowExecutionGetDataResponse,
ClusterAssignment,
)
from flytekit.models.launch_plan import LaunchPlanState
from flytekit.models.literals import Literal, LiteralMap
Expand Down Expand Up @@ -962,6 +963,7 @@ def _execute(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.
Expand All @@ -978,6 +980,7 @@ def _execute(
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
if execution_name is not None and execution_name_prefix is not None:
Expand All @@ -986,6 +989,8 @@ def _execute(
execution_name = execution_name or (execution_name_prefix or "f") + uuid.uuid4().hex[:19]
if not options:
options = Options()
if cluster_pool:
options.cluster_assignment = ClusterAssignment(cluster_pool=cluster_pool)
if options.disable_notifications is not None:
if options.disable_notifications:
notifications = None
Expand Down Expand Up @@ -1107,6 +1112,7 @@ def execute(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity.
Expand Down Expand Up @@ -1145,6 +1151,7 @@ def execute(
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
.. note:
Expand All @@ -1167,6 +1174,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, FlyteWorkflow):
return self.execute_remote_wf(
Expand All @@ -1182,6 +1190,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, PythonTask):
return self.execute_local_task(
Expand All @@ -1198,6 +1207,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, WorkflowBase):
return self.execute_local_workflow(
Expand All @@ -1215,6 +1225,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
if isinstance(entity, LaunchPlan):
return self.execute_local_launch_plan(
Expand All @@ -1230,6 +1241,7 @@ def execute(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)
raise NotImplementedError(f"entity type {type(entity)} not recognized for execution")

Expand All @@ -1250,6 +1262,7 @@ def execute_remote_task_lp(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteTask, or FlyteLaunchplan.
Expand All @@ -1268,6 +1281,7 @@ def execute_remote_task_lp(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_remote_wf(
Expand All @@ -1284,6 +1298,7 @@ def execute_remote_wf(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteWorkflow.
Expand All @@ -1303,6 +1318,7 @@ def execute_remote_wf(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

# Flytekit Entities
Expand All @@ -1323,6 +1339,7 @@ def execute_local_task(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute a @task-decorated function or TaskTemplate task.
Expand All @@ -1339,6 +1356,7 @@ def execute_local_task(
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:return: FlyteWorkflowExecution object.
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1369,6 +1387,7 @@ def execute_local_task(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_local_workflow(
Expand All @@ -1387,6 +1406,7 @@ def execute_local_workflow(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
Expand All @@ -1403,6 +1423,7 @@ def execute_local_workflow(
:param overwrite_cache:
:param envs:
:param tags:
:param cluster_pool:
:return:
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1451,6 +1472,7 @@ def execute_local_workflow(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

def execute_local_launch_plan(
Expand All @@ -1467,6 +1489,7 @@ def execute_local_launch_plan(
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
cluster_pool: typing.Optional[str] = None,
) -> FlyteWorkflowExecution:
"""
Expand All @@ -1481,6 +1504,7 @@ def execute_local_launch_plan(
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:param cluster_pool: Specify cluster pool on which newly created execution should be placed.
:return: FlyteWorkflowExecution object
"""
try:
Expand Down Expand Up @@ -1510,6 +1534,7 @@ def execute_local_launch_plan(
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
cluster_pool=cluster_pool,
)

###################################
Expand Down

0 comments on commit 132c2cc

Please sign in to comment.