-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Eager workflows to support async workflows
Signed-off-by: Niels Bantilan <[email protected]>
- Loading branch information
1 parent
dd7fbe9
commit 4886519
Showing
12 changed files
with
621 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
FROM python:3.8-slim-buster | ||
|
||
WORKDIR /root | ||
ENV VENV /opt/venv | ||
ENV LANG C.UTF-8 | ||
ENV LC_ALL C.UTF-8 | ||
ENV PYTHONPATH /root | ||
|
||
RUN apt-get update && apt-get install -y build-essential git | ||
|
||
# Install the AWS cli separately to prevent issues with boto being written over | ||
RUN pip3 install awscli | ||
|
||
ENV VENV /opt/venv | ||
# Virtual environment | ||
RUN python3 -m venv ${VENV} | ||
ENV PATH="${VENV}/bin:$PATH" | ||
|
||
# Install Python dependencies | ||
RUN pip install scikit-learn pandas | ||
|
||
ARG gitsha | ||
RUN pip install flytekitplugins-deck-standard | ||
RUN pip install git+https://github.com/flyteorg/flytekit@${gitsha} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""Async workflows prototype.""" | ||
|
||
import asyncio | ||
from typing import NamedTuple | ||
|
||
import pandas as pd | ||
from sklearn.datasets import load_wine | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import accuracy_score | ||
|
||
from flytekit import workflow, task, Secret | ||
from flytekit.configuration import Config, PlatformConfig | ||
from flytekit.experimental import eager | ||
from flytekit.remote import FlyteRemote | ||
|
||
CACHE_VERSION = "3" | ||
|
||
|
||
remote = FlyteRemote( | ||
# config=Config.for_sandbox(), | ||
config=Config( | ||
platform=PlatformConfig( | ||
endpoint="development.uniondemo.run", | ||
auth_mode="Pkce", | ||
client_id="flytepropeller", | ||
insecure=False, | ||
), | ||
), | ||
default_project="flytesnacks", | ||
default_domain="development", | ||
data_upload_location="s3://flyte-development-data/data", | ||
) | ||
|
||
class CustomException(Exception): ... | ||
|
||
BestModel = NamedTuple("BestModel", model=LogisticRegression, metric=float) | ||
|
||
|
||
@task(cache=True, cache_version=CACHE_VERSION) | ||
def get_data() -> pd.DataFrame: | ||
"""Get the wine dataset.""" | ||
return load_wine(as_frame=True).frame | ||
|
||
|
||
@task(cache=True, cache_version=CACHE_VERSION) | ||
def process_data(data: pd.DataFrame) -> pd.DataFrame: | ||
"""Simplify the task from a 3-class to a binary classification problem.""" | ||
return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1)) | ||
|
||
|
||
@task(cache=True, cache_version=CACHE_VERSION) | ||
def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression: | ||
"""Train a model on the wine dataset.""" | ||
features = data.drop("target", axis="columns") | ||
target = data["target"] | ||
return LogisticRegression(max_iter=3000, **hyperparameters).fit(features, target) | ||
|
||
|
||
@task | ||
def evaluate_model(data: pd.DataFrame, model: LogisticRegression) -> float: | ||
"""Train a model on the wine dataset.""" | ||
features = data.drop("target", axis="columns") | ||
target = data["target"] | ||
return float(accuracy_score(target, model.predict(features))) | ||
|
||
|
||
@eager( | ||
remote=remote, | ||
force_remote=True, | ||
secret_requests=[Secret(group="async-client-secret", key="client_secret")], | ||
disable_deck=False, | ||
) | ||
async def main() -> BestModel: | ||
data = await get_data() | ||
processed_data = await process_data(data=data) | ||
|
||
# split the data | ||
try: | ||
train, test = train_test_split(processed_data, test_size=0.2) | ||
except Exception as exc: | ||
raise CustomException(str(exc)) from exc | ||
|
||
models = await asyncio.gather(*[ | ||
train_model(data=train, hyperparameters={"C": x}) | ||
for x in [0.1, 0.01, 0.001, 0.0001, 0.00001] | ||
]) | ||
results = await asyncio.gather(*[ | ||
evaluate_model(data=test, model=model) for model in models | ||
]) | ||
|
||
best_model, best_result = None, float("-inf") | ||
for model, result in zip(models, results): | ||
if result > best_result: | ||
best_model, best_result = model, result | ||
|
||
assert best_model is not None, "model cannot be None!" | ||
return best_model, best_result | ||
|
||
|
||
@workflow | ||
def wf() -> BestModel: | ||
return main() | ||
|
||
|
||
if __name__ == "__main__": | ||
print("training model") | ||
model = asyncio.run(main()) | ||
print(f"trained model: {model}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.