Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Make sure decks created in elastic task workers are transferred to parent process #1837

Merged
merged 4 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

import cloudpickle
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
Expand Down Expand Up @@ -203,7 +203,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)


def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkpoint_src: str, kwargs) -> Any:
class ElasticWorkerResult(NamedTuple):
"""
A named tuple representing the result of a torch elastic worker process.

Attributes:
return_value (Any): The value returned by the task function in the worker process.
decks (list[flytekit.Deck]): A list of flytekit Deck objects created in the worker process.
"""

return_value: Any
decks: List[flytekit.Deck]


def spawn_helper(
fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkpoint_src: str, kwargs
) -> ElasticWorkerResult:
"""Help to spawn worker processes.

The purpose of this function is to 1) be pickleable so that it can be used with
Expand All @@ -220,7 +235,8 @@ def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkp
checkpoint_src (str): Location where the new checkpoint should be copied to.

Returns:
The return value of the received target function.
ElasticWorkerResult: A named tuple containing the return value of the task function and a list of
flytekit Deck objects created in the worker process.
"""
from flytekit.bin.entrypoint import setup_execution

Expand All @@ -231,7 +247,8 @@ def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkp
):
fn = cloudpickle.loads(fn)
return_val = fn(**kwargs)
return return_val

return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)


class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
Expand Down Expand Up @@ -336,7 +353,8 @@ def _execute(self, **kwargs) -> Any:

def fn_partial():
"""Closure of the task function with kwargs already bound."""
return self._task_function(**kwargs)
return_val = self._task_function(**kwargs)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)

launcher_target_func = fn_partial
launcher_args = ()
Expand Down Expand Up @@ -365,7 +383,13 @@ def fn_partial():
# `out` is a dictionary of rank (not local rank) -> result
# Rank 0 returns the result of the task function
if 0 in out:
return out[0]
# For rank 0, we transfer the decks created in the worker process to the parent process
ctx = flytekit.current_context()
for deck in out[0].decks:
if not isinstance(deck, flytekit.deck.deck.TimeLineDeck):
ctx.decks.append(deck)

return out[0].return_value
else:
raise IgnoreOutputs()

Expand Down
38 changes: 38 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,41 @@ def test_task():
with mock.patch("torch.distributed.launcher.api.LaunchConfig", side_effect=LaunchConfig) as mock_launch_config:
test_task()
assert mock_launch_config.call_args[1]["rdzv_configs"] == rdzv_configs


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_deck(start_method: str) -> None:
"""Test that decks created in the main worker process are transferred to the parent process."""
world_size = 2

@task(
task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method),
disable_deck=False,
)
def train():
import os

ctx = flytekit.current_context()
deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}")
ctx.decks.append(deck)
default_deck = ctx.default_deck
default_deck.append("Hello from default deck")

@workflow
def wf():
train()

wf()

ctx = flytekit.current_context()

expected_deck_names = {"timeline", "default", "test-deck"}
found_deck_names = set(d.name for d in ctx.decks)

assert expected_deck_names.issubset(found_deck_names)

default_deck = [d for d in ctx.decks if d.name == "default"][0]
assert "Hello from default deck" == default_deck.html.strip()

test_deck = [d for d in ctx.decks if d.name == "test-deck"][0]
assert "Hello Flyte Deck viewer from worker process 0" in test_deck.html