The Dropout technique can be used for avoiding overfitting in your neural network. It has been around for some time and is widely available in a variety of neural network libraries. Let's take a look at how Dropout can be implemented with PyTorch.
In this article, you will learn...
Let's take a look! đ
In our article about the trade-off between bias and variance, it became clear that models can be high in bias or high in variance. Preferably, there is a balance between both.
To summarize that article briefly, models high in bias are relatively rigid. Linear models are a good example - they assume that your input data has a linear pattern. Models high in variance, however, do not make such assumptions -- but they are sensitive to changes in your training data.
As you can imagine, striking a balance between rigidity and sensitivity.
Dropout is related to the fact that deep neural networks have high variance. As you know when you have dealt with neural networks for a while, such models are sensitive to overfitting - capturing noise in the data as if it is part of the real function that must be modeled.
In their paper âDropout: A Simple Way to Prevent Neural Networks from Overfittingâ, Srivastava et al. (2014) describe Dropout, which is a technique that temporarily removes neurons from the neural network.
With Dropout, the training process essentially drops out neurons in a neural network.
When certain neurons are dropped, no data flows through them anymore. Dropout is modeled as Bernoulli variables, which are either zero (0) or one (1). They can be configured with a variable, \(p\), which illustrates the probability (between 0 and 1) with which neurons are dropped.
When neurons are dropped, they are not dropped permanently: instead, at every epoch (or even minibatch) the network randomly selects neurons that are dropped this time. Neurons that had been dropped before can be activated again during future iterations.
Now that we understand what Dropout is, we can take a look at how Dropout can be implemented with the PyTorch framework. For this example, we are using a basic example that models a Multilayer Perceptron. We will be applying it to the MNIST dataset (but note that Convolutional Neural Networks are more applicable, generally speaking, for image datasets).
In the example, you'll see that:
os
for Python operating system interfaces, torch
representing PyTorch, and a variety of sub components, such as its neural networks library (nn
), the MNIST
dataset, the DataLoader
for loading the data, and transforms
for a Tensor transform.MLP
class, which is a PyTorch neural network module (nn.Module
). Its constructor initializes the nn.Module
super class and then initializes a Sequential
network (i.e., a network where layers are stacked on top of each other). It begins by flattening the three-dimensional input (width, height, channels) into a one-dimensional input, then applies a Linear
layer (MLP layer), followed by Dropout, Rectified Linear Unit. This is then repeated once more, before we end with a final Linear
layer for the final multiclass prediction.forward
definition is a relatively standard PyTorch definition that must be included in a nn.Module
: it ensures that the forward pass of the network (i.e., when the data is fed to the network), is performed by feeding input data x
through the layers
defined in the constructor.main
check, a random seed is fixed, the dataset is loaded and prepared; the MLP, loss function and optimizer are initialized; then the model is trained. This is the classic PyTorch training loop: gradients are zeroed, a forward pass is performed, loss is computed and backpropagated through the network, and optimization is performed. Finally, after every iteration, statistics are printed.import os
import torch
from torch import nn
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
class MLP(nn.Module):
'''
Multilayer Perceptron.
'''
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28 * 1, 64),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Linear(64, 32),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Linear(32, 10)
)
def forward(self, x):
'''Forward pass'''
return self.layers(x)
if __name__ == '__main__':
# Set fixed random number seed
torch.manual_seed(42)
# Prepare CIFAR-10 dataset
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)
# Initialize the MLP
mlp = MLP()
# Define the loss function and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-4)
# Run the training loop
for epoch in range(0, 5): # 5 epochs at maximum
# Print epoch
print(f'Starting epoch {epoch+1}')
# Set current loss value
current_loss = 0.0
# Iterate over the DataLoader for training data
for i, data in enumerate(trainloader, 0):
# Get inputs
inputs, targets = data
# Zero the gradients
optimizer.zero_grad()
# Perform forward pass
outputs = mlp(inputs)
# Compute loss
loss = loss_function(outputs, targets)
# Perform backward pass
loss.backward()
# Perform optimization
optimizer.step()
# Print statistics
current_loss += loss.item()
if i % 500 == 499:
print('Loss after mini-batch %5d: %.3f' %
(i + 1, current_loss / 500))
current_loss = 0.0
# Process is complete.
print('Training process has finished.')
PyTorch. (n.d.). Dropout â PyTorch 1.9.0 documentation. https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html
MachineCurve. (2019, December 17). What is dropout? Reduce overfitting in your neural networks. https://www.machinecurve.com/index.php/2019/12/16/what-is-dropout-reduce-overfitting-in-your-neural-networks/
Learn how large language models and other foundation models are working and how you can train open source ones yourself.
Keras is a high-level API for TensorFlow. It is one of the most popular deep learning frameworks.
Read about the fundamentals of machine learning, deep learning and artificial intelligence.
To get in touch with me, please connect with me on LinkedIn. Make sure to write me a message saying hi!
The content on this website is written for educational purposes. In writing the articles, I have attempted to be as correct and precise as possible. Should you find any errors, please let me know by creating an issue or pull request in this GitHub repository.
All text on this website written by me is copyrighted and may not be used without prior permission. Creating citations using content from this website is allowed if a reference is added, including an URL reference to the referenced article.
If you have any questions or remarks, feel free to get in touch.
TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.
PyTorch, the PyTorch logo and any related marks are trademarks of The Linux Foundation.
Montserrat and Source Sans are fonts licensed under the SIL Open Font License version 1.1.
Mathjax is licensed under the Apache License, Version 2.0.