forked from flyteorg/flyte
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Single node GPU training example (flyteorg#333)
* Single node GPU training example Signed-off-by: Ketan Umare <[email protected]> * Minor fix related to tensorboard in PyTorch (flyteorg#334) Signed-off-by: Jinserk Baik <[email protected]> * updated pytorch training example Signed-off-by: Ketan Umare <[email protected]> * updated Signed-off-by: Ketan Umare <[email protected]> * wandb integration, code lint, content Signed-off-by: Samhita Alla <[email protected]> * remove misplaced text Signed-off-by: Samhita Alla <[email protected]> * add pytorch in tests' manifest Signed-off-by: Samhita Alla <[email protected]> * changed pytorch to mnist Signed-off-by: Samhita Alla <[email protected]> * dockerfile Signed-off-by: Samhita Alla <[email protected]> * update link Signed-off-by: cosmicBboy <[email protected]> * update deps Signed-off-by: cosmicBboy <[email protected]> Co-authored-by: Jinserk Baik <[email protected]> Co-authored-by: Samhita Alla <[email protected]> Co-authored-by: cosmicBboy <[email protected]>
- Loading branch information
1 parent
b3b34d7
commit ff4e179
Showing
18 changed files
with
825 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
wandb/ |
31 changes: 31 additions & 0 deletions
31
cookbook/case_studies/ml_training/mnist_classifier/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
FROM nvcr.io/nvidia/pytorch:21.06-py3 | ||
LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks | ||
|
||
WORKDIR /root | ||
ENV LANG C.UTF-8 | ||
ENV LC_ALL C.UTF-8 | ||
ENV PYTHONPATH /root | ||
|
||
# Give your wandb API key. Get it from https://wandb.ai/authorize. | ||
# ENV WANDB_API_KEY your-api-key | ||
|
||
# Install the AWS cli for AWS support | ||
RUN pip install awscli | ||
|
||
ENV VENV /opt/venv | ||
|
||
# Virtual environment | ||
RUN python3 -m venv ${VENV} | ||
ENV PATH="${VENV}/bin:$PATH" | ||
|
||
# Install Python dependencies | ||
COPY mnist_classifier/requirements.txt /root | ||
RUN pip install -r /root/requirements.txt | ||
|
||
# Copy the actual code | ||
COPY mnist_classifier/ /root/mnist_classifier/ | ||
|
||
# This tag is supplied by the build script and will be used to determine the version | ||
# when registering tasks, workflows, and launch plans | ||
ARG tag | ||
ENV FLYTE_INTERNAL_IMAGE $tag |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
PREFIX=mnist_classifier | ||
include ../../../common/Makefile | ||
include ../../../common/leaf.mk |
75 changes: 75 additions & 0 deletions
75
cookbook/case_studies/ml_training/mnist_classifier/README.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
.. _mnist-classifier-training: | ||
|
||
MNIST Classification With PyTorch and W&B | ||
----------------------------------------- | ||
|
||
PyTorch | ||
======= | ||
|
||
`Pytorch <https://pytorch.org/>`__ is a machine learning framework that accelerates the path from research prototyping | ||
to production deployment. You can build *Tensors* and *Dynamic neural networks* in Python with strong GPU acceleration | ||
using PyTorch. | ||
|
||
In a nutshell, it is a Python package that provides two high-level features: | ||
|
||
- Tensor computation (like NumPy) with strong GPU acceleration | ||
- Deep neural networks built on a tape-based autograd system | ||
|
||
Flyte directly has no unique understanding of PyTorch. As per Flyte, PyTorch is just a Python library. | ||
However, when merged with Flyte, the combo helps utilize and bootstrap the infrastructure for PyTorch and ensures that things work well! | ||
Additionally, it also offers other benefits of using tasks and workflows -- checkpointing, separation of concerns, and auto-memoization. | ||
|
||
Specify GPU Requirement | ||
======================= | ||
|
||
One of the necessary directives applicable when working on deep learning models is explicitly requesting one or more GPUs. | ||
This can be done by giving a simple directive to the task declaration as follows: | ||
|
||
.. code-block:: python | ||
from flytekit import Resources, task | ||
@task(requests=Resources(gpu="1"), limits=Resources(gpu="1")) | ||
def my_deep_learning_task(): | ||
... | ||
.. tip:: | ||
It is recommended to use the same ``requests`` and ``limits`` for a GPU as automatic GPU scaling is not supported. | ||
|
||
Moreover, to utilize the power of a GPU, ensure that your Flyte backend has GPU nodes provisioned. | ||
|
||
Distributed Data-Parallel Training | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
Flyte also supports distributed training for PyTorch models, but this is not native. This is achieved using one of the optional plugins, such as: | ||
|
||
- Natively on Kubernetes using :ref:`kf-pytorch-op` | ||
- On AWS using :ref:`aws-sagemaker` training | ||
|
||
*Other distributed training plugins are coming soon -- MPIOperator, Google Vertex AI, etc. You can add your favorite services, too!* | ||
|
||
Weights & Biases Integration | ||
============================ | ||
|
||
`Weights & Biases <https://wandb.ai/site>`__, or simply, ``wandb`` helps build better models faster with experiment tracking, dataset versioning, and model management. | ||
|
||
We'll use ``wandb`` alongside PyTorch to track our ML experiment and its concerned model parameters. | ||
|
||
.. note:: | ||
Before running the example, create a ``wandb`` account and log in to access the API. | ||
If you're running the code locally, run the command ``wandb login``. | ||
If it's a remote cluster, you have to include the API key in the Dockerfile. | ||
|
||
.. _pytorch-dockerfile: | ||
|
||
PyTorch Dockerfile for Deployment | ||
================================= | ||
|
||
It is essential to build the Dockerfile with GPU support to use a GPU within PyTorch. | ||
The example in this section uses a simple ``nvidia-supplied GPU Docker image`` as the base, and the rest of the construction is similar to the other Dockerfiles. | ||
|
||
.. literalinclude:: ../../../../../case_studies/ml_training/mnist_classifier/Dockerfile | ||
:language: docker | ||
|
||
.. note:: | ||
Run your code in the ``ml_training`` directory, both locally and within the sandbox. |
Empty file.
4 changes: 4 additions & 0 deletions
4
cookbook/case_studies/ml_training/mnist_classifier/requirements.in
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
-r ../../../common/requirements-common.in | ||
torch | ||
torchvision | ||
wandb |
187 changes: 187 additions & 0 deletions
187
cookbook/case_studies/ml_training/mnist_classifier/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
# | ||
# This file is autogenerated by pip-compile with python 3.8 | ||
# To update, run: | ||
# | ||
# /Library/Developer/CommandLineTools/usr/bin/make requirements.txt | ||
# | ||
attrs==21.2.0 | ||
# via scantree | ||
certifi==2021.5.30 | ||
# via | ||
# requests | ||
# sentry-sdk | ||
charset-normalizer==2.0.3 | ||
# via requests | ||
click==7.1.2 | ||
# via | ||
# flytekit | ||
# wandb | ||
configparser==5.0.2 | ||
# via wandb | ||
croniter==1.0.15 | ||
# via flytekit | ||
cycler==0.10.0 | ||
# via matplotlib | ||
dataclasses-json==0.5.4 | ||
# via flytekit | ||
decorator==5.0.9 | ||
# via retry | ||
deprecated==1.2.12 | ||
# via flytekit | ||
dirhash==0.2.1 | ||
# via flytekit | ||
docker-image-py==0.1.10 | ||
# via flytekit | ||
docker-pycreds==0.4.0 | ||
# via wandb | ||
flyteidl==0.19.14 | ||
# via flytekit | ||
flytekit==0.20.1 | ||
# via -r ../../../common/requirements-common.in | ||
gitdb==4.0.7 | ||
# via gitpython | ||
gitpython==3.1.18 | ||
# via wandb | ||
grpcio==1.39.0 | ||
# via flytekit | ||
idna==3.2 | ||
# via requests | ||
importlib-metadata==4.6.1 | ||
# via keyring | ||
keyring==23.0.1 | ||
# via flytekit | ||
kiwisolver==1.3.1 | ||
# via matplotlib | ||
marshmallow==3.13.0 | ||
# via | ||
# dataclasses-json | ||
# marshmallow-enum | ||
# marshmallow-jsonschema | ||
marshmallow-enum==1.5.1 | ||
# via dataclasses-json | ||
marshmallow-jsonschema==0.12.0 | ||
# via flytekit | ||
matplotlib==3.4.2 | ||
# via -r ../../../common/requirements-common.in | ||
mypy-extensions==0.4.3 | ||
# via typing-inspect | ||
natsort==7.1.1 | ||
# via flytekit | ||
numpy==1.21.1 | ||
# via | ||
# matplotlib | ||
# pandas | ||
# pyarrow | ||
# torchvision | ||
pandas==1.3.0 | ||
# via flytekit | ||
pathspec==0.9.0 | ||
# via scantree | ||
pathtools==0.1.2 | ||
# via wandb | ||
pillow==8.3.1 | ||
# via | ||
# matplotlib | ||
# torchvision | ||
promise==2.3 | ||
# via wandb | ||
protobuf==3.17.3 | ||
# via | ||
# flyteidl | ||
# flytekit | ||
# wandb | ||
psutil==5.8.0 | ||
# via wandb | ||
py==1.10.0 | ||
# via retry | ||
pyarrow==3.0.0 | ||
# via flytekit | ||
pyparsing==2.4.7 | ||
# via matplotlib | ||
python-dateutil==2.8.1 | ||
# via | ||
# croniter | ||
# flytekit | ||
# matplotlib | ||
# pandas | ||
# wandb | ||
python-json-logger==2.0.1 | ||
# via flytekit | ||
pytimeparse==1.1.8 | ||
# via flytekit | ||
pytz==2018.4 | ||
# via | ||
# flytekit | ||
# pandas | ||
pyyaml==5.4.1 | ||
# via wandb | ||
regex==2021.7.6 | ||
# via docker-image-py | ||
requests==2.26.0 | ||
# via | ||
# flytekit | ||
# responses | ||
# wandb | ||
responses==0.13.3 | ||
# via flytekit | ||
retry==0.9.2 | ||
# via flytekit | ||
scantree==0.0.1 | ||
# via dirhash | ||
sentry-sdk==1.3.0 | ||
# via wandb | ||
shortuuid==1.0.1 | ||
# via wandb | ||
six==1.16.0 | ||
# via | ||
# cycler | ||
# docker-pycreds | ||
# flytekit | ||
# grpcio | ||
# promise | ||
# protobuf | ||
# python-dateutil | ||
# responses | ||
# scantree | ||
# wandb | ||
smmap==4.0.0 | ||
# via gitdb | ||
sortedcontainers==2.4.0 | ||
# via flytekit | ||
statsd==3.3.0 | ||
# via flytekit | ||
stringcase==1.2.0 | ||
# via dataclasses-json | ||
subprocess32==3.5.4 | ||
# via wandb | ||
torch==1.9.0 | ||
# via | ||
# -r requirements.in | ||
# torchvision | ||
torchvision==0.10.0 | ||
# via -r requirements.in | ||
typing-extensions==3.10.0.0 | ||
# via | ||
# torch | ||
# typing-inspect | ||
typing-inspect==0.7.1 | ||
# via dataclasses-json | ||
urllib3==1.26.6 | ||
# via | ||
# flytekit | ||
# requests | ||
# responses | ||
# sentry-sdk | ||
# wandb | ||
wandb==0.11.0 | ||
# via -r requirements.in | ||
wheel==0.36.2 | ||
# via | ||
# -r ../../../common/requirements-common.in | ||
# flytekit | ||
wrapt==1.12.1 | ||
# via | ||
# deprecated | ||
# flytekit | ||
zipp==3.5.0 | ||
# via importlib-metadata |
3 changes: 3 additions & 0 deletions
3
cookbook/case_studies/ml_training/mnist_classifier/sandbox.config
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[sdk] | ||
workflow_packages=mnist_classifier | ||
python_venv=flytekit_venv |
Oops, something went wrong.