When you have a big data set and a complicated machine learning problem, chances are that training your model takes a couple of days even on a modern GPU.

However, it is well-known that the cycle of having a new idea, implementing it and then verifying it should be as quick as possible. This is to ensure that you can efficiently test out new ideas.

If you need to wait for a whole week for your training run, this becomes very inefficient.

Luckily, we can parallelize the training to train on multiple GPUs and by doing so get big speedups.

This blog post will show you how to do this using PyTorch.

I will demonstrate on MNIST, but you can adapt this to your own problems.

The best solution in PyTorch is to use DistributedDataParallel. This is because this is the fastest solution and gives you the biggest speedup. There also exists the easier DataParallel class, but it has much more communication overhead and waiting times and thereby the speedup is reduced.

How does distributed training work?

The idea is that you copy the model to each GPU and each GPU gets complementary data in a batch. Each GPU then does a normal cycle of feedforward and the calculation of the gradients.

Then the gradients of all the GPUs are communicated to each other and they are averaged. The actual backpropagation is then done by each GPU on its local model with these averaged gradients, so the models can be slightly different, but in practice almost identical.

This is repeated over the course of many batches and episodes.

What do you need to do in PyTorch?

There are three things that you need to take care of:

  • You need to start a process for each GPU and setup the distributed process group of participating processes
  • Your machine learning model needs to be wrapped by DistributedDataParallel.
  • Your data sampler should be a DistributedSampler.

We will tackle all these things now!

Setting up the distributed processes

PyTorch provides us with the torch.distributed package which we can import as dist. We need to call the method init_process_group from each process which we start. All processes wait until world_size processes are actually started and have added themselves to the process_group.

For example, if you want to train on 4 GPUs in parallel, you will start 4 separate processes and each process indicates the world_size with 4. The difference for each process is the GPU on which the training will be run - the GPU index and the rank which is just an index to identify the process in the process group, so will be also ranging from 0-3 like the GPU index when training on 4 GPUs.

And this is how it looks like in code:

os.environ['MASTER_ADDR'] = ''
os.environ['MASTER_PORT'] = '8889'


We specify in the environment variables MASTER_ADDR and MASTER_PORT the IP and port where the synchronization should happen.

In this example I train on 4 GPUs on a single machine, so that is why the IP is localhost.

Using DistributedDataParallel for your machine learning model

This is an easy one:

model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu_index])

You simply wrap your existing model with the DistributedDataParallel class and provide the gpu index on which it should run - the gpu index will be different for each process we are starting. In this simple example, the gpu_index is always the same as the rank of the process.

Your data sampler should be a DistributedSampler

This is simple if you have no custom sampler as a DistributedSampler is provided by PyTorch:

train_sampler = torch.utils.data.distributed.DistributedSampler(

We are passing the training set, the number of processes we are running (should be the number of GPUs you want to train on) and the rank. The rank determines the unique index of the current process as described above.

When you have a custom sampler, you should reimplement it, so that it splits the data appropriately depending on the number of processes. See the DistributedSampler implementation and adapt accordingly or send me a message for details.

Starting your training

As mentioned, the process_group ensures that the processes which are started first wait for all other process until the specified world_size is reached and all processes registered themselves to the process_group.

Thus, there is no need to rush, you can simply start up one process after the next:

python run_multi_gpu_train --world_size 4 --rank 0
python run_multi_gpu_train --world_size 4 --rank 1
python run_multi_gpu_train --world_size 4 --rank 2
python run_multi_gpu_train --world_size 4 --rank 3

Full example code

Here is the full example code which you can run:

import os
from datetime import datetime
import argparse
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.distributed as dist

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7*7*32, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

def train(world_size, rank, num_epochs):

    # In this simple case, the GPU index equals the rank
    gpu_index = rank
    model = ConvNet()
    batch_size = 100
    criterion = nn.CrossEntropyLoss().cuda(gpu_index)
    lr = 1e-4 * world_size  # Larger world_size implies larger batches -> scale LR
    optimizer = torch.optim.SGD(model.parameters(), lr)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu_index])

    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, transform=transforms.ToTensor(), download=True
    train_sampler = torch.utils.data.distributed.DistributedSampler(

    train_loader = torch.utils.data.DataLoader(

    start = datetime.now()
    total_step = len(train_loader)
    num_total = 0
    correct = 0
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            num_total += labels.shape[0]
            correct += (torch.argmax(outputs, dim=1) == labels).sum().item()

            if (i + 1) % 100 == 0:
                print('Accuracy: {}'.format(correct / num_total))
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

    print("Training completed in: " + str(datetime.now() - start))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', default=1, type=int)
    parser.add_argument('--rank', default=0, type=int)
    parser.add_argument('--epochs', default=2, type=int)
    args = parser.parse_args()
    os.environ['MASTER_ADDR'] = ''
    os.environ['MASTER_PORT'] = '8889'
    train(args.world_size, args.rank, args.epochs)