← Back to homepage

Using Dropout with PyTorch

July 7, 2021 by Chris

The Dropout technique can be used for avoiding overfitting in your neural network. It has been around for some time and is widely available in a variety of neural network libraries. Let's take a look at how Dropout can be implemented with PyTorch.

In this article, you will learn...

Let's take a look! 😎

Variance and overfitting

In our article about the trade-off between bias and variance, it became clear that models can be high in bias or high in variance. Preferably, there is a balance between both.

To summarize that article briefly, models high in bias are relatively rigid. Linear models are a good example - they assume that your input data has a linear pattern. Models high in variance, however, do not make such assumptions -- but they are sensitive to changes in your training data.

As you can imagine, striking a balance between rigidity and sensitivity.

Dropout is related to the fact that deep neural networks have high variance. As you know when you have dealt with neural networks for a while, such models are sensitive to overfitting - capturing noise in the data as if it is part of the real function that must be modeled.

Dropout

In their paper “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, Srivastava et al. (2014) describe Dropout, which is a technique that temporarily removes neurons from the neural network.

With Dropout, the training process essentially drops out neurons in a neural network.

What is Dropout? Reduce overfitting in your neural networks

When certain neurons are dropped, no data flows through them anymore. Dropout is modeled as Bernoulli variables, which are either zero (0) or one (1). They can be configured with a variable, \(p\), which illustrates the probability (between 0 and 1) with which neurons are dropped.

When neurons are dropped, they are not dropped permanently: instead, at every epoch (or even minibatch) the network randomly selects neurons that are dropped this time. Neurons that had been dropped before can be activated again during future iterations.

Using Dropout with PyTorch: full example

Now that we understand what Dropout is, we can take a look at how Dropout can be implemented with the PyTorch framework. For this example, we are using a basic example that models a Multilayer Perceptron. We will be applying it to the MNIST dataset (but note that Convolutional Neural Networks are more applicable, generally speaking, for image datasets).

In the example, you'll see that:

import os
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms

class MLP(nn.Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Flatten(),
      nn.Linear(28 * 28 * 1, 64),      
      nn.Dropout(p=0.5),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.Dropout(p=0.5),
      nn.ReLU(),
      nn.Linear(32, 10)
    )


  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)


if __name__ == '__main__':

  # Set fixed random number seed
  torch.manual_seed(42)

  # Prepare CIFAR-10 dataset
  dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
  trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)

  # Initialize the MLP
  mlp = MLP()

  # Define the loss function and optimizer
  loss_function = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)

  # Run the training loop
  for epoch in range(0, 5): # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch+1}')

    # Set current loss value
    current_loss = 0.0

    # Iterate over the DataLoader for training data
    for i, data in enumerate(trainloader, 0):

      # Get inputs
      inputs, targets = data

      # Zero the gradients
      optimizer.zero_grad()

      # Perform forward pass
      outputs = mlp(inputs)

      # Compute loss
      loss = loss_function(outputs, targets)

      # Perform backward pass
      loss.backward()

      # Perform optimization
      optimizer.step()

      # Print statistics
      current_loss += loss.item()
      if i % 500 == 499:
          print('Loss after mini-batch %5d: %.3f' %
                (i + 1, current_loss / 500))
          current_loss = 0.0

  # Process is complete.
  print('Training process has finished.')

References

PyTorch. (n.d.). Dropout — PyTorch 1.9.0 documentation. https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html

MachineCurve. (2019, December 17). What is dropout? Reduce overfitting in your neural networks. https://www.machinecurve.com/index.php/2019/12/16/what-is-dropout-reduce-overfitting-in-your-neural-networks/

Hi, I'm Chris!

I know a thing or two about AI and machine learning. Welcome to MachineCurve.com, where machine learning is explained in gentle terms.