From ff4e179abe231d65fe06f24a181f96ad737d40ee Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Tue, 27 Jul 2021 11:31:11 -0700 Subject: [PATCH] Single node GPU training example (#333) * Single node GPU training example Signed-off-by: Ketan Umare * Minor fix related to tensorboard in PyTorch (#334) Signed-off-by: Jinserk Baik <823222+jinserk@users.noreply.github.com> * updated pytorch training example Signed-off-by: Ketan Umare * updated Signed-off-by: Ketan Umare * wandb integration, code lint, content Signed-off-by: Samhita Alla * remove misplaced text Signed-off-by: Samhita Alla * add pytorch in tests' manifest Signed-off-by: Samhita Alla * changed pytorch to mnist Signed-off-by: Samhita Alla * dockerfile Signed-off-by: Samhita Alla * update link Signed-off-by: cosmicBboy * update deps Signed-off-by: cosmicBboy Co-authored-by: Jinserk Baik <823222+jinserk@users.noreply.github.com> Co-authored-by: Samhita Alla Co-authored-by: cosmicBboy --- .github/workflows/ghcr_push.yml | 2 + cookbook/case_studies/ml_training/__init__.py | 0 .../ml_training/mnist_classifier/.gitignore | 1 + .../ml_training/mnist_classifier/Dockerfile | 31 ++ .../ml_training/mnist_classifier/Makefile | 3 + .../ml_training/mnist_classifier/README.rst | 75 ++++ .../ml_training/mnist_classifier/__init__.py | 0 .../mnist_classifier/requirements.in | 4 + .../mnist_classifier/requirements.txt | 187 ++++++++ .../mnist_classifier/sandbox.config | 3 + .../mnist_classifier/single_node.py | 406 ++++++++++++++++++ cookbook/dev-requirements.in | 10 +- cookbook/dev-requirements.txt | 75 +++- cookbook/docs/conf.py | 26 +- cookbook/docs/ml_training.rst | 12 +- cookbook/flyte_tests_manifest.json | 10 +- .../aws/sagemaker_training/README.rst | 2 + .../kubernetes/kfpytorch/README.rst | 2 + 18 files changed, 825 insertions(+), 24 deletions(-) create mode 100644 cookbook/case_studies/ml_training/__init__.py create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/.gitignore create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/Dockerfile create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/Makefile create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/README.rst create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/__init__.py create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/requirements.in create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/requirements.txt create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/sandbox.config create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/single_node.py diff --git a/.github/workflows/ghcr_push.yml b/.github/workflows/ghcr_push.yml index 0f52d4d11b..caf9c11022 100644 --- a/.github/workflows/ghcr_push.yml +++ b/.github/workflows/ghcr_push.yml @@ -41,6 +41,8 @@ jobs: path: integrations/flytekit_plugins - name: house_price_prediction path: case_studies/ml_training + - name: mnist_classifier + path: case_studies/ml_training steps: - uses: actions/checkout@v2 with: diff --git a/cookbook/case_studies/ml_training/__init__.py b/cookbook/case_studies/ml_training/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/case_studies/ml_training/mnist_classifier/.gitignore b/cookbook/case_studies/ml_training/mnist_classifier/.gitignore new file mode 100644 index 0000000000..813919db4f --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/.gitignore @@ -0,0 +1 @@ +wandb/ \ No newline at end of file diff --git a/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile b/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile new file mode 100644 index 0000000000..5fe6cc70ba --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile @@ -0,0 +1,31 @@ +FROM nvcr.io/nvidia/pytorch:21.06-py3 +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks + +WORKDIR /root +ENV LANG C.UTF-8 +ENV LC_ALL C.UTF-8 +ENV PYTHONPATH /root + +# Give your wandb API key. Get it from https://wandb.ai/authorize. +# ENV WANDB_API_KEY your-api-key + +# Install the AWS cli for AWS support +RUN pip install awscli + +ENV VENV /opt/venv + +# Virtual environment +RUN python3 -m venv ${VENV} +ENV PATH="${VENV}/bin:$PATH" + +# Install Python dependencies +COPY mnist_classifier/requirements.txt /root +RUN pip install -r /root/requirements.txt + +# Copy the actual code +COPY mnist_classifier/ /root/mnist_classifier/ + +# This tag is supplied by the build script and will be used to determine the version +# when registering tasks, workflows, and launch plans +ARG tag +ENV FLYTE_INTERNAL_IMAGE $tag diff --git a/cookbook/case_studies/ml_training/mnist_classifier/Makefile b/cookbook/case_studies/ml_training/mnist_classifier/Makefile new file mode 100644 index 0000000000..de62aed227 --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/Makefile @@ -0,0 +1,3 @@ +PREFIX=mnist_classifier +include ../../../common/Makefile +include ../../../common/leaf.mk diff --git a/cookbook/case_studies/ml_training/mnist_classifier/README.rst b/cookbook/case_studies/ml_training/mnist_classifier/README.rst new file mode 100644 index 0000000000..e895fd5710 --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/README.rst @@ -0,0 +1,75 @@ +.. _mnist-classifier-training: + +MNIST Classification With PyTorch and W&B +----------------------------------------- + +PyTorch +======= + +`Pytorch `__ is a machine learning framework that accelerates the path from research prototyping +to production deployment. You can build *Tensors* and *Dynamic neural networks* in Python with strong GPU acceleration +using PyTorch. + +In a nutshell, it is a Python package that provides two high-level features: + +- Tensor computation (like NumPy) with strong GPU acceleration +- Deep neural networks built on a tape-based autograd system + +Flyte directly has no unique understanding of PyTorch. As per Flyte, PyTorch is just a Python library. +However, when merged with Flyte, the combo helps utilize and bootstrap the infrastructure for PyTorch and ensures that things work well! +Additionally, it also offers other benefits of using tasks and workflows -- checkpointing, separation of concerns, and auto-memoization. + +Specify GPU Requirement +======================= + +One of the necessary directives applicable when working on deep learning models is explicitly requesting one or more GPUs. +This can be done by giving a simple directive to the task declaration as follows: + +.. code-block:: python + + from flytekit import Resources, task + + @task(requests=Resources(gpu="1"), limits=Resources(gpu="1")) + def my_deep_learning_task(): + ... + +.. tip:: + It is recommended to use the same ``requests`` and ``limits`` for a GPU as automatic GPU scaling is not supported. + +Moreover, to utilize the power of a GPU, ensure that your Flyte backend has GPU nodes provisioned. + +Distributed Data-Parallel Training +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Flyte also supports distributed training for PyTorch models, but this is not native. This is achieved using one of the optional plugins, such as: + +- Natively on Kubernetes using :ref:`kf-pytorch-op` +- On AWS using :ref:`aws-sagemaker` training + +*Other distributed training plugins are coming soon -- MPIOperator, Google Vertex AI, etc. You can add your favorite services, too!* + +Weights & Biases Integration +============================ + +`Weights & Biases `__, or simply, ``wandb`` helps build better models faster with experiment tracking, dataset versioning, and model management. + +We'll use ``wandb`` alongside PyTorch to track our ML experiment and its concerned model parameters. + +.. note:: + Before running the example, create a ``wandb`` account and log in to access the API. + If you're running the code locally, run the command ``wandb login``. + If it's a remote cluster, you have to include the API key in the Dockerfile. + +.. _pytorch-dockerfile: + +PyTorch Dockerfile for Deployment +================================= + +It is essential to build the Dockerfile with GPU support to use a GPU within PyTorch. +The example in this section uses a simple ``nvidia-supplied GPU Docker image`` as the base, and the rest of the construction is similar to the other Dockerfiles. + +.. literalinclude:: ../../../../../case_studies/ml_training/mnist_classifier/Dockerfile + :language: docker + +.. note:: + Run your code in the ``ml_training`` directory, both locally and within the sandbox. diff --git a/cookbook/case_studies/ml_training/mnist_classifier/__init__.py b/cookbook/case_studies/ml_training/mnist_classifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cookbook/case_studies/ml_training/mnist_classifier/requirements.in b/cookbook/case_studies/ml_training/mnist_classifier/requirements.in new file mode 100644 index 0000000000..65ebe53222 --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/requirements.in @@ -0,0 +1,4 @@ +-r ../../../common/requirements-common.in +torch +torchvision +wandb diff --git a/cookbook/case_studies/ml_training/mnist_classifier/requirements.txt b/cookbook/case_studies/ml_training/mnist_classifier/requirements.txt new file mode 100644 index 0000000000..1967aea3ea --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/requirements.txt @@ -0,0 +1,187 @@ +# +# This file is autogenerated by pip-compile with python 3.8 +# To update, run: +# +# /Library/Developer/CommandLineTools/usr/bin/make requirements.txt +# +attrs==21.2.0 + # via scantree +certifi==2021.5.30 + # via + # requests + # sentry-sdk +charset-normalizer==2.0.3 + # via requests +click==7.1.2 + # via + # flytekit + # wandb +configparser==5.0.2 + # via wandb +croniter==1.0.15 + # via flytekit +cycler==0.10.0 + # via matplotlib +dataclasses-json==0.5.4 + # via flytekit +decorator==5.0.9 + # via retry +deprecated==1.2.12 + # via flytekit +dirhash==0.2.1 + # via flytekit +docker-image-py==0.1.10 + # via flytekit +docker-pycreds==0.4.0 + # via wandb +flyteidl==0.19.14 + # via flytekit +flytekit==0.20.1 + # via -r ../../../common/requirements-common.in +gitdb==4.0.7 + # via gitpython +gitpython==3.1.18 + # via wandb +grpcio==1.39.0 + # via flytekit +idna==3.2 + # via requests +importlib-metadata==4.6.1 + # via keyring +keyring==23.0.1 + # via flytekit +kiwisolver==1.3.1 + # via matplotlib +marshmallow==3.13.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.12.0 + # via flytekit +matplotlib==3.4.2 + # via -r ../../../common/requirements-common.in +mypy-extensions==0.4.3 + # via typing-inspect +natsort==7.1.1 + # via flytekit +numpy==1.21.1 + # via + # matplotlib + # pandas + # pyarrow + # torchvision +pandas==1.3.0 + # via flytekit +pathspec==0.9.0 + # via scantree +pathtools==0.1.2 + # via wandb +pillow==8.3.1 + # via + # matplotlib + # torchvision +promise==2.3 + # via wandb +protobuf==3.17.3 + # via + # flyteidl + # flytekit + # wandb +psutil==5.8.0 + # via wandb +py==1.10.0 + # via retry +pyarrow==3.0.0 + # via flytekit +pyparsing==2.4.7 + # via matplotlib +python-dateutil==2.8.1 + # via + # croniter + # flytekit + # matplotlib + # pandas + # wandb +python-json-logger==2.0.1 + # via flytekit +pytimeparse==1.1.8 + # via flytekit +pytz==2018.4 + # via + # flytekit + # pandas +pyyaml==5.4.1 + # via wandb +regex==2021.7.6 + # via docker-image-py +requests==2.26.0 + # via + # flytekit + # responses + # wandb +responses==0.13.3 + # via flytekit +retry==0.9.2 + # via flytekit +scantree==0.0.1 + # via dirhash +sentry-sdk==1.3.0 + # via wandb +shortuuid==1.0.1 + # via wandb +six==1.16.0 + # via + # cycler + # docker-pycreds + # flytekit + # grpcio + # promise + # protobuf + # python-dateutil + # responses + # scantree + # wandb +smmap==4.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +stringcase==1.2.0 + # via dataclasses-json +subprocess32==3.5.4 + # via wandb +torch==1.9.0 + # via + # -r requirements.in + # torchvision +torchvision==0.10.0 + # via -r requirements.in +typing-extensions==3.10.0.0 + # via + # torch + # typing-inspect +typing-inspect==0.7.1 + # via dataclasses-json +urllib3==1.26.6 + # via + # flytekit + # requests + # responses + # sentry-sdk + # wandb +wandb==0.11.0 + # via -r requirements.in +wheel==0.36.2 + # via + # -r ../../../common/requirements-common.in + # flytekit +wrapt==1.12.1 + # via + # deprecated + # flytekit +zipp==3.5.0 + # via importlib-metadata diff --git a/cookbook/case_studies/ml_training/mnist_classifier/sandbox.config b/cookbook/case_studies/ml_training/mnist_classifier/sandbox.config new file mode 100644 index 0000000000..ad5f7ddb33 --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/sandbox.config @@ -0,0 +1,3 @@ +[sdk] +workflow_packages=mnist_classifier +python_venv=flytekit_venv diff --git a/cookbook/case_studies/ml_training/mnist_classifier/single_node.py b/cookbook/case_studies/ml_training/mnist_classifier/single_node.py new file mode 100644 index 0000000000..b9c4574c1c --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/single_node.py @@ -0,0 +1,406 @@ +""" +Single GPU Training +------------------- + +Training a model on a single node on one GPU is as trivial as writing any Flyte task and simply setting the GPU to ``1``. +As long as the Docker image is built correctly with the right version of the GPU drivers and the Flyte backend is +provisioned to have GPU machines, Flyte will execute the task on a node that has GPU(s). + +Currently, Flyte does not provide any specific task type for PyTorch (though it is entirely possible to provide a task-type +that supports *PyTorch-Ignite* or *PyTorch Lightening* support, but this is not critical). One can request for a GPU, simply +by setting GPU="1" resource request and then at runtime, the GPU will be provisioned. + +In this example, we'll see how we can create any PyTorch model, train it using Flyte and a specialized container. +""" + +# %% +# First, let's import the libraries. +import json +import os +import typing +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import wandb +from dataclasses_json import dataclass_json +from flytekit import Resources, task, workflow +from flytekit.types.file import PythonPickledFile +from torch import distributed as dist +from torch import nn, optim +from torchvision import datasets, transforms + +# %% +# Let's define some variables to be used later. +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + +# %% +# The following variables are specific to ``wandb``: +# +# - ``NUM_BATCHES_TO_LOG``: Number of batches to log from the test data for each test step +# - ``LOG_IMAGES_PER_BATCH``: Number of images to log per test batch +NUM_BATCHES_TO_LOG = 10 +LOG_IMAGES_PER_BATCH = 32 + +# %% +# If running remotely, copy your ``wandb`` API key to the Dockerfile. Next, login to ``wandb``. +# You can disable this if you're already logged in on your local machine. +wandb.login() + +# %% +# Next, we initialize the ``wandb`` project. +# +# .. admonition:: MUST DO! +# +# Replace ``entity`` value with your username. +wandb.init(project="mnist-single-node", entity="your-user-name") + +# %% +# Creating the Network +# ==================== +# +# We use a simple PyTorch model with :py:class:`pytorch:torch.nn.Conv2d` and :py:class:`pytorch:torch.nn.Linear` layers. +# Let's also use :py:func:`pytorch:torch.nn.functional.relu`, :py:func:`pytorch:torch.nn.functional.max_pool2d`, and +# :py:func:`pytorch:torch.nn.functional.relu` to define the forward pass. +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.Linear(4 * 4 * 50, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +# %% +# Training +# ======== +# +# We define a ``train`` function to enclose the training loop per epoch, i.e., this gets called for every successive epoch. +# Additionally, we log the loss and epoch progression, which can later be visualized in a ``wandb`` dashboard. +def train(model, device, train_loader, optimizer, epoch, log_interval): + model.train() + + # hooks into the model to collect gradients and the topology + wandb.watch(model) + + # loop through the training batches + for batch_idx, (data, target) in enumerate(train_loader): + + # device conversion + data, target = data.to(device), target.to(device) + + # clear gradient + optimizer.zero_grad() + + # forward pass + output = model(data) + + # compute loss + loss = F.nll_loss(output, target) + + # propagate the loss backward + loss.backward() + + # update the model parameters + optimizer.step() + + if batch_idx % log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + # log epoch and loss + wandb.log({"loss": loss, "epoch": epoch}) + + +# %% +# We define a test logger function which will be called when we run the model on test dataset. +def log_test_predictions(images, labels, outputs, predicted, my_table, log_counter): + """ + Convenience funtion to log predictions for a batch of test images + """ + + # obtain confidence scores for all classes + scores = F.softmax(outputs.data, dim=1) + log_scores = scores.cpu().numpy() + log_images = images.cpu().numpy() + log_labels = labels.cpu().numpy() + log_preds = predicted.cpu().numpy() + + # assign ids based on the order of the images + _id = 0 + for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores): + + # add required info to data table: + # id, image pixels, model's guess, true label, scores for all classes + img_id = str(_id) + "_" + str(log_counter) + my_table.add_data(img_id, wandb.Image(i), p, l, *s) + _id += 1 + if _id == LOG_IMAGES_PER_BATCH: + break + + +# %% +# Evaluation +# ========== +# +# We define a ``test`` function to test the model on the test dataset. +# +# We log ``accuracy``, ``test_loss``, and a ``wandb`` `table `__. +# The ``wandb`` table can help in depicting the model's performance in a structured format. +def test(model, device, test_loader): + + # ``wandb`` tabular columns + columns = ["id", "image", "guess", "truth"] + for digit in range(10): + columns.append("score_" + str(digit)) + my_table = wandb.Table(columns=columns) + + model.eval() + + # hooks into the model to collect gradients and the topology + wandb.watch(model) + + test_loss = 0 + correct = 0 + log_counter = 0 + + # disable gradient + with torch.no_grad(): + + # loop through the test data loader + for images, targets in test_loader: + + # device conversion + images, targets = images.to(device), targets.to(device) + + # forward pass -- generate predictions + outputs = model(images) + + # sum up batch loss + test_loss += F.nll_loss(outputs, targets, reduction="sum").item() + + # get the index of the max log-probability + _, predicted = torch.max(outputs.data, 1) + + # compare predictions to true label + correct += (predicted == targets).sum().item() + + # log predictions to the ``wandb`` table + if log_counter < NUM_BATCHES_TO_LOG: + log_test_predictions( + images, targets, outputs, predicted, my_table, log_counter + ) + log_counter += 1 + + # compute the average loss + test_loss /= len(test_loader.dataset) + + print("\naccuracy={:.4f}\n".format(float(correct) / len(test_loader.dataset))) + accuracy = float(correct) / len(test_loader.dataset) + + # log the average loss, accuracy, and table + wandb.log( + {"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table} + ) + + return accuracy + + +# %% +# Next, we define a function that runs for every epoch. It calls the ``train`` and ``test`` functions. +def epoch_step( + model, device, train_loader, test_loader, optimizer, epoch, log_interval +): + train(model, device, train_loader, optimizer, epoch, log_interval) + return test(model, device, test_loader) + + +# %% +# Hyperparameters +# =============== +# +# We define a few hyperparameters for training our model. +@dataclass_json +@dataclass +class Hyperparameters(object): + """ + Args: + batch_size: input batch size for training (default: 64) + test_batch_size: input batch size for testing (default: 1000) + epochs: number of epochs to train (default: 10) + learning_rate: learning rate (default: 0.01) + sgd_momentum: SGD momentum (default: 0.5) + seed: random seed (default: 1) + log_interval: how many batches to wait before logging training status + dir: directory where summary logs are stored + """ + + backend: str = dist.Backend.GLOO + sgd_momentum: float = 0.5 + seed: int = 1 + log_interval: int = 10 + batch_size: int = 64 + test_batch_size: int = 1000 + epochs: int = 10 + learning_rate: float = 0.01 + + +# %% +# Training and Evaluating +# ======================= +# +# The output model using :py:func:`pytorch:torch.save` saves the `state_dict` as described +# `in pytorch docs `_. +# A common convention is to have the ``.pt`` extension for the model file. +# +# .. note:: +# Note the usage of ``requests=Resources(gpu="1")``. This will force Flyte to allocate this task onto a machine with GPU(s). +# The task will be queued up until a machine with GPU(s) can be procured. Also, for the GPU Training to work, the +# Dockerfile needs to be built as explained in the :ref:`pytorch-dockerfile` section. +TrainingOutputs = typing.NamedTuple( + "TrainingOutputs", + epoch_accuracies=typing.List[float], + model_state=PythonPickledFile, +) + + +@task(retries=2, cache=True, cache_version="1.0", requests=Resources(gpu="1")) +def train_mnist(hp: Hyperparameters) -> TrainingOutputs: + + # store the hyperparameters' config in ``wandb`` + cfg = wandb.config + cfg.update(json.loads(hp.to_json())) + print(wandb.config) + + # set random seed + torch.manual_seed(hp.seed) + + # ideally, if GPU training is required, and if cuda is not available, we can raise an exception + # however, as we want this algorithm to work locally as well (and most users don't have a GPU locally), we will fallback to using a CPU + use_cuda = torch.cuda.is_available() + print(f"Use cuda {use_cuda}") + device = torch.device("cuda" if use_cuda else "cpu") + + print("Using device: {}, world size: {}".format(device, WORLD_SIZE)) + + # load Data + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + training_data_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "../data", + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=hp.batch_size, + shuffle=True, + **kwargs, + ) + test_data_loader = torch.utils.data.DataLoader( + datasets.MNIST( + "../data", + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=hp.test_batch_size, + shuffle=False, + **kwargs, + ) + + # train the model + model = Net().to(device) + + optimizer = optim.SGD( + model.parameters(), lr=hp.learning_rate, momentum=hp.sgd_momentum + ) + + # run multiple epochs and capture the accuracies for each epoch + accuracies = [ + epoch_step( + model, + device, + train_loader=training_data_loader, + test_loader=test_data_loader, + optimizer=optimizer, + epoch=epoch, + log_interval=hp.log_interval, + ) + for epoch in range(1, hp.epochs + 1) + ] + + # after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile` + # type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format + # of the serialized model as ``pickled`` + model_file = "mnist_cnn.pt" + torch.save(model.state_dict(), model_file) + + return TrainingOutputs( + epoch_accuracies=accuracies, model_state=PythonPickledFile(model_file) + ) + + +# %% +# Finally, we define a workflow to run the training algorithm. We return the model and accuracies. +@workflow +def pytorch_training_wf( + hp: Hyperparameters, +) -> (PythonPickledFile, typing.List[float]): + accuracies, model = train_mnist(hp=hp) + return model, accuracies + + +# %% +# Running the Model Locally +# ========================= +# +# It is possible to run the model locally with almost no modifications (as long as the code takes care of resolving +# if the code is distributed or not). This is how we can do it: +if __name__ == "__main__": + model, accuracies = pytorch_training_wf( + hp=Hyperparameters(epochs=10, batch_size=128) + ) + print(f"Model: {model}, Accuracies: {accuracies}") + +# %% +# Weights & Biases Report +# ======================= +# +# Lastly, let's look at the reports that are generated by the model. +# +# .. figure:: https://raw.githubusercontent.com/flyteorg/flyte/static-resources/img/flytesnacks/pytorch/single-node/wandb_graphs.png +# :alt: Wandb Graphs +# :class: with-shadow +# +# Wandb Graphs +# +# .. figure:: https://raw.githubusercontent.com/flyteorg/flyte/static-resources/img/flytesnacks/pytorch/single-node/wandb_table.png +# :alt: Wandb Table +# :class: with-shadow +# +# Wandb Table +# +# You can refer to the complete ``wandb`` report `here `__. +# +# .. tip:: +# A lot more customizations can be done to the report according to your requirement! diff --git a/cookbook/dev-requirements.in b/cookbook/dev-requirements.in index dc9a62fa47..b339797a12 100644 --- a/cookbook/dev-requirements.in +++ b/cookbook/dev-requirements.in @@ -1,12 +1,14 @@ --r ./integrations/kubernetes/pod/requirements.in --r ./integrations/kubernetes/k8s_spark/requirements.in +-r ./docs-requirements.in -r ./integrations/aws/sagemaker_training/requirements.in --r ./integrations/kubernetes/kfpytorch/requirements.in -r ./integrations/aws/sagemaker_pytorch/requirements.in -r ./integrations/external_services/hive/requirements.in -r ./integrations/flytekit_plugins/dolt/requirements.in +-r ./integrations/kubernetes/pod/requirements.in +-r ./integrations/kubernetes/k8s_spark/requirements.in +-r ./integrations/kubernetes/kfpytorch/requirements.in +-r ./case_studies/ml_training/house_price_prediction/requirements.in +-r ./case_studies/ml_training/mnist_classifier/requirements.in -r ./case_studies/ml_training/pima_diabetes/requirements.in --r ./docs-requirements.in black==19.10b0 coverage diff --git a/cookbook/dev-requirements.txt b/cookbook/dev-requirements.txt index 31b018f6a0..6cbcf7a7e0 100644 --- a/cookbook/dev-requirements.txt +++ b/cookbook/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.9 +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # /Library/Developer/CommandLineTools/usr/bin/make dev-requirements.txt @@ -52,6 +52,7 @@ certifi==2021.5.30 # via # kubernetes # requests + # sentry-sdk cffi==1.14.6 # via # bcrypt @@ -63,6 +64,9 @@ click==7.1.2 # via # black # flytekit + # wandb +configparser==5.0.2 + # via wandb coverage==5.5 # via -r dev-requirements.in croniter==1.0.15 @@ -85,6 +89,8 @@ dirhash==0.2.1 # via flytekit docker-image-py==0.1.12 # via flytekit +docker-pycreds==0.4.0 + # via wandb docutils==0.16 # via # -r ./docs-requirements.in @@ -112,6 +118,8 @@ flyteidl==0.19.15 flytekit==0.20.1 # via # -r ././common/requirements-common.in + # -r ./case_studies/ml_training/house_price_prediction/../../../common/requirements-common.in + # -r ./case_studies/ml_training/mnist_classifier/../../../common/requirements-common.in # -r ./case_studies/ml_training/pima_diabetes/../../../common/requirements-common.in # -r ./integrations/aws/sagemaker_pytorch/../../../common/requirements-common.in # -r ./integrations/aws/sagemaker_training/../../../common/requirements-common.in @@ -148,6 +156,10 @@ gast==0.4.0 # via tensorflow gevent==21.1.2 # via sagemaker-training +gitdb==4.0.7 + # via gitpython +gitpython==3.1.18 + # via wandb google-auth==1.33.1 # via # google-auth-oauthlib @@ -174,6 +186,8 @@ imagesize==1.2.0 # via sphinx importlib-metadata==4.6.1 # via keyring +importlib-resources==5.2.0 + # via tensorflow-datasets iniconfig==1.1.1 # via pytest inotify_simple==1.2.1 @@ -192,6 +206,7 @@ jmespath==0.10.0 # botocore joblib==1.0.1 # via + # -r ./case_studies/ml_training/house_price_prediction/requirements.in # -r ./case_studies/ml_training/pima_diabetes/requirements.in # scikit-learn keras-nightly==2.5.0.dev2021032900 @@ -222,6 +237,9 @@ marshmallow-jsonschema==0.12.0 matplotlib==3.4.2 # via # -r ././common/requirements-common.in + # -r ./case_studies/ml_training/house_price_prediction/../../../common/requirements-common.in + # -r ./case_studies/ml_training/house_price_prediction/requirements.in + # -r ./case_studies/ml_training/mnist_classifier/../../../common/requirements-common.in # -r ./case_studies/ml_training/pima_diabetes/../../../common/requirements-common.in # -r ./case_studies/ml_training/pima_diabetes/requirements.in # -r ./integrations/aws/sagemaker_pytorch/../../../common/requirements-common.in @@ -254,6 +272,7 @@ numpy==1.19.5 # tensorboardx # tensorflow # tensorflow-datasets + # torchvision # xgboost oauthlib==3.1.1 # via requests-oauthlib @@ -273,14 +292,19 @@ pathspec==0.9.0 # via # black # scantree +pathtools==0.1.2 + # via wandb pillow==8.3.1 # via # -r ./docs-requirements.in # matplotlib + # torchvision pluggy==0.13.1 # via pytest promise==2.3 - # via tensorflow-datasets + # via + # tensorflow-datasets + # wandb protobuf==3.17.3 # via # flyteidl @@ -292,8 +316,11 @@ protobuf==3.17.3 # tensorflow # tensorflow-datasets # tensorflow-metadata + # wandb psutil==5.8.0 - # via sagemaker-training + # via + # sagemaker-training + # wandb py==1.10.0 # via # pytest @@ -339,6 +366,7 @@ python-dateutil==2.8.1 # kubernetes # matplotlib # pandas + # wandb python-json-logger==2.0.1 # via flytekit pytimeparse==1.1.8 @@ -352,6 +380,7 @@ pyyaml==5.4.1 # via # kubernetes # sphinx-autoapi + # wandb regex==2021.7.6 # via # black @@ -365,6 +394,7 @@ requests==2.26.0 # sphinx # tensorboard # tensorflow-datasets + # wandb requests-oauthlib==1.3.0 # via # google-auth-oauthlib @@ -390,6 +420,10 @@ scipy==1.7.0 # sagemaker-training # scikit-learn # xgboost +sentry-sdk==1.3.1 + # via wandb +shortuuid==1.0.1 + # via wandb six==1.15.0 # via # absl-py @@ -397,6 +431,7 @@ six==1.15.0 # astunparse # bcrypt # cycler + # docker-pycreds # flytekit # google-auth # google-pasta @@ -415,8 +450,13 @@ six==1.15.0 # sphinxext-remoteliteralinclude # tensorflow # tensorflow-datasets + # wandb sklearn==0.0 - # via -r ./case_studies/ml_training/pima_diabetes/requirements.in + # via + # -r ./case_studies/ml_training/house_price_prediction/requirements.in + # -r ./case_studies/ml_training/pima_diabetes/requirements.in +smmap==4.0.0 + # via gitdb snowballstemmer==2.1.0 # via sphinx sortedcontainers==2.4.0 @@ -475,8 +515,12 @@ statsd==3.3.0 # via flytekit stringcase==1.2.0 # via dataclasses-json +subprocess32==3.5.4 + # via wandb tabulate==0.8.9 - # via -r ./case_studies/ml_training/pima_diabetes/requirements.in + # via + # -r ./case_studies/ml_training/house_price_prediction/requirements.in + # -r ./case_studies/ml_training/pima_diabetes/requirements.in tensorboard==2.5.0 # via tensorflow tensorboard-data-server==0.6.1 @@ -508,6 +552,12 @@ toml==0.10.2 # black # flake8-black # pytest +torch==1.9.0 + # via + # -r ./case_studies/ml_training/mnist_classifier/requirements.in + # torchvision +torchvision==0.10.0 + # via -r ./case_studies/ml_training/mnist_classifier/requirements.in tqdm==4.61.2 # via tensorflow-datasets typed-ast==1.4.3 @@ -515,6 +565,7 @@ typed-ast==1.4.3 typing-extensions==3.7.4.3 # via # tensorflow + # torch # typing-inspect typing-inspect==0.7.1 # via dataclasses-json @@ -527,6 +578,10 @@ urllib3==1.26.6 # kubernetes # requests # responses + # sentry-sdk + # wandb +wandb==0.11.0 + # via -r ./case_studies/ml_training/mnist_classifier/requirements.in websocket-client==1.1.0 # via kubernetes werkzeug==2.0.1 @@ -536,6 +591,8 @@ werkzeug==2.0.1 wheel==0.36.2 # via # -r ././common/requirements-common.in + # -r ./case_studies/ml_training/house_price_prediction/../../../common/requirements-common.in + # -r ./case_studies/ml_training/mnist_classifier/../../../common/requirements-common.in # -r ./case_studies/ml_training/pima_diabetes/../../../common/requirements-common.in # -r ./integrations/aws/sagemaker_pytorch/../../../common/requirements-common.in # -r ./integrations/aws/sagemaker_training/../../../common/requirements-common.in @@ -555,9 +612,13 @@ wrapt==1.12.1 # flytekit # tensorflow xgboost==1.4.2 - # via -r ./case_studies/ml_training/pima_diabetes/requirements.in + # via + # -r ./case_studies/ml_training/house_price_prediction/requirements.in + # -r ./case_studies/ml_training/pima_diabetes/requirements.in zipp==3.5.0 - # via importlib-metadata + # via + # importlib-metadata + # importlib-resources zope.event==4.5.0 # via gevent zope.interface==5.4.0 diff --git a/cookbook/docs/conf.py b/cookbook/docs/conf.py index 09e0fda7db..e5d84ba78e 100644 --- a/cookbook/docs/conf.py +++ b/cookbook/docs/conf.py @@ -130,6 +130,7 @@ class CustomSorter(FileNameSortKey): "multiregion_house_price_predictor.py", "datacleaning_tasks.py", "datacleaning_workflow.py", + "single_node.py", ] """ Take a look at the code for the default sorter included in the sphinx_gallery to see how this works. @@ -238,6 +239,7 @@ def __call__(self, filename): "../core/type_system", "../case_studies/ml_training/pima_diabetes", "../case_studies/ml_training/house_price_prediction", + "../case_studies/ml_training/mnist_classifier", "../case_studies/feature_engineering/sqlite_datacleaning", "../testing", "../core/containerization", @@ -264,6 +266,7 @@ def __call__(self, filename): "auto/core/type_system", "auto/case_studies/ml_training/pima_diabetes", "auto/case_studies/ml_training/house_price_prediction", + "auto/case_studies/ml_training/mnist_classifier", "auto/case_studies/feature_engineering/sqlite_datacleaning", "auto/testing", "auto/core/containerization", @@ -362,10 +365,10 @@ def hide_example_page(file_handler): no_imports = False return ( - example_content.startswith('"""') - and example_content.endswith('"""') - and no_percent_comments - and no_imports + example_content.startswith('"""') + and example_content.endswith('"""') + and no_percent_comments + and no_imports ) @@ -382,10 +385,10 @@ def hide_example_page(file_handler): if hide_example_page(fh): page_id = ( str(f) - .replace("..", "auto") - .replace("/", "-") - .replace(".", "-") - .replace("_", "-") + .replace("..", "auto") + .replace("/", "-") + .replace(".", "-") + .replace("_", "-") ) hide_download_page_ids.append(f"sphx-glr-download-{page_id}") @@ -419,12 +422,13 @@ def hide_example_page(file_handler): # "flytekit": ("/Users/ytong/go/src/github.com/lyft/flytekit/docs/build/html", None), "flyteidl": ("https://docs.flyte.org/projects/flyteidl/en/latest", None), "flytectl": ("https://docs.flyte.org/projects/flytectl/en/latest/", None), + "pytorch": ("https://pytorch.org/docs/stable/", None), } # Sphinx-tabs config -sphinx_tabs_valid_builders = ['linkcheck'] +sphinx_tabs_valid_builders = ["linkcheck"] # Sphinx-mermaid config -mermaid_output_format = 'raw' -mermaid_version = 'latest' +mermaid_output_format = "raw" +mermaid_version = "latest" mermaid_init_js = "mermaid.initialize({startOnLoad:false});" diff --git a/cookbook/docs/ml_training.rst b/cookbook/docs/ml_training.rst index b5a5ecb243..55c222a126 100644 --- a/cookbook/docs/ml_training.rst +++ b/cookbook/docs/ml_training.rst @@ -19,7 +19,16 @@ ML Training :text: House Price Regression :classes: btn-block stretched-link ^^^^^^^^^^^^ - Use dynamic workflows to train a multiregion house price prediction model. + Use dynamic workflows to train a multiregion house price prediction model using XGBoost + + --- + + .. link-button:: auto/case_studies/ml_training/mnist_classifier/index + :type: ref + :text: MNIST Classification + :classes: btn-block stretched-link + ^^^^^^^^^^^^ + Train a neural network on MNIST with PyTorch and W&B .. toctree:: @@ -29,6 +38,7 @@ ML Training auto/case_studies/ml_training/pima_diabetes/index auto/case_studies/ml_training/house_price_prediction/index + auto/case_studies/ml_training/mnist_classifier/index .. TODO: write tutorials for data parallel training, distributed training, and single node training diff --git a/cookbook/flyte_tests_manifest.json b/cookbook/flyte_tests_manifest.json index b3001891a7..ffae38df19 100644 --- a/cookbook/flyte_tests_manifest.json +++ b/cookbook/flyte_tests_manifest.json @@ -86,5 +86,13 @@ "exit_success": true, "exit_message": "" } - }] + },{ + "name": "case-studies-mnist-classifier", + "priority": "P2", + "path": "case_studies/ml_training/mnist_classifier", + "exitCondition": { + "exit_success": true, + "exit_message": "" + } +}] diff --git a/cookbook/integrations/aws/sagemaker_training/README.rst b/cookbook/integrations/aws/sagemaker_training/README.rst index 3c4c41d1e8..9e7789e658 100644 --- a/cookbook/integrations/aws/sagemaker_training/README.rst +++ b/cookbook/integrations/aws/sagemaker_training/README.rst @@ -1,3 +1,5 @@ +.. _aws-sagemaker: + Sagemaker ========= This section provides examples of Flyte Plugins that are designed to work with diff --git a/cookbook/integrations/kubernetes/kfpytorch/README.rst b/cookbook/integrations/kubernetes/kfpytorch/README.rst index 733cdedd22..a15dbb9227 100644 --- a/cookbook/integrations/kubernetes/kfpytorch/README.rst +++ b/cookbook/integrations/kubernetes/kfpytorch/README.rst @@ -1,3 +1,5 @@ +.. _kf-pytorch-op: + Pytorch Operator =================