Skip to content

Commit

Permalink
Single node GPU training example (flyteorg#333)
Browse files Browse the repository at this point in the history
* Single node GPU training example

Signed-off-by: Ketan Umare <[email protected]>

* Minor fix related to tensorboard in PyTorch (flyteorg#334)

Signed-off-by: Jinserk Baik <[email protected]>

* updated pytorch training example

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* wandb integration, code lint, content

Signed-off-by: Samhita Alla <[email protected]>

* remove misplaced text

Signed-off-by: Samhita Alla <[email protected]>

* add pytorch in tests' manifest

Signed-off-by: Samhita Alla <[email protected]>

* changed pytorch to mnist

Signed-off-by: Samhita Alla <[email protected]>

* dockerfile

Signed-off-by: Samhita Alla <[email protected]>

* update link

Signed-off-by: cosmicBboy <[email protected]>

* update deps

Signed-off-by: cosmicBboy <[email protected]>

Co-authored-by: Jinserk Baik <[email protected]>
Co-authored-by: Samhita Alla <[email protected]>
Co-authored-by: cosmicBboy <[email protected]>
  • Loading branch information
4 people authored Jul 27, 2021
1 parent b3b34d7 commit ff4e179
Show file tree
Hide file tree
Showing 18 changed files with 825 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ghcr_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ jobs:
path: integrations/flytekit_plugins
- name: house_price_prediction
path: case_studies/ml_training
- name: mnist_classifier
path: case_studies/ml_training
steps:
- uses: actions/checkout@v2
with:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
wandb/
31 changes: 31 additions & 0 deletions cookbook/case_studies/ml_training/mnist_classifier/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3
LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks

WORKDIR /root
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8
ENV PYTHONPATH /root

# Give your wandb API key. Get it from https://wandb.ai/authorize.
# ENV WANDB_API_KEY your-api-key

# Install the AWS cli for AWS support
RUN pip install awscli

ENV VENV /opt/venv

# Virtual environment
RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

# Install Python dependencies
COPY mnist_classifier/requirements.txt /root
RUN pip install -r /root/requirements.txt

# Copy the actual code
COPY mnist_classifier/ /root/mnist_classifier/

# This tag is supplied by the build script and will be used to determine the version
# when registering tasks, workflows, and launch plans
ARG tag
ENV FLYTE_INTERNAL_IMAGE $tag
3 changes: 3 additions & 0 deletions cookbook/case_studies/ml_training/mnist_classifier/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
PREFIX=mnist_classifier
include ../../../common/Makefile
include ../../../common/leaf.mk
75 changes: 75 additions & 0 deletions cookbook/case_studies/ml_training/mnist_classifier/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
.. _mnist-classifier-training:

MNIST Classification With PyTorch and W&B
-----------------------------------------

PyTorch
=======

`Pytorch <https://pytorch.org/>`__ is a machine learning framework that accelerates the path from research prototyping
to production deployment. You can build *Tensors* and *Dynamic neural networks* in Python with strong GPU acceleration
using PyTorch.

In a nutshell, it is a Python package that provides two high-level features:

- Tensor computation (like NumPy) with strong GPU acceleration
- Deep neural networks built on a tape-based autograd system

Flyte directly has no unique understanding of PyTorch. As per Flyte, PyTorch is just a Python library.
However, when merged with Flyte, the combo helps utilize and bootstrap the infrastructure for PyTorch and ensures that things work well!
Additionally, it also offers other benefits of using tasks and workflows -- checkpointing, separation of concerns, and auto-memoization.

Specify GPU Requirement
=======================

One of the necessary directives applicable when working on deep learning models is explicitly requesting one or more GPUs.
This can be done by giving a simple directive to the task declaration as follows:

.. code-block:: python
from flytekit import Resources, task
@task(requests=Resources(gpu="1"), limits=Resources(gpu="1"))
def my_deep_learning_task():
...
.. tip::
It is recommended to use the same ``requests`` and ``limits`` for a GPU as automatic GPU scaling is not supported.

Moreover, to utilize the power of a GPU, ensure that your Flyte backend has GPU nodes provisioned.

Distributed Data-Parallel Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Flyte also supports distributed training for PyTorch models, but this is not native. This is achieved using one of the optional plugins, such as:

- Natively on Kubernetes using :ref:`kf-pytorch-op`
- On AWS using :ref:`aws-sagemaker` training

*Other distributed training plugins are coming soon -- MPIOperator, Google Vertex AI, etc. You can add your favorite services, too!*

Weights & Biases Integration
============================

`Weights & Biases <https://wandb.ai/site>`__, or simply, ``wandb`` helps build better models faster with experiment tracking, dataset versioning, and model management.

We'll use ``wandb`` alongside PyTorch to track our ML experiment and its concerned model parameters.

.. note::
Before running the example, create a ``wandb`` account and log in to access the API.
If you're running the code locally, run the command ``wandb login``.
If it's a remote cluster, you have to include the API key in the Dockerfile.

.. _pytorch-dockerfile:

PyTorch Dockerfile for Deployment
=================================

It is essential to build the Dockerfile with GPU support to use a GPU within PyTorch.
The example in this section uses a simple ``nvidia-supplied GPU Docker image`` as the base, and the rest of the construction is similar to the other Dockerfiles.

.. literalinclude:: ../../../../../case_studies/ml_training/mnist_classifier/Dockerfile
:language: docker

.. note::
Run your code in the ``ml_training`` directory, both locally and within the sandbox.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-r ../../../common/requirements-common.in
torch
torchvision
wandb
187 changes: 187 additions & 0 deletions cookbook/case_studies/ml_training/mnist_classifier/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# /Library/Developer/CommandLineTools/usr/bin/make requirements.txt
#
attrs==21.2.0
# via scantree
certifi==2021.5.30
# via
# requests
# sentry-sdk
charset-normalizer==2.0.3
# via requests
click==7.1.2
# via
# flytekit
# wandb
configparser==5.0.2
# via wandb
croniter==1.0.15
# via flytekit
cycler==0.10.0
# via matplotlib
dataclasses-json==0.5.4
# via flytekit
decorator==5.0.9
# via retry
deprecated==1.2.12
# via flytekit
dirhash==0.2.1
# via flytekit
docker-image-py==0.1.10
# via flytekit
docker-pycreds==0.4.0
# via wandb
flyteidl==0.19.14
# via flytekit
flytekit==0.20.1
# via -r ../../../common/requirements-common.in
gitdb==4.0.7
# via gitpython
gitpython==3.1.18
# via wandb
grpcio==1.39.0
# via flytekit
idna==3.2
# via requests
importlib-metadata==4.6.1
# via keyring
keyring==23.0.1
# via flytekit
kiwisolver==1.3.1
# via matplotlib
marshmallow==3.13.0
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.12.0
# via flytekit
matplotlib==3.4.2
# via -r ../../../common/requirements-common.in
mypy-extensions==0.4.3
# via typing-inspect
natsort==7.1.1
# via flytekit
numpy==1.21.1
# via
# matplotlib
# pandas
# pyarrow
# torchvision
pandas==1.3.0
# via flytekit
pathspec==0.9.0
# via scantree
pathtools==0.1.2
# via wandb
pillow==8.3.1
# via
# matplotlib
# torchvision
promise==2.3
# via wandb
protobuf==3.17.3
# via
# flyteidl
# flytekit
# wandb
psutil==5.8.0
# via wandb
py==1.10.0
# via retry
pyarrow==3.0.0
# via flytekit
pyparsing==2.4.7
# via matplotlib
python-dateutil==2.8.1
# via
# croniter
# flytekit
# matplotlib
# pandas
# wandb
python-json-logger==2.0.1
# via flytekit
pytimeparse==1.1.8
# via flytekit
pytz==2018.4
# via
# flytekit
# pandas
pyyaml==5.4.1
# via wandb
regex==2021.7.6
# via docker-image-py
requests==2.26.0
# via
# flytekit
# responses
# wandb
responses==0.13.3
# via flytekit
retry==0.9.2
# via flytekit
scantree==0.0.1
# via dirhash
sentry-sdk==1.3.0
# via wandb
shortuuid==1.0.1
# via wandb
six==1.16.0
# via
# cycler
# docker-pycreds
# flytekit
# grpcio
# promise
# protobuf
# python-dateutil
# responses
# scantree
# wandb
smmap==4.0.0
# via gitdb
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
# via flytekit
stringcase==1.2.0
# via dataclasses-json
subprocess32==3.5.4
# via wandb
torch==1.9.0
# via
# -r requirements.in
# torchvision
torchvision==0.10.0
# via -r requirements.in
typing-extensions==3.10.0.0
# via
# torch
# typing-inspect
typing-inspect==0.7.1
# via dataclasses-json
urllib3==1.26.6
# via
# flytekit
# requests
# responses
# sentry-sdk
# wandb
wandb==0.11.0
# via -r requirements.in
wheel==0.36.2
# via
# -r ../../../common/requirements-common.in
# flytekit
wrapt==1.12.1
# via
# deprecated
# flytekit
zipp==3.5.0
# via importlib-metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[sdk]
workflow_packages=mnist_classifier
python_venv=flytekit_venv
Loading

0 comments on commit ff4e179

Please sign in to comment.