Cropping layers with PyTorch

Cropping layers with PyTorch

Last Updated on 10 November 2021

Sometimes, you may wish to perform cropping on the input images that you are feeding to your neural network. While strictly speaking a part of data processing in many cases, it can be interesting to move cropping your input data to the neural network itself, because then you might not need to adapt a full dataset in advance.

In TensorFlow and Keras, cropping your input data is relatively easy, using the Cropping layers readily available there.

In PyTorch, this is different, because Cropping layers are not part of the PyTorch API.

In this article, you will learn how you can perform Cropping within PyTorch anyway – by using the ZeroPad2d layer, which performs zero padding. By using it in an inverse way, we can remove padding (and hence perform cropping) instead of adding it.

Ready? Let’s take a look. 😎



Using ZeroPad2d for Cropping

For creating our Cropping layer, we will be using the ZeroPad2d layer that is available within PyTorch.

Normally, it’s used for adding a box of pixels around the input data – which is what padding does. In that case, it’s used with positive padding. In the image below, on the left, you can see what happens when it’s called with a +1 padding – an extra box of zero-valued pixels is added around the input image.

Now, what if we used a -1 padding instead? You would expect that padding then works in the opposite direction, meaning that a box is not added, but removed. And precisely this effect is what we will use for creating a Cropping layer for your PyTorch model.

Calling Zero Padding with a positive padding results in a zero-valued box of pixels being added to your input image. Using a negative padding removes data from your image.

Full Cropping layers example using PyTorch

Let’s now take a look at how we can implement ZeroPad2d for generating a Cropping layer with PyTorch. First, it’s time to write down our imports.

import os import torch from torch import nn from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt
Code language: Python (python)

These are relatively straight-forward: there are many torch related imports, which are explained in our articles on PyTorch based networks such as the ConvNet.

Time to move forward with the CroppingNetwork. Here it is:

class CroppingNetwork(nn.Module): ''' Simple network with one Cropping layer ''' def __init__(self): super().__init__() self.layers = nn.Sequential( nn.ZeroPad2d(-1), nn.ZeroPad2d(-1), nn.ZeroPad2d(-1), nn.ZeroPad2d(-1), ) def forward(self, x): '''Forward pass''' return self.layers(x)
Code language: Python (python)

It is actually really simple! By specifying nn.ZeroPad2d with a cropping size of -1, we remove 1 column of pixels on the left, 1 on the right, as well as a row from the top and the bottom of the image.

Our input images – MNIST images – have an input shape of (1, 28, 28) – or (28, 28) when we reshape them. Since we repeat the layer four times, we remove 4 pixels from the left, 4 from the right, 4 from the top and 4 from the bottom. This means that the shape of our outputs will be (20, 20).

What remains is stitching everything together:

Let's pause for a second! 👩‍💻

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!

By signing up, you consent that any information you receive can include services and special offers by email.
if __name__ == '__main__': # Set fixed random number seed torch.manual_seed(42) # Prepare CIFAR-10 dataset dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1) # Initialize the CroppingNetwork croppingnet = CroppingNetwork() # Iterate over some samples for i, data in enumerate(trainloader, 0): # Unpack inputs and targets inputs, targets = data # Feed samples through the network cropped_samples = croppingnet(inputs) # Reshape the samples reshaped_original = inputs[i].reshape(28, 28) reshaped_cropped = cropped_samples[i].reshape(20, 20) fig, (ax1, ax2) = plt.subplots(1, 2) fig.set_size_inches(9, 5, forward=True) fig.suptitle('Original sample (left) and Cropped sample (right)') ax1.imshow(reshaped_original) ax2.imshow(reshaped_cropped) plt.show()
Code language: Python (python)

The code above uses the PyTorch DataLoader for loading the first minibatch of samples, feeds them through the CroppingNetwork, and visualizes the results.


Examples of PyTorch cropping layers

And here they are – some examples of what is produced by the cropping network:


References

PyTorch. (n.d.). ZeroPad2d — PyTorch 1.10.0 documentationhttps://pytorch.org/docs/stable/generated/torch.nn.ZeroPad2d.html

Do you want to start learning ML from a developer perspective? 👩‍💻

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to learn new things and better understand concepts you already know. We send emails every Friday.

By signing up, you consent that any information you receive can include services and special offers by email.

Leave a Reply

Your email address will not be published. Required fields are marked *