Skip to content

Commit

Permalink
Eager workflows to support async workflows
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Apr 4, 2023
1 parent dd7fbe9 commit 4886519
Show file tree
Hide file tree
Showing 12 changed files with 621 additions and 67 deletions.
24 changes: 24 additions & 0 deletions Dockerfile.eager
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM python:3.8-slim-buster

WORKDIR /root
ENV VENV /opt/venv
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

RUN apt-get update && apt-get install -y build-essential git

# Install the AWS cli separately to prevent issues with boto being written over
RUN pip3 install awscli

ENV VENV /opt/venv
# Virtual environment
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

# Install Python dependencies
RUN pip install scikit-learn pandas

ARG gitsha
RUN pip install flytekitplugins-deck-standard
RUN pip install git+https://github.com/flyteorg/flytekit@${gitsha}
Empty file added async_prototype/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions async_prototype/flytekit_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Async workflows prototype."""

import asyncio
from typing import NamedTuple

import pandas as pd
from sklearn.datasets import load_wine
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from flytekit import workflow, task, Secret
from flytekit.configuration import Config, PlatformConfig
from flytekit.experimental import eager
from flytekit.remote import FlyteRemote

CACHE_VERSION = "3"


remote = FlyteRemote(
# config=Config.for_sandbox(),
config=Config(
platform=PlatformConfig(
endpoint="development.uniondemo.run",
auth_mode="Pkce",
client_id="flytepropeller",
insecure=False,
),
),
default_project="flytesnacks",
default_domain="development",
data_upload_location="s3://flyte-development-data/data",
)

class CustomException(Exception): ...

BestModel = NamedTuple("BestModel", model=LogisticRegression, metric=float)


@task(cache=True, cache_version=CACHE_VERSION)
def get_data() -> pd.DataFrame:
"""Get the wine dataset."""
return load_wine(as_frame=True).frame


@task(cache=True, cache_version=CACHE_VERSION)
def process_data(data: pd.DataFrame) -> pd.DataFrame:
"""Simplify the task from a 3-class to a binary classification problem."""
return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1))


@task(cache=True, cache_version=CACHE_VERSION)
def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression:
"""Train a model on the wine dataset."""
features = data.drop("target", axis="columns")
target = data["target"]
return LogisticRegression(max_iter=3000, **hyperparameters).fit(features, target)


@task
def evaluate_model(data: pd.DataFrame, model: LogisticRegression) -> float:
"""Train a model on the wine dataset."""
features = data.drop("target", axis="columns")
target = data["target"]
return float(accuracy_score(target, model.predict(features)))


@eager(
remote=remote,
force_remote=True,
secret_requests=[Secret(group="async-client-secret", key="client_secret")],
disable_deck=False,
)
async def main() -> BestModel:
data = await get_data()
processed_data = await process_data(data=data)

# split the data
try:
train, test = train_test_split(processed_data, test_size=0.2)
except Exception as exc:
raise CustomException(str(exc)) from exc

models = await asyncio.gather(*[
train_model(data=train, hyperparameters={"C": x})
for x in [0.1, 0.01, 0.001, 0.0001, 0.00001]
])
results = await asyncio.gather(*[
evaluate_model(data=test, model=model) for model in models
])

best_model, best_result = None, float("-inf")
for model, result in zip(models, results):
if result > best_result:
best_model, best_result = model, result

assert best_model is not None, "model cannot be None!"
return best_model, best_result


@workflow
def wf() -> BestModel:
return main()


if __name__ == "__main__":
print("training model")
model = asyncio.run(main())
print(f"trained model: {model}")
7 changes: 7 additions & 0 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import contextlib
import datetime as _datetime
import inspect
import os
import pathlib
import subprocess
Expand Down Expand Up @@ -89,6 +91,11 @@ def _dispatch_execute(
# Decorate the dispatch execute function before calling it, this wraps all exceptions into one
# of the FlyteScopedExceptions
outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(ctx, idl_input_literals)
if inspect.iscoroutine(outputs):
# Handle eager-mode (async) tasks
logger.info("Output is a coroutine")
outputs = asyncio.run(outputs)

# Step3a
if isinstance(outputs, VoidPromise):
logger.warning("Task produces no outputs")
Expand Down
5 changes: 5 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import datetime
import functools
import importlib
import inspect
import json
import logging
import os
Expand Down Expand Up @@ -540,6 +542,9 @@ def _run(*args, **kwargs):

if not ctx.obj[REMOTE_FLAG_KEY]:
output = entity(**inputs)
if inspect.iscoroutine(output):
# TODO: make eager mode workflows run with local-mode
output = asyncio.run(output)
click.echo(output)
return

Expand Down
128 changes: 73 additions & 55 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import collections
import datetime
import inspect
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast
Expand Down Expand Up @@ -278,6 +279,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
# Code is simpler with duplication and less metaprogramming, but introduces regressions
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)

if inspect.iscoroutine(outputs_literal_map):
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down Expand Up @@ -489,6 +494,65 @@ def compile(self, ctx: FlyteContext, *args, **kwargs):
def _outputs_interface(self) -> Dict[Any, Variable]:
return self.interface.outputs # type: ignore

def _output_to_literal_map(self, native_outputs, exec_ctx):
expected_output_names = list(self._outputs_interface.keys())
if len(expected_output_names) == 1:
# Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
# length one. That convention is used for naming outputs - and single-length-NamedTuples are
# particularly troublesome but elegant handling of them is not a high priority
# Again, we're using the output_tuple_name as a proxy.
if self.python_interface.output_tuple_name and isinstance(native_outputs, tuple):
native_outputs_as_map = {expected_output_names[0]: native_outputs[0]}
else:
native_outputs_as_map = {expected_output_names[0]: native_outputs}
elif len(expected_output_names) == 0:
native_outputs_as_map = {}
else:
native_outputs_as_map = {expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)}

# We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
# built into the IDL that all the values of a literal map are of the same type.
literals = {}
for k, v in native_outputs_as_map.items():
literal_type = self._outputs_interface[k].type
py_type = self.get_type_for_output_var(k, v)

if isinstance(v, tuple):
raise TypeError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}")
try:
literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type)
except Exception as e:
logger.error(f"Failed to convert return value for var {k} with error {type(e)}: {e}")
raise TypeError(
f"Failed to convert return value for var {k} for function {self.name} with error {type(e)}: {e}"
) from e

return _literal_models.LiteralMap(literals=literals), native_outputs_as_map

def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_params):
from flytekit.deck.deck import _output_deck

if self._disable_deck is False:
INPUT = "input"
OUTPUT = "output"

input_deck = Deck(INPUT)
for k, v in native_inputs.items():
input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v)))

output_deck = Deck(OUTPUT)
for k, v in native_outputs_as_map.items():
output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v)))

_output_deck(self.name.split(".")[-1], new_user_params)

async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params):
out = await native_outputs
native_outputs = self.post_execute(new_user_params, native_outputs)
literals_map, native_outputs_as_map = self._output_to_literal_map(out, exec_ctx)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
return literals_map

def dispatch_execute(
self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap
) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]:
Expand All @@ -501,10 +565,8 @@ def dispatch_execute(
may be none
* ``DynamicJobSpec`` is returned when a dynamic workflow is executed
"""

# Invoked before the task is executed
new_user_params = self.pre_execute(ctx.user_space_params)
from flytekit.deck.deck import _output_deck

# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
Expand All @@ -526,6 +588,11 @@ def dispatch_execute(
logger.exception(f"Exception when executing {e}")
raise e

if inspect.iscoroutine(native_outputs):
if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
return native_outputs
return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params)

logger.debug("Task executed successfully in user level")
# Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
# bubbled up to be handled at the callee layer.
Expand All @@ -534,62 +601,13 @@ def dispatch_execute(
# Short circuit the translation to literal map because what's returned may be a dj spec (or an
# already-constructed LiteralMap if the dynamic task was a no-op), not python native values
# dynamic_execute returns a literal map in local execute so this also gets triggered.
if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance(
native_outputs, _dynamic_job.DynamicJobSpec
):
if isinstance(native_outputs, (_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec)):
return native_outputs

expected_output_names = list(self._outputs_interface.keys())
if len(expected_output_names) == 1:
# Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
# length one. That convention is used for naming outputs - and single-length-NamedTuples are
# particularly troublesome but elegant handling of them is not a high priority
# Again, we're using the output_tuple_name as a proxy.
if self.python_interface.output_tuple_name and isinstance(native_outputs, tuple):
native_outputs_as_map = {expected_output_names[0]: native_outputs[0]}
else:
native_outputs_as_map = {expected_output_names[0]: native_outputs}
elif len(expected_output_names) == 0:
native_outputs_as_map = {}
else:
native_outputs_as_map = {
expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)
}

# We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption
# built into the IDL that all the values of a literal map are of the same type.
literals = {}
for k, v in native_outputs_as_map.items():
literal_type = self._outputs_interface[k].type
py_type = self.get_type_for_output_var(k, v)

if isinstance(v, tuple):
raise TypeError(f"Output({k}) in task{self.name} received a tuple {v}, instead of {py_type}")
try:
literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type)
except Exception as e:
logger.error(f"Failed to convert return value for var {k} with error {type(e)}: {e}")
raise TypeError(
f"Failed to convert return value for var {k} for function {self.name} with error {type(e)}: {e}"
) from e

if self._disable_deck is False:
INPUT = "input"
OUTPUT = "output"

input_deck = Deck(INPUT)
for k, v in native_inputs.items():
input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v)))

output_deck = Deck(OUTPUT)
for k, v in native_outputs_as_map.items():
output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v)))

_output_deck(self.name.split(".")[-1], new_user_params)

outputs_literal_map = _literal_models.LiteralMap(literals=literals)
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
# After the execute has been successfully completed
return outputs_literal_map
return literals_map

def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore
"""
Expand Down
4 changes: 4 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import inspect
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast

Expand Down Expand Up @@ -1093,6 +1094,9 @@ def flyte_entity_call_handler(
else:
raise Exception(f"Received an output when workflow local execution expected None. Received: {result}")

if inspect.iscoroutine(result):
return result

if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
result is not None and expected_outputs == 1
):
Expand Down
Loading

0 comments on commit 4886519

Please sign in to comment.