Skip to content

Commit

Permalink
[Elastic] Fix context usage and apply fix to fork method (flyteorg#2628)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Aug 1, 2024
1 parent 86c201d commit 3549597
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
12 changes: 9 additions & 3 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import flytekit
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import OutputMetadata
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
Expand Down Expand Up @@ -429,13 +429,18 @@ def fn_partial():
"""Closure of the task function with kwargs already bound."""
try:
return_val = self._task_function(**kwargs)
core_context = FlyteContextManager.current_context()
omt = core_context.output_metadata_tracker
om = None
if omt:
om = omt.get(return_val)
except Exception as e:
# See explanation in `create_recoverable_error_file` why we check
# for recoverable errors here in the worker processes.
if isinstance(e, FlyteRecoverableException):
create_recoverable_error_file()
raise
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=None)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om)

launcher_target_func = fn_partial
launcher_args = ()
Expand Down Expand Up @@ -470,7 +475,8 @@ def fn_partial():
if not isinstance(deck, flytekit.deck.deck.TimeLineDeck):
ctx.decks.append(deck)
if out[0].om:
ctx.output_metadata_tracker.add(out[0].return_value, out[0].om)
core_context = FlyteContextManager.current_context()
core_context.output_metadata_tracker.add(out[0].return_value, out[0].om)

return out[0].return_value
else:
Expand Down
40 changes: 40 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import typing
from dataclasses import dataclass
from unittest import mock
from typing_extensions import Annotated, cast
from flytekitplugins.kfpytorch.task import Elastic

from flytekit import Artifact

import pytest
import torch
Expand All @@ -11,6 +15,7 @@

import flytekit
from flytekit import task, workflow
from flytekit.core.context_manager import FlyteContext, FlyteContextManager, ExecutionState, ExecutionParameters, OutputMetadataTracker
from flytekit.configuration import SerializationSettings
from flytekit.exceptions.user import FlyteRecoverableException

Expand Down Expand Up @@ -159,6 +164,41 @@ def wf():
assert "Hello Flyte Deck viewer from worker process 0" in test_deck.html


class Card(object):
def __init__(self, text: str):
self.text = text

def serialize_to_string(self, ctx: FlyteContext, variable_name: str):
print(f"In serialize_to_string: {id(ctx)}")
return "card", "card"


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_output_metadata_passing(start_method: str) -> None:
ea = Artifact(name="elastic-artf")

@task(
task_config=Elastic(start_method=start_method),
)
def train2() -> Annotated[str, ea]:
return ea.create_from("hello flyte", Card("## card"))

@workflow
def wf():
train2()

ctx = FlyteContext.current_context()
omt = OutputMetadataTracker()
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt)
) as child_ctx:
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
# call execute directly so as to be able to get at the same FlyteContext object.
res = train2.execute()
om = child_ctx.output_metadata_tracker.get(res)
assert len(om.additional_items) == 1


@pytest.mark.parametrize(
"recoverable,start_method",
[
Expand Down

0 comments on commit 3549597

Please sign in to comment.