From 942ee876994d8a09a1e897e4248f0e43da528de0 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Tue, 17 Aug 2021 12:01:10 -0400 Subject: [PATCH] Single node GPU training example (#333) (#352) * update pytorch multi-gpu example, incorporate comments @samhita-alla @kumare3 Signed-off-by: Niels Bantilan * Apply suggestions from code review Co-authored-by: Samhita Alla Signed-off-by: Niels Bantilan Co-authored-by: Samhita Alla --- .../ml_training/mnist_classifier/Dockerfile | 14 +- ...node.py => pytorch_single_node_and_gpu.py} | 207 ++++------ .../pytorch_single_node_multi_gpu.py | 382 ++++++++++++++++++ 3 files changed, 466 insertions(+), 137 deletions(-) rename cookbook/case_studies/ml_training/mnist_classifier/{single_node.py => pytorch_single_node_and_gpu.py} (72%) create mode 100644 cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_multi_gpu.py diff --git a/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile b/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile index 5fe6cc70ba..bd076c758f 100644 --- a/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile +++ b/cookbook/case_studies/ml_training/mnist_classifier/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:21.06-py3 +FROM pytorch/pytorch:1.9.0-cuda10.2-cudnn7-runtime LABEL org.opencontainers.image.source https://github.com/flyteorg/flytesnacks WORKDIR /root @@ -6,15 +6,15 @@ ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 ENV PYTHONPATH /root -# Give your wandb API key. Get it from https://wandb.ai/authorize. -# ENV WANDB_API_KEY your-api-key +# Set your wandb API key and user name. Get the API key from https://wandb.ai/authorize. +# ENV WANDB_API_KEY +# ENV WANDB_USERNAME # Install the AWS cli for AWS support RUN pip install awscli -ENV VENV /opt/venv - # Virtual environment +ENV VENV /opt/venv RUN python3 -m venv ${VENV} ENV PATH="${VENV}/bin:$PATH" @@ -25,6 +25,10 @@ RUN pip install -r /root/requirements.txt # Copy the actual code COPY mnist_classifier/ /root/mnist_classifier/ +# Copy the makefile targets to expose on the container. This makes it easier to register. +COPY in_container.mk /root/Makefile +COPY mnist_classifier/sandbox.config /root + # This tag is supplied by the build script and will be used to determine the version # when registering tasks, workflows, and launch plans ARG tag diff --git a/cookbook/case_studies/ml_training/mnist_classifier/single_node.py b/cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_and_gpu.py similarity index 72% rename from cookbook/case_studies/ml_training/mnist_classifier/single_node.py rename to cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_and_gpu.py index b9c4574c1c..b53e1513fa 100644 --- a/cookbook/case_studies/ml_training/mnist_classifier/single_node.py +++ b/cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_and_gpu.py @@ -1,6 +1,6 @@ """ -Single GPU Training -------------------- +Single Node, Single GPU Training +-------------------------------- Training a model on a single node on one GPU is as trivial as writing any Flyte task and simply setting the GPU to ``1``. As long as the Docker image is built correctly with the right version of the GPU drivers and the Flyte backend is @@ -31,11 +31,7 @@ from torchvision import datasets, transforms # %% -# Let's define some variables to be used later. -WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) - -# %% -# The following variables are specific to ``wandb``: +# Let's define some variables to be used later. The following variables are specific to ``wandb``: # # - ``NUM_BATCHES_TO_LOG``: Number of batches to log from the test data for each test step # - ``LOG_IMAGES_PER_BATCH``: Number of images to log per test batch @@ -43,17 +39,14 @@ LOG_IMAGES_PER_BATCH = 32 # %% -# If running remotely, copy your ``wandb`` API key to the Dockerfile. Next, login to ``wandb``. -# You can disable this if you're already logged in on your local machine. -wandb.login() - -# %% -# Next, we initialize the ``wandb`` project. +# If running remotely, copy your ``wandb`` API key to the Dockerfile under the environment variable ``WANDB_API_KEY``. +# This function logs into ``wandb`` and initializes the project. If you built your Docker image with the +# ``WANDB_USERNAME``, this will work. Otherwise, replace ``my-user-name`` with your ``wandb`` user name. # -# .. admonition:: MUST DO! -# -# Replace ``entity`` value with your username. -wandb.init(project="mnist-single-node", entity="your-user-name") +# We'll call this function in the ``pytorch_mnist_task`` defined below. +def wandb_setup(): + wandb.login() + wandb.init(project="mnist-single-node-single-gpu", entity=os.environ.get("WANDB_USERNAME", "my-user-name")) # %% # Creating the Network @@ -81,6 +74,26 @@ def forward(self, x): return F.log_softmax(x, dim=1) +# %% +# The Data Loader +# =============== + +def mnist_dataloader(batch_size, train=True, **kwargs): + return torch.utils.data.DataLoader( + datasets.MNIST( + "./data", + train=train, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=batch_size, + shuffle=True, + **kwargs, + ) + + # %% # Training # ======== @@ -95,24 +108,12 @@ def train(model, device, train_loader, optimizer, epoch, log_interval): # loop through the training batches for batch_idx, (data, target) in enumerate(train_loader): - - # device conversion - data, target = data.to(device), target.to(device) - - # clear gradient - optimizer.zero_grad() - - # forward pass - output = model(data) - - # compute loss - loss = F.nll_loss(output, target) - - # propagate the loss backward - loss.backward() - - # update the model parameters - optimizer.step() + data, target = data.to(device), target.to(device) # device conversion + optimizer.zero_grad() # clear gradient + output = model(data) # forward pass + loss = F.nll_loss(output, target) # compute loss + loss.backward() # propagate the loss backward + optimizer.step() # update the model parameters if batch_idx % log_interval == 0: print( @@ -133,26 +134,19 @@ def train(model, device, train_loader, optimizer, epoch, log_interval): # We define a test logger function which will be called when we run the model on test dataset. def log_test_predictions(images, labels, outputs, predicted, my_table, log_counter): """ - Convenience funtion to log predictions for a batch of test images + Convenience function to log predictions for a batch of test images """ # obtain confidence scores for all classes scores = F.softmax(outputs.data, dim=1) - log_scores = scores.cpu().numpy() - log_images = images.cpu().numpy() - log_labels = labels.cpu().numpy() - log_preds = predicted.cpu().numpy() # assign ids based on the order of the images - _id = 0 - for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores): - - # add required info to data table: - # id, image pixels, model's guess, true label, scores for all classes - img_id = str(_id) + "_" + str(log_counter) - my_table.add_data(img_id, wandb.Image(i), p, l, *s) - _id += 1 - if _id == LOG_IMAGES_PER_BATCH: + for i, (image, pred, label, score) in enumerate( + zip(*[x.cpu().numpy() for x in (images, predicted, labels, scores)]) + ): + # add required info to data table: id, image pixels, model's guess, true label, scores for all classes + my_table.add_data(f"{i}_{log_counter}", wandb.Image(image), pred, label, *score) + if i == LOG_IMAGES_PER_BATCH: break @@ -186,21 +180,11 @@ def test(model, device, test_loader): # loop through the test data loader for images, targets in test_loader: - - # device conversion - images, targets = images.to(device), targets.to(device) - - # forward pass -- generate predictions - outputs = model(images) - - # sum up batch loss - test_loss += F.nll_loss(outputs, targets, reduction="sum").item() - - # get the index of the max log-probability - _, predicted = torch.max(outputs.data, 1) - - # compare predictions to true label - correct += (predicted == targets).sum().item() + images, targets = images.to(device), targets.to(device) # device conversion + outputs = model(images) # forward pass -- generate predictions + test_loss += F.nll_loss(outputs, targets, reduction="sum").item() # sum up batch loss + _, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability + correct += (predicted == targets).sum().item() # compare predictions to true label # log predictions to the ``wandb`` table if log_counter < NUM_BATCHES_TO_LOG: @@ -216,22 +200,11 @@ def test(model, device, test_loader): accuracy = float(correct) / len(test_loader.dataset) # log the average loss, accuracy, and table - wandb.log( - {"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table} - ) + wandb.log({"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table}) return accuracy -# %% -# Next, we define a function that runs for every epoch. It calls the ``train`` and ``test`` functions. -def epoch_step( - model, device, train_loader, test_loader, optimizer, epoch, log_interval -): - train(model, device, train_loader, optimizer, epoch, log_interval) - return test(model, device, test_loader) - - # %% # Hyperparameters # =============== @@ -242,14 +215,14 @@ def epoch_step( class Hyperparameters(object): """ Args: + backend: pytorch backend to use, e.g. "gloo" or "nccl" + sgd_momentum: SGD momentum (default: 0.5) + seed: random seed (default: 1) + log_interval: how many batches to wait before logging training status batch_size: input batch size for training (default: 64) test_batch_size: input batch size for testing (default: 1000) epochs: number of epochs to train (default: 10) learning_rate: learning rate (default: 0.01) - sgd_momentum: SGD momentum (default: 0.5) - seed: random seed (default: 1) - log_interval: how many batches to wait before logging training status - dir: directory where summary logs are stored """ backend: str = dist.Backend.GLOO @@ -281,13 +254,18 @@ class Hyperparameters(object): ) -@task(retries=2, cache=True, cache_version="1.0", requests=Resources(gpu="1")) -def train_mnist(hp: Hyperparameters) -> TrainingOutputs: +@task( + retries=2, + cache=True, + cache_version="1.0", + requests=Resources(gpu="1", mem="3Gi", storage="1Gi"), + limits=Resources(gpu="1", mem="3Gi", storage="1Gi"), +) +def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs: + wandb_setup() # store the hyperparameters' config in ``wandb`` - cfg = wandb.config - cfg.update(json.loads(hp.to_json())) - print(wandb.config) + wandb.config.update(json.loads(hp.to_json())) # set random seed torch.manual_seed(hp.seed) @@ -298,35 +276,10 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs: print(f"Use cuda {use_cuda}") device = torch.device("cuda" if use_cuda else "cpu") - print("Using device: {}, world size: {}".format(device, WORLD_SIZE)) - - # load Data + # load data kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} - training_data_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=True, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ), - ), - batch_size=hp.batch_size, - shuffle=True, - **kwargs, - ) - test_data_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "../data", - train=False, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ), - ), - batch_size=hp.test_batch_size, - shuffle=False, - **kwargs, - ) + training_data_loader = mnist_dataloader(hp.batch_size, train=True, **kwargs) + test_data_loader = mnist_dataloader(hp.batch_size, train=False, **kwargs) # train the model model = Net().to(device) @@ -336,18 +289,11 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs: ) # run multiple epochs and capture the accuracies for each epoch - accuracies = [ - epoch_step( - model, - device, - train_loader=training_data_loader, - test_loader=test_data_loader, - optimizer=optimizer, - epoch=epoch, - log_interval=hp.log_interval, - ) - for epoch in range(1, hp.epochs + 1) - ] + # train the model: run multiple epochs and capture the accuracies for each epoch + accuracies = [] + for epoch in range(1, hp.epochs + 1): + train(model, device, training_data_loader, optimizer, epoch, hp.log_interval) + accuracies.append(test(model, device, test_data_loader)) # after training the model, we can simply save it to disk and return it from the Flyte task as a :py:class:`flytekit.types.file.FlyteFile` # type, which is the ``PythonPickledFile``. ``PythonPickledFile`` is simply a decorator on the ``FlyteFile`` that records the format @@ -364,10 +310,9 @@ def train_mnist(hp: Hyperparameters) -> TrainingOutputs: # Finally, we define a workflow to run the training algorithm. We return the model and accuracies. @workflow def pytorch_training_wf( - hp: Hyperparameters, -) -> (PythonPickledFile, typing.List[float]): - accuracies, model = train_mnist(hp=hp) - return model, accuracies + hp: Hyperparameters = Hyperparameters(epochs=10, batch_size=128) +) -> TrainingOutputs: + return pytorch_mnist_task(hp=hp) # %% @@ -377,9 +322,7 @@ def pytorch_training_wf( # It is possible to run the model locally with almost no modifications (as long as the code takes care of resolving # if the code is distributed or not). This is how we can do it: if __name__ == "__main__": - model, accuracies = pytorch_training_wf( - hp=Hyperparameters(epochs=10, batch_size=128) - ) + model, accuracies = pytorch_training_wf(hp=Hyperparameters(epochs=10, batch_size=128)) print(f"Model: {model}, Accuracies: {accuracies}") # %% diff --git a/cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_multi_gpu.py b/cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_multi_gpu.py new file mode 100644 index 0000000000..7710c2cbcf --- /dev/null +++ b/cookbook/case_studies/ml_training/mnist_classifier/pytorch_single_node_multi_gpu.py @@ -0,0 +1,382 @@ +""" +Single Node, Multi GPU Training +-------------------------------- + +When you need to scale up model training in pytorch, you can use the :py:class:`~pytorch:torch.nn.DataParallel` for +single node, multi-gpu/cpu training or :py:class:`~pytorch:torch.nn.parallel.DistributedDataParallel` for multi-node, +multi-gpu training. + +This tutorial will cover how to write a simple training script on the MNIST dataset that uses +``DistributedDataParallel`` since its functionality is a superset of ``DataParallel``, supporting both single- and +multi-node training, and this is the `recommended way `__ +of distributing your training workload. Note, however, that this tutorial will only work for single-node, multi-gpu +settings. + +For training on a single node and gpu see +:ref:`this tutorial `, and for more +information on distributed training, check out the +`pytorch documentation `__. +""" + +# %% +# Import the required libraries. +import json +import os +import typing +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import wandb +from flytekit import Resources, task, workflow +from flytekit.types.file import PythonPickledFile +from torch import distributed as dist +from torch import nn, multiprocessing as mp, optim +from torchvision import datasets, transforms + +# %% +# We'll re-use certain classes and functions from the +# :ref:`single node and gpu tutorial ` +# such as the ``Net`` model architecture, ``Hyperparameters``, and ``log_test_predictions``. +from mnist_classifier.pytorch_single_node_and_gpu import Net, Hyperparameters, log_test_predictions + +# %% +# Let's define some variables to be used later. +# +# ``WORLD_SIZE`` defines the total number of GPUs we want to use to distribute our training job and ``DATA_DIR`` +# specifies where the downloaded data should be written to. +WORLD_SIZE = 2 +DATA_DIR = "./data" + +# %% +# The following variables are specific to ``wandb``: +# +# - ``NUM_BATCHES_TO_LOG``: Number of batches to log from the test data for each test step +# - ``LOG_IMAGES_PER_BATCH``: Number of images to log per test batch +NUM_BATCHES_TO_LOG = 10 +LOG_IMAGES_PER_BATCH = 32 + +# %% +# If running remotely, copy your ``wandb`` API key to the Dockerfile under the environment variable ``WANDB_API_KEY``. +# This function logs into ``wandb`` and initializes the project. If you built your Docker image with the +# ``WANDB_USERNAME``, this will work. Otherwise, replace ``my-user-name`` with your ``wandb`` user name. +# +# We'll call this function in the ``pytorch_mnist_task`` defined below. +def wandb_setup(): + wandb.login() + wandb.init(project="mnist-single-node-multi-gpu", entity=os.environ.get("WANDB_USERNAME", "my-user-name")) + +# %% +# Re-using the Network from the Single GPU Example +# ================================================ +# +# We'll use the same neural network architecture as the one we define in the +# :ref:`single node and gpu tutorial `. + + +# %% +# Data Downloader +# =============== +# +# We'll use this helper function to download the training and test sets before-hand to avoid race conditions when +# initializing the train and test dataloaders during training. +def download_mnist(data_dir): + for train in [True, False]: + datasets.MNIST(data_dir, train=train, download=True) + +# %% +# The Data Loader +# =============== +# +# This function will be called in the training function to be distributed across all available GPUs. Note that +# we set ``download=False`` here to avoid race conditions as mentioned above. +def mnist_dataloader(data_dir, batch_size, train=True, distributed=False, rank=None, world_size=None, **kwargs): + dataset = datasets.MNIST( + data_dir, + train=train, + download=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))]), + ) + if distributed: + assert rank is not None, "rank needs to be specified when doing distributed training." + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, rank=rank, num_replicas=1 if world_size is None else world_size, shuffle=True + ) + else: + sampler = None + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + sampler=sampler, + **kwargs, + ) + + +# %% +# Training +# ======== +# +# We define a ``train`` function to enclose the training loop per epoch, and we log the loss and epoch progression, +# which can later be visualized in a ``wandb`` dashboard. +def train(model, rank, train_loader, optimizer, epoch, log_interval): + model.train() + + # hooks into the model to collect gradients and the topology + if rank == 0: + wandb.watch(model) + + # loop through the training batches + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(rank), target.to(rank) # device conversion + optimizer.zero_grad() # clear gradient + output = model(data) # forward pass + loss = F.nll_loss(output, target) # compute loss + loss.backward() # propagate the loss backward + optimizer.step() # update the model parameters + + if rank == 0 and batch_idx % log_interval == 0: + # log epoch and loss + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + wandb.log({"loss": loss, "epoch": epoch}) + + +# %% +# Evaluation +# ========== +# +# We define a ``test`` function to test the model on the test dataset, logging ``accuracy``, and ``test_loss`` to a +# ``wandb`` `table `__, which helps us visualize the model's +# performance in a structured format. +def test(model, rank, test_loader): + + model.eval() + + # define ``wandb`` tabular columns and hooks into the model to collect gradients and the topology + columns = ["id", "image", "guess", "truth", *[f"score_{i}" for i in range(10)]] + if rank == 0: + my_table = wandb.Table(columns=columns) + wandb.watch(model) + + test_loss = 0 + correct = 0 + log_counter = 0 + + # disable gradient + with torch.no_grad(): + + # loop through the test data loader + total = 0. + for images, targets in test_loader: + total += len(targets) + images, targets = images.to(rank), targets.to(rank) # device conversion + outputs = model(images) # forward pass -- generate predictions + test_loss += F.nll_loss(outputs, targets, reduction="sum").item() # sum up batch loss + _, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability + correct += (predicted == targets).sum().item() # compare predictions to true label + + # log predictions to the ``wandb`` table + if log_counter < NUM_BATCHES_TO_LOG: + if rank == 0: + log_test_predictions(images, targets, outputs, predicted, my_table, log_counter) + log_counter += 1 + + # compute the average loss + test_loss /= total + accuracy = float(correct) / total + + if rank == 0: + print("\ntest_loss={:.4f}\naccuracy={:.4f}\n".format(test_loss, accuracy)) + # log the average loss, accuracy, and table + wandb.log({"test_loss": test_loss, "accuracy": accuracy, "mnist_predictions": my_table}) + + return accuracy + + +# %% +# Training and Evaluating +# ======================= + +TrainingOutputs = typing.NamedTuple( + "TrainingOutputs", + epoch_accuracies=typing.List[float], + model_state=PythonPickledFile, +) + + +# %% +# Setting up Distributed Training +# =============================== +# +# ``dist_setup`` is a helper function that instantiates a distributed environment. We're pointing all of the +# processes across all available GPUs to the address of the main process. + +def dist_setup(rank, world_size, backend): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "8888" + dist.init_process_group(backend, rank=rank, world_size=world_size) + + +# %% +# These global variables point to the location of where to save the model and validation accuracies. +MODEL_FILE = "./mnist_cnn.pt" +ACCURACIES_FILE = "./mnist_cnn_accuracies.json" + +# %% +# Then we define the ``train_mnist`` function. Note the conditionals that check for ``rank == 0``. These parts of the +# functions are only called in the main process, which is the ``0``th rank. The reason for this is that we only want the +# main process to perform certain actions such as: +# +# - log metrics via ``wandb`` +# - save the trained model to disk +# - keep track of validation metrics + +def train_mnist(rank: int, world_size: int, hp: Hyperparameters): + + # store the hyperparameters' config in ``wandb`` + if rank == 0: + wandb_setup() + wandb.config.update(json.loads(hp.to_json())) + + # set random seed + torch.manual_seed(hp.seed) + + use_cuda = torch.cuda.is_available() + print(f"Using distributed PyTorch with {hp.backend} backend") + print(f"Running MNIST training on rank {rank}, world size: {world_size}") + print(f"Use cuda: {use_cuda}") + dist_setup(rank, world_size, hp.backend) + print(f"Rank {rank + 1}/{world_size} process initialized.\n") + + # load data + kwargs = {"num_workers": 0, "pin_memory": True} if use_cuda else {} + print("Getting data loaders") + training_data_loader = mnist_dataloader( + DATA_DIR, hp.batch_size, train=True, distributed=use_cuda, rank=rank, world_size=world_size, **kwargs + ) + test_data_loader = mnist_dataloader(DATA_DIR, hp.test_batch_size, train=False, **kwargs) + + # define the distributed model and optimizer + print("Defining model") + model = Net().cuda(rank) + model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + + optimizer = optim.SGD(model.parameters(), lr=hp.learning_rate, momentum=hp.sgd_momentum) + + # train the model: run multiple epochs and capture the accuracies for each epoch + print(f"Training for {hp.epochs} epochs") + accuracies = [] + for epoch in range(1, hp.epochs + 1): + train(model, rank, training_data_loader, optimizer, epoch, hp.log_interval) + + # only compute validation metrics in the main process + if rank == 0: + accuracies.append(test(model, rank, test_data_loader)) + + # wait for the main process to complete validation before continuing the training process + dist.barrier() + + if rank == 0: + # tell wandb that we're done logging metrics + wandb.finish() + + # after training the model, we can simply save it to disk and return it from the Flyte + # task as a `flytekit.types.file.FlyteFile` type, which is the `PythonPickledFile`. + # `PythonPickledFile` is simply a decorator on the `FlyteFile` that records the format + # of the serialized model as `pickled` + print("Saving model") + torch.save(model.state_dict(), MODEL_FILE) + + # save epoch accuracies + print("Saving accuracies") + with open(ACCURACIES_FILE, "w") as fp: + json.dump(accuracies, fp) + + print(f"Rank {rank + 1}/{world_size} process complete.\n") + dist.destroy_process_group() # clean up + + +# %% +# The output model using :py:func:`pytorch:torch.save` saves the `state_dict` as described +# `in pytorch docs `_. +# A common convention is to have the ``.pt`` extension for the model file. +# +# .. note:: +# Note the usage of ``requests=Resources(gpu=WORLD_SIZE)``. This will force Flyte to allocate this task onto a +# machine with GPU(s), which in our case is 2 gpus. The task will be queued up until a machine with GPU(s) can be +# procured. Also, for the GPU Training to work, the Dockerfile needs to be built as explained in the +# :ref:`pytorch-dockerfile` section. + +# %% +# Defining the ``task`` +# ===================== +# +# Next we define the flyte task that kicks off the distributed training process. Here we call the +# pytorch :ref:`multiprocessing ` function to initiate a process on each +# available GPU. Since we're parallelizing the data, each process will contain a copy of the model and pytorch +# will handle syncing the weights across all processes on ``optimizer.step()`` calls. +# +# See `here `_ to read more about pytorch distributed +# training. + +@task( + retries=2, + cache=True, + cache_version="1.2", + requests=Resources(gpu=str(WORLD_SIZE), mem="30Gi", storage="20Gi", ephemeral_storage="500Mi"), + limits=Resources(gpu=str(WORLD_SIZE), mem="30Gi", storage="20Gi", ephemeral_storage="500Mi"), +) +def pytorch_mnist_task(hp: Hyperparameters) -> TrainingOutputs: + print("Start MNIST training:") + + world_size = torch.cuda.device_count() + print(f"Device count: {world_size}") + download_mnist(DATA_DIR) + mp.spawn( + train_mnist, + args=(world_size, hp), + nprocs=world_size, + join=True, + ) + print("Training Complete") + with open(ACCURACIES_FILE) as fp: + accuracies = json.load(fp) + return TrainingOutputs( + epoch_accuracies=accuracies, model_state=PythonPickledFile(MODEL_FILE) + ) + + + +# %% +# Finally, we define a workflow to run the training algorithm. We return the model and accuracies. +@workflow +def pytorch_training_wf(hp: Hyperparameters = Hyperparameters(epochs=10, batch_size=128)) -> TrainingOutputs: + return pytorch_mnist_task(hp=hp) + + +# %% +# Running the Model Locally +# ========================= +# +# It is possible to run the model locally with almost no modifications (as long as the code takes care of resolving +# if the code is distributed or not). This is how we can do it: +if __name__ == "__main__": + model, accuracies = pytorch_training_wf(hp=Hyperparameters(epochs=10, batch_size=128)) + print(f"Model: {model}, Accuracies: {accuracies}") + +# %% +# Weights & Biases Report +# ======================= +# +# You can refer to the complete ``wandb`` report `here `__. +# +# .. tip:: +# A lot more customizations can be done to the report according to your requirement!