# How to create a variational autoencoder with Keras?

Last Updated on 17 August 2020

In a different blog post, we studied the concept of a Variational Autoencoder (or VAE) in detail. The models, which are generative, can be used to manipulate datasets by learning the distribution of this input data.

But there’s a difference between theory and practice. While it’s always nice to understand neural networks in theory, it’s always even more fun to actually create them with a particular framework. It makes them really usable.

Today, we’ll use the Keras deep learning framework to create a convolutional variational autoencoder. We subsequently train it on the MNIST dataset, and also show you what our latent space looks like as well as new samples generated from the latent space.

But first, let’s take a look at what VAEs are.

Let’s go! 😎

Update 17/08/2020: added a fix for an issue with vae.fit().

## Recap: what are variational autoencoders?

If you are already familiar with variational autoencoders or wish to find the implementation straight away, I’d suggest to skip this section. In any other case, it may be worth the read.

### How do VAEs work?

Contrary to a normal autoencoder, which learns to encode some input into a point in latent space, Variational Autoencoders (VAEs) learn to encode multivariate probability distributions into latent space, given their configuration usually Gaussian ones:

Sampling from the distribution gives a point in latent space that, given the distribution, is oriented around some mean value $$\mu$$ and standard deviation $$\sigma$$, like the points in this two-dimensional distribution:

Combining this with a Kullback-Leibler divergence segment in the loss function leads to a latent space that is both continuous and complete: for every point sampled close to the distribution’s mean and standard deviation (which is in our case the standard normal distribution) the output should be both similar to samples around that sample and should make sense.

### What can you do with VAEs?

Besides the regular stuff one can do with an autoencoder (like denoising and dimensionality reduction), the principles of a VAE outlined above allow us to use variational autoencoders for generative purposes.

I would really recommend my blog “What is a Variational Autoencoder (VAE)?” if you are interested in understanding VAEs in more detail. However, based on the high-level recap above, I hope that you now both understand (1) how VAEs work at a high level and (2) what this allows you to do with them: using them for generative purposes.

Let’s now take a look at how we will use VAEs today 😊

## Let's pause for a second! 👩‍💻

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!
By signing up, you consent that any information you receive can include services and special offers by email.

## Creating a VAE with Keras

### What we’ll create today

Today, we’ll use the Keras deep learning framework for creating a VAE. It consists of three individual parts: the encoder, the decoder and the VAE as a whole. We do so using the Keras Functional API, which allows us to combine layers very easily.

The MNIST dataset will be used for training the autoencoder. This dataset contains thousands of 28 x 28 pixel images of handwritten digits, as we can see below. As such, our autoencoder will learn the distribution of handwritten digits across (two)dimensional latent space, which we can then use to manipulate samples into a format we like.

This is the structure of the encoder:

.wp-block-code{border:0;padding:0;}.wp-block-code > div{overflow:auto;}.shcb-language{border:0;clip:rect(1px,1px,1px,1px);-webkit-clip-path:inset(50%);clip-path:inset(50%);height:1px;margin:-1px;overflow:hidden;padding:0;position:absolute;width:1px;word-wrap:normal;word-break:normal;}.hljs{box-sizing:border-box;}.hljs.shcb-code-table{display:table;width:100%;}.hljs.shcb-code-table > .shcb-loc{color:inherit;display:table-row;width:100%;}.hljs.shcb-code-table .shcb-loc > span{display:table-cell;}.wp-block-code code.hljs:not(.shcb-wrap-lines){white-space:pre;}.wp-block-code code.hljs.shcb-wrap-lines{white-space:pre-wrap;}.hljs.shcb-line-numbers{border-spacing:0;counter-reset:line;}.hljs.shcb-line-numbers > .shcb-loc{counter-increment:line;}.hljs.shcb-line-numbers .shcb-loc > span{padding-left:.75em;}.hljs.shcb-line-numbers .shcb-loc::before{border-right:1px solid #ddd;content:counter(line);display:table-cell;padding:0 .75em;text-align:right;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;white-space:nowrap;width:1%;}Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
encoder_input (InputLayer)      (None, 28, 28, 1)    0
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 8)    80          encoder_input[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 14, 14, 8)    32          conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 7, 7, 16)     1168        batch_normalization_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 7, 7, 16)     64          conv2d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 784)          0           batch_normalization_2[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 20)           15700       flatten_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 20)           80          dense_1[0][0]
__________________________________________________________________________________________________
latent_mu (Dense)               (None, 2)            42          batch_normalization_3[0][0]
__________________________________________________________________________________________________
latent_sigma (Dense)            (None, 2)            42          batch_normalization_3[0][0]
__________________________________________________________________________________________________
z (Lambda)                      (None, 2)            0           latent_mu[0][0]
latent_sigma[0][0]
==================================================================================================
Total params: 17,208
Trainable params: 17,120
Non-trainable params: 88Code language: PHP (php)

And the decoder:

__________________________________________________________________________________________________
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
decoder_input (InputLayer)   (None, 2)                 0
_________________________________________________________________
dense_2 (Dense)              (None, 784)               2352
_________________________________________________________________
batch_normalization_4 (Batch (None, 784)               3136
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 16)          0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 16)        2320
_________________________________________________________________
batch_normalization_5 (Batch (None, 14, 14, 16)        64
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 8)         1160
_________________________________________________________________
batch_normalization_6 (Batch (None, 28, 28, 8)         32
_________________________________________________________________
decoder_output (Conv2DTransp (None, 28, 28, 1)         73
=================================================================
Total params: 9,137
Trainable params: 7,521
Non-trainable params: 1,616Code language: PHP (php)

And, finally, the VAE as a whole:

_________________________________________________________________
Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
encoder_input (InputLayer)   (None, 28, 28, 1)         0
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 17208
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         9137
=================================================================
Total params: 26,345
Trainable params: 24,641
Non-trainable params: 1,704Code language: PHP (php)

From the final summary, we can see that indeed, the VAE takes in samples of shape $$(28, 28, 1)$$ and returns samples in the same format. Great! 😊

Let’s now start working on our model. Open up your Explorer / Finder, navigate to some folder, and create a new Python file, e.g. variational_autoencoder.py. Now, open this file in your code editor, and let’s start coding! 😎

### What you’ll need to run the model

Before we begin, it’s important that you ensure that you have all the required dependencies installed on your system:

• First of all, you’ll need the Keras deep learning framework, with which we are creating the VAE.
• It’s best if you used the Tensorflow backend (on top of which Keras can run). However, Theano and CNTK work as well (for Python).
• By consequence, it’s preferred if you run Keras with Python, version 3.6+.
• You’ll also need Numpy, for number processing, and Matplotlib, for visualization purposes.

### Model imports

Let’s now declare everything that we will import:

• Keras, obviously.
• From Keras Layers, we’ll need convolutional layers and transposed convolutions, which we’ll use for the autoencoder. Given our usage of the Functional API, we also need Input, Lambda and Reshape, as well as Dense and Flatten.
• We’ll import BatchNormalization as well to ensure that the mean and variance of our layer’s inputs remains close to (0, 1) during training. This benefits the training process.
• We’ll import the Model container from keras.models. This allows us to instantiate the models eventually.
• The mnist dataset is the dataset we’ll be training our VAE on.
• With binary_crossentropy, we can compute reconstruction loss.
• Our backend (K) contains calls for tensor manipulations, which we’ll use.
• Numpy is used for number processing and Matplotlib for plotting the visualizations on screen.

This is the code that includes our imports:

'''
Variational Autoencoder (VAE) with the Keras Functional API.
'''

import keras
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
import matplotlib.pyplot as pltCode language: PHP (php)

Next thing: importing the MNIST dataset. Since MNIST is part of the Keras Datasets, we can import it easily – by calling mnist.load_data(). Love the Keras simplicity!

# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()Code language: PHP (php)

### Model configuration

Importing the data is followed by setting config parameters for data and model.

# Data & model configuration
img_width, img_height = input_train.shape[1], input_train.shape[2]
batch_size = 128
no_epochs = 100
validation_split = 0.2
verbosity = 1
latent_dim = 2
num_channels = 1Code language: PHP (php)

The width and height of our configuration settings is determined by the training data. In our case, they will be img_width = img_height = 28, as the MNIST dataset contains samples that are 28 x 28 pixels.

Batch size is set to 128 samples per (mini)batch, which is quite normal. The same is true for the number of epochs, which was set to 100. 20% of the training data is used for validation purposes. This is also quite normal. Nothing special here.

Verbosity mode is set to True (by means of 1), which means that all the output is shown on screen.

## Never miss new Machine Learning articles ✅

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!
By signing up, you consent that any information you receive can include services and special offers by email.

The final two configuration settings are of relatively more interest. First, the latent space will be two-dimensional – which means that a significant information bottleneck will be created which should yield good results with autoencoders on relatively simple datasets. Finally, the num_channels parameter can be configured to equal the number of image channels: for RGB data, it’s 3 (red – green – blue), and for grayscale data (such as MNIST), it’s 1.

### Data preprocessing

The next thing is data preprocessing:

# Reshape data
input_train = input_train.reshape(input_train.shape[0], img_height, img_width, num_channels)
input_test = input_test.reshape(input_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)

# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')

# Normalize data
input_train = input_train / 255
input_test = input_test / 255Code language: PHP (php)

First, we reshape the data so that it takes the shape (X, 28, 28, 1), where X is the number of samples in either the training or testing dataset. We also set (28, 28, 1) as input_shape.

Next, we parse the numbers as floats, which presumably speeds up the training process, and normalize it, which the neural network appreciates. And that’s it already for data preprocessing 🙂

### Creating the encoder

Now, it’s time to create the encoder. This is a three-step process: firstly, we define it. Secondly, we perform something that is known as the reparameterization trick in order to allow us to link the encoder to the decoder later, to instantiate the VAE as a whole. But before that, we instantiate the encoder first, as our third and final step.

#### Encoder definition

The first step in the three-step process is the definition of our encoder. Following the connection process of the Keras Functional API, we link the layers together:

# # =================
# # Encoder
# # =================

# Definition
i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(20, activation='relu')(x)
x       = BatchNormalization()(x)
mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)Code language: PHP (php)

Let’s now take a look at the individual lines of code in more detail.

• The first layer is the Input layer. It accepts data with input_shape = (28, 28, 1) and is named encoder_input. It’s actually a pretty dumb layer, haha 😂
• Next up is a two-dimensional convolutional layer, or Conv2D in Keras terms. It learns 8 filters by deploying a 3 x 3 kernel which it convolves over the input. It has a stride of two which means that it skips over the input during the convolution as well, speeding up the learning process. It employs ‘same’ padding and ReLU activation. Do note that officially, it’s best to use He init with ReLU activating layers. However, since the dataset is relatively small, it shouldn’t be too much of a problem if you don’t.
• Subsequently, we use Batch Normalization. This layer ensures that the outputs of the Conv2D layer that are input to the next Conv2D layer have a steady mean and variance, likely $$\mu = 0.0, \sigma = 1.0$$ (plus some $$\epsilon$$, an error term to ensure numerical stability). This benefits the learning process.
• Once again, a Conv2D layer. It learns 16 filters and for the rest is equal to the first Conv2D layer.
• BatchNormalization once more.
• Next up, a Flatten layer. It’s a relatively dumb layer too, and only serves to flatten the multidimensional data from the convolutional layers into one-dimensional shape. This has to be done because the densely-connected layers that we use next require data to have this shape.
• The next layer is a Dense layer with 20 output neurons. It’s the autoencoder bottleneck we’ve been talking about.
• BatchNormalization once more.
• The next two layers, mu and sigma, are actually not separate from each other – look at the previous layer they are linked to (both x, i.e. the Dense(20) layer). The first outputs the mean values $$\mu$$ of the encoded input and the second one outputs the stddevs $$\sigma$$. With these, we can sample the random variables that constitute the point in latent space onto which some input is mapped.

That’s for the layers of our encoder 😄 The next step is to retrieve the shape of the final Conv2D output:

# Get Conv2D shape for Conv2DTranspose operation in decoder
conv_shape = K.int_shape(cx)Code language: PHP (php)

We’ll need it when defining the layers of our decoder. I won’t bother you with the details yet, as they are best explained when we’re a bit further down the road. However, just remember to come back here if you wonder why we need some conv_shape value in the decoder, okay? 😉

Let’s now take a look at the second part of our encoder segment: the reparameterization trick.

#### Reparameterization trick

While for a mathematically sound explanation of the so-called “reparameterization trick” introduced to VAEs by Kingma & Welling (2013) I must refer you to Gregory Gunderson’s “The Reparameterization Trick”, I’ll try to explain the need for reparameritization briefly.

If you use neural networks (or, to be more precise, gradient descent) for optimizing the variational autoencoder, you effectively minimize some expected loss value, which can be estimated with Monte-Carlo techniques (Huang, n.d.). However, this requires that the loss function is differentiable, which is not necessarily the case, because it is dependent on the parameter of some probability distribution that we don’t know about. In this case, it’s possible to rewrite the equation, but then it no longer has the form of an expectation, making it impossible to use the Monte-Carlo techniques usable before.

However, if we can reparameterize the sample fed to the function into the shape $$\mu + \sigma^2 \times \epsilon$$, it now becomes possible to use gradient descent for estimating the gradients accurately (Gunderson, n.d.; Huang, n.d.).

And that’s precisely what we’ll do in our code. We “sample” the value for $$z$$ from the computed $$\mu$$ and $$\sigma$$ values by resampling into mu + K.exp(sigma / 2) * eps.

# Define sampling with reparameterization trick
def sample_z(args):
mu, sigma = args
batch     = K.shape(mu)[0]
dim       = K.int_shape(mu)[1]
eps       = K.random_normal(shape=(batch, dim))
return mu + K.exp(sigma / 2) * epsCode language: PHP (php)

We then use this with a Lambda to ensure that correct gradients are computed during the backwards pass based on our values for mu and sigma:

## Join hundreds of other learners! 😎

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!
By signing up, you consent that any information you receive can include services and special offers by email.
# Use reparameterization trick to ensure correct gradient
z       = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])Code language: PHP (php)

#### Encoder instantiation

Now, it’s time to instantiate the encoder – taking inputs through input layer i, and outputting the values generated by the mu, sigma and z layers (i.e., the individual means and standard deviations, and the point sampled from the random variable represented by them):

# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name='encoder')
encoder.summary()Code language: PHP (php)

Now that we’ve got the encoder, it’s time to start working on the decoder 🙂

### Creating the decoder

Creating the decoder is a bit simpler and boils down to a two-step process: defining it, and instantiating it.

#### Decoder definition

Firstly, we’ll define the layers of our decoder – just as we’ve done when defining the structure of our encoder.

# =================
# Decoder
# =================

# Definition
d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx    = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same',  activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)
Code language: PHP (php)
• Our decoder also starts with an Input layer, the decoder_input layer. It takes input with the shape (latent_dim, ), which as we will see is the vector we sampled for z with our encoder.
• If we’d like to upsample the point in latent space with Conv2DTranspose layers, in exactly the opposite symmetrical order as with we downsampled with our encoder, we must first bring back the data from shape (latent_dim, ) into some shape that can be reshaped into the output shape of the last convolutional layer of our encoder.
• This is why you needed the conv_shape variable. We’ll thus now add a Dense layer which has conv_shape[1] * conv_shape[2] * conv_shape[3] output, and converts the latent space into many outputs.
• We next use a Reshape layer to convert the output of the Dense layer into the output shape of the last convolutional layer: (conv_shape[1], conv_shape[2], conv_shape[3] = (7, 7, 16). Sixteen filters learnt with 7 x 7 pixels per filter.
• We then use Conv2DTranspose and BatchNormalization in the exact opposite order as with our encoder to upsample our data into 28 x 28 pixels (which is equal to the width and height of our inputs). However, we still have 8 filters, so the shape so far is (28, 28, 8).
• We therefore add a final Conv2DTranspose layer which does nothing to the width and height of the data, but ensures that the number of filters learns equals num_channels. For MNIST data, where num_channels = 1, this means that the shape of our output will be (28, 28, 1), just as it has to be 🙂 This last layer also uses Sigmoid activation, which allows us to use binary crossentropy loss when computing the reconstruction loss part of our loss function.

#### Decoder instantiation

The next thing we do is instantiate the decoder:

# Instantiate decoder
decoder = Model(d_i, o, name='decoder')
decoder.summary()Code language: PHP (php)

It takes the inputs from the decoder input layer d_i and outputs whatever is output by the output layer o. Simple 🙂

### Creating the whole VAE

Now that the encoder and decoder are complete, we can create the VAE as a whole:

# =================
# VAE as a whole
# =================

# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae         = Model(i, vae_outputs, name='vae')
vae.summary()Code language: PHP (php)

If you think about it, the outputs of the entire VAE are the original inputs, encoded by the encoder, and decoded by the decoder.

That’s how we arrive at vae_outputs = decoder(encoder(i)[2]): inputs i are encoded by the encoder into [mu, sigma, z] (the individual means and standard deviations with the sampled z as well). We then take the sampled z values (hence the [2]) and feed it to the decoder, which ensures that we arrive at correct VAE output.

We the instantiate the model: i are our inputs indeed, and vae_outputs are the outputs. We call the model vae, because it simply is.

### Defining custom VAE loss function

Now that we have defined our model, we can proceed with model configuration. Usually, with neural networks, this is done with model.compile, where a loss function is specified such as binary crossentropy. However, when we look at how VAEs are optimized, we see that it’s not a simple loss function that is used: we use reconstruction loss (in our case, binary crossentropy loss) together with KL divergence loss to ensure that our latent space is both continuous and complete.

We define it as follows:

# Define loss
def kl_reconstruction_loss(true, pred):
# Reconstruction loss
reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
# KL divergence loss
kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
# Total loss = 50% rec + 50% KL divergence loss
return K.mean(reconstruction_loss + kl_loss)Code language: PHP (php)
• Our reconstruction_loss is the binary crossentropy value computed for the flattened true values (representing our targets, i.e. our ground truth) and the pred prediction values generated by our VAE. It’s multiplied with img_width and img_height to reduce the impact of flattening.
• Our KL divergence loss can be rewritten in the formula defined above (Wiseodd, 2016).
• We use 50% reconstruction loss and 50% KL divergence loss, and do so by returning the mean value between the two.

### Compilation & training

Now that we have defined our custom loss function, we can compile our model. We do so using the Adam optimizer and our kl_reconstruction_loss custom loss function.

# Compile VAE

# Train autoencoder
vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)
Code language: PHP (php)

Once compiled, we can call vae.fit to start the training process. Note that we set input_train both as our features and targets, as is usual with autoencoders. For the rest, we configure the training process as defined previously, in the model configuration step.

## Visualizing VAE results

Even though you can now actually train your VAE, it’s best to wait just a bit more – because we’ll add some code for visualization purposes:

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!
By signing up, you consent that any information you receive can include services and special offers by email.
• We will visualize our test set inputs mapped onto the latent space. This allows us to check the continuity and completeness of our latent space.
• We will also visualize an uniform walk across latent space to see how sampling from it will result in output that actually makes sense. This is actually the end result we’d love to see 🙂

Some credits first, though: the code for the two visualizers was originally created (and found by me) in the Keras Docs, at the link here, as well as in François Chollet’s blog post, here. All credits for the original ideas go to the authors of these articles. I made some adaptations to the code to accomodate for this blog post:

• First of all, I split the visualizers into two separate definitions. Originally, there was one definition, that generated both visualizations. However, I think that having them separated gives you more flexibility.
• Additionally, I ensured that multi-channeled data can be visualized as well. The original code was created specifically for MNIST, which is only one-channel. RGB datasets, such as CIFAR10, are three-dimensional. This required some extra code to make it work based on the autoencoder we created before.

### Visualizing inputs mapped onto latent space

Visualizing inputs mapped onto the latent space is simply taking some input data, feeding it to the encoder, taking the mean values $$\mu$$ for the predictions, and plotting them in a scatter plot:

# =================
# Results visualization
# Credits for original visualization code: https://keras.io/examples/variational_autoencoder_deconv/
# (François Chollet).
# Adapted to accomodate this VAE.
# =================
def viz_latent_space(encoder, data):
input_data, target_data = data
mu, _, _ = encoder.predict(input_data)
plt.figure(figsize=(8, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=target_data)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
plt.colorbar()
plt.show()Code language: PHP (php)

### Visualizing samples from the latent space

Visualizing samples from the latent space entails a bit more work. First, we’ll have to create a figure filled with zeros, as well as a linear space around $$(\mu = 0, \sigma = 1)$$ we can iterate over (from $$domain = range = [-4, +4]$$). We take a sample from the grid (determined by our current $$x$$ and $$y$$ positions) and feed it to the decoder. We then replace the zeros in our figure with the output, and finally plot the entire figure on screen. This includes reshaping one-dimensional (i.e., grayscale) input if necessary.


def viz_decoded(encoder, decoder, data):
num_samples = 15
figure = np.zeros((img_width * num_samples, img_height * num_samples, num_channels))
grid_x = np.linspace(-4, 4, num_samples)
grid_y = np.linspace(-4, 4, num_samples)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(img_width, img_height, num_channels)
figure[i * img_width: (i + 1) * img_width,
j * img_height: (j + 1) * img_height] = digit
plt.figure(figsize=(10, 10))
start_range = img_width // 2
end_range = num_samples * img_width + start_range + 1
pixel_range = np.arange(start_range, end_range, img_width)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
# matplotlib.pyplot.imshow() needs a 2D array, or a 3D array with the third dimension being of shape 3 or 4!
# So reshape if necessary
fig_shape = np.shape(figure)
if fig_shape[2] == 1:
figure = figure.reshape((fig_shape[0], fig_shape[1]))
# Show image
plt.imshow(figure)
plt.show()Code language: PHP (php)

### Calling the visualizers

Using the visualizers is however much easier:

# Plot results
data = (input_test, target_test)
viz_latent_space(encoder, data)
viz_decoded(encoder, decoder, data)Code language: PHP (php)

## Time to run it!

Let’s now run our model. Open up a terminal which has access to all the required dependencies, cd to the folder where your Python file is located, and run it, e.g. python variational_autoencoder.py.

The training process should now begin with some visualizations being output after it finishes! 🙂

### If you get an error with vae.fit()

Marc, one of our readers, reported an issue with the model when running the VAE with TensorFlow 2.3.0 (and possibly also newer versions): https://github.com/tensorflow/probability/issues/519

By adding the following line of code, this issue can be resolved:

tf.config.experimental_run_functions_eagerly(True)Code language: CSS (css)

## Full VAE code

Even though I would recommend to read the entire post first before you start playing with the code (because the structures are intrinsically linked), it may be that you wish to take the full code and start fiddling right away. In this case, having the full code at once may be worthwhile to you, so here you go 😊

'''
Variational Autoencoder (VAE) with the Keras Functional API.
'''

import keras
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape
from keras.layers import BatchNormalization
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt

(input_train, target_train), (input_test, target_test) = mnist.load_data()

# Data & model configuration
img_width, img_height = input_train.shape[1], input_train.shape[2]
batch_size = 128
no_epochs = 100
validation_split = 0.2
verbosity = 1
latent_dim = 2
num_channels = 1

# Reshape data
input_train = input_train.reshape(input_train.shape[0], img_height, img_width, num_channels)
input_test = input_test.reshape(input_test.shape[0], img_height, img_width, num_channels)
input_shape = (img_height, img_width, num_channels)

# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')

# Normalize data
input_train = input_train / 255
input_test = input_test / 255

# # =================
# # Encoder
# # =================

# Definition
i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(20, activation='relu')(x)
x       = BatchNormalization()(x)
mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)

# Get Conv2D shape for Conv2DTranspose operation in decoder
conv_shape = K.int_shape(cx)

# Define sampling with reparameterization trick
def sample_z(args):
mu, sigma = args
batch     = K.shape(mu)[0]
dim       = K.int_shape(mu)[1]
eps       = K.random_normal(shape=(batch, dim))
return mu + K.exp(sigma / 2) * eps

# Use reparameterization trick to ....??
z       = Lambda(sample_z, output_shape=(latent_dim, ), name='z')([mu, sigma])

# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name='encoder')
encoder.summary()

# =================
# Decoder
# =================

# Definition
d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx    = Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same',  activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)

# Instantiate decoder
decoder = Model(d_i, o, name='decoder')
decoder.summary()

# =================
# VAE as a whole
# =================

# Instantiate VAE
vae_outputs = decoder(encoder(i)[2])
vae         = Model(i, vae_outputs, name='vae')
vae.summary()

# Define loss
def kl_reconstruction_loss(true, pred):
# Reconstruction loss
reconstruction_loss = binary_crossentropy(K.flatten(true), K.flatten(pred)) * img_width * img_height
# KL divergence loss
kl_loss = 1 + sigma - K.square(mu) - K.exp(sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
# Total loss = 50% rec + 50% KL divergence loss
return K.mean(reconstruction_loss + kl_loss)

# Compile VAE

# Train autoencoder
vae.fit(input_train, input_train, epochs = no_epochs, batch_size = batch_size, validation_split = validation_split)

# =================
# Results visualization
# Credits for original visualization code: https://keras.io/examples/variational_autoencoder_deconv/
# (François Chollet).
# Adapted to accomodate this VAE.
# =================
def viz_latent_space(encoder, data):
input_data, target_data = data
mu, _, _ = encoder.predict(input_data)
plt.figure(figsize=(8, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=target_data)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
plt.colorbar()
plt.show()

def viz_decoded(encoder, decoder, data):
num_samples = 15
figure = np.zeros((img_width * num_samples, img_height * num_samples, num_channels))
grid_x = np.linspace(-4, 4, num_samples)
grid_y = np.linspace(-4, 4, num_samples)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(img_width, img_height, num_channels)
figure[i * img_width: (i + 1) * img_width,
j * img_height: (j + 1) * img_height] = digit
plt.figure(figsize=(10, 10))
start_range = img_width // 2
end_range = num_samples * img_width + start_range + 1
pixel_range = np.arange(start_range, end_range, img_width)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel('z - dim 1')
plt.ylabel('z - dim 2')
# matplotlib.pyplot.imshow() needs a 2D array, or a 3D array with the third dimension being of shape 3 or 4!
# So reshape if necessary
fig_shape = np.shape(figure)
if fig_shape[2] == 1:
figure = figure.reshape((fig_shape[0], fig_shape[1]))
# Show image
plt.imshow(figure)
plt.show()

# Plot results
data = (input_test, target_test)
viz_latent_space(encoder, data)
viz_decoded(encoder, decoder, data)Code language: PHP (php)

## Results

Now, time for the results 🙂

Training the model for 100 epochs yields this visualization of the latent space:

As we can see, around $$(0, 0)$$ our latent space is pretty continuous as well as complete. Somewhere around $$(0, -1.5)$$ we see some holes, as well as near the edges (e.g. $$(3, -3)$$). We can see these issues in the actual sampling too:

Especially in the right corners, we see the issue with completeness, which yield outputs that do not make sense. Some issues with continuity are visible wherever the samples are blurred. However, generally speaking, I’m quite happy with the results! 😎

However, let’s see if we can make them even better 🙂

## DCGAN-like architecture

In their paper “Unsupervised representation learning with deep convolutional generative adversarial networks“, Radford et al. (2015) introduce the concept of a deep convolutional generative adversarial network, or DCGAN. While a GAN represents the other branch of generative models, results have suggested that deep convolutional architectures for generative models may produce better results with VAEs as well.

So, as an extension of our original post, we’ve changed the architecture of our model into deeper and wider convolutional layers, in line with Radford et al. (2015). I changed the encoder into:

i       = Input(shape=input_shape, name='encoder_input')
cx      = Conv2D(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(i)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
cx      = Conv2D(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx      = BatchNormalization()(cx)
x       = Flatten()(cx)
x       = Dense(20, activation='relu')(x)
x       = BatchNormalization()(x)
mu      = Dense(latent_dim, name='latent_mu')(x)
sigma   = Dense(latent_dim, name='latent_sigma')(x)Code language: JavaScript (javascript)

And the decoder into:

# Definition
d_i   = Input(shape=(latent_dim, ), name='decoder_input')
x     = Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation='relu')(d_i)
x     = BatchNormalization()(x)
x     = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
cx    = Conv2DTranspose(filters=1024, kernel_size=5, strides=2, padding='same', activation='relu')(x)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=512, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=256, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
cx    = Conv2DTranspose(filters=128, kernel_size=5, strides=2, padding='same', activation='relu')(cx)
cx    = BatchNormalization()(cx)
o     = Conv2DTranspose(filters=num_channels, kernel_size=3, activation='sigmoid', padding='same', name='decoder_output')(cx)Code language: PHP (php)

While our original VAE had approximately 26.000 trainable parameters, this one has approximately 9M:

_________________________________________________________________
Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
encoder_input (InputLayer)   (None, 28, 28, 1)         0
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 4044984
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         4683521
=================================================================
Total params: 8,728,505
Trainable params: 8,324,753
Non-trainable params: 403,752Code language: PHP (php)

However, even after training it for only 5 epochs, results have become considerably better:

Latent space (left) also looks better compared to our initial VAE (right):

However, what is interesting, is that the left one is a zoom, actually, as we also have some outliers now:

Interesting result 🙂

## Summary

In this blog post, we’ve seen how to create a variational autoencoder with Keras. We first looked at what VAEs are, and why they are different from regular autoencoders. We then created a neural network implementation with Keras and explained it step by step, so that you can easily reproduce it yourself while understanding what happens.

In order to compare our initial 26K-parameter VAE, we expanded the architecture to resemble a DCGAN-like architecture of approx. 9M parameters, for both the encoder and the decoder. This yielded better results, but also increased the number of outliers.

I hope you’ve learnt something from this article 🙂 If you did, please let me know by leaving a comment in the comments section below! 👇 If you have questions or remarks, please do the same!

Thank you for reading MachineCurve today and happy engineering 😎

## References

Kingma, D. P., & Welling, M. (2013). Auto-encoding variational bayesarXiv preprint arXiv:1312.6114.

Wiseodd. (2016, December 10). Variational Autoencoder: Intuition and Implementation. Retrieved from http://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/

Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networksarXiv preprint arXiv:1511.06434.

## Do you want to start learning ML from a developer perspective? 👩‍💻

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to learn new things and better understand concepts you already know. We send emails every Friday.
By signing up, you consent that any information you receive can include services and special offers by email.

## 30 thoughts on “How to create a variational autoencoder with Keras?”

Please fix how the code is presented on this website. I am sure you couldnt make it run yourself copy pasting it from this site. Spacing is weird, code is only in one single line instead of multiple lines. Fix this please for the love of god.

1. Chris

Hi there,
Thanks for your comment. I am aware of the issue and am looking for a fix. Most likely, I can spend some time on the matter tomorrow.
Regards,
Chris

2. Chris

Hi there,
Things should be normal again!
Regards,
Chris

2. Robin

Nice post – I think there is an error though in the larger decoder. You pass x instead of cx

1. Chris

Hi Robin,
I’ve looked at the post and if by “larger decoder” you mean the DCGAN-like one, indeed, I noticed the error there. I’ve adapted the post. Thanks!
Regards,
Chris

1. Edwin

Hi Chris. I am runing your code but the size of the input of the encode and the output of the decoder are different:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
encoder_input (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
encoder (Model) [(None, 2), (None, 2), (N 17298104
_________________________________________________________________
decoder (Model) (None, 32, 32, 1) 43457025
=================================================================

I think something is missing in decoder part. I copy and page the code above.

3. Marc

Amazing Post!
Very detailed write-up and well explained 🙂

That said I can’t run the code because I keep getting an issue with vae.fit() (I run tf v2.3.0)
https://github.com/tensorflow/probability/issues/519

Worked around it by adding the line
tf.config.experimental_run_functions_eagerly(True)

1. Chris

Hi Marc,

Thanks for the comment and the addition!
I’ll make sure to add your fix to the post, for those who can’t run it because of their newer TF versions.

Thanks again, best,
Chris

2. pankaj

Thanks for this comment, It helped me fix the code instantly. Also Kudos for this amazing post Chris.

1. Chris

Thank you Pankaj!

4. Herbz

Hi i am running on my own dataset (100x 100 px grayscale images).
It works fine when I use a autoencoder, but when I use VAE, the custom loss with kl_loss, my training process just produce loss: 0.0000e+00.

I have been trying whole afternoon. Do you have any suggestions?

1. Chris

Hi Herbz,

You might need to scale your 0-255 data to 0-1 by e.g. applying min-max normalization.

Best,
Chris

1. Rytis

I noticed the same behavior, it gives 0.0000e+00. Issue is related with ‘kl_reconstruction_loss’ I can see that it produces valid error value with binary_crossentropy, however when it comes to kl_loss it gives (Tensor(“Mean:0”, shape=(), dtype=float32)

1. Chris

Hi Rytis,

That’s interesting. Could be due to TF updates. I’ll look into it. Please make sure to post a solution if you have one. Thanks!

Best,
Chris

Hi Chris! I got an error when calculating KL divergence loss (somewhere at these lines, yet I cannot figure out)
# KL divergence loss
kl_loss = 1 + sigma – K.square(mu) – K.exp(sigma)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5

I try removing the codes above (which means now I’m only using reconstruction loss) and I found that the model starts to train properly, even though it is not how VAE supposed to be. Please let me know if you know how to make the KL divergence loss work as it should be. Thankyou!

Here’s the error by the way:
_SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [, ]

1. Chris

Thanks for your response. The issue seems to be related to the fact that you’re trying to run the VAE with Eager execution enabled, while the code was created for Graph based computation (see https://www.machinecurve.com/index.php/2020/09/13/tensorflow-eager-execution-what-is-it/) for the difference.

However, as eager execution is now the default, I must look for a solution. However – https://medium.com/tensorflow/variational-autoencoders-with-tensorflow-probability-layers-d06c658931b7#57c6 – seems to provide one for you. Instead of computing KL divergence yourself, you could add a MultivariateNormalTriL layer with KLDivergenceRegularizer activity regularizer, which apparently shows the same behavior.

Training with reconstruction loss will work, but should produce less adequate results given the fact that it does not “contract” training space and should not necessarily produce completeness (https://www.machinecurve.com/index.php/2019/12/24/what-is-a-variational-autoencoder-vae/#second-difference-kl-divergence-reconstruction-error-for-optimization). Good to hear that you’ve found my post useful and I hope that this answer helps!

Best,
Chris

Thanks a lot for the references! But here I got simpler approach to solve my problem (or probably the simplest one I guess). So what I did was importing tensorflow as tf, then I do the following steps:

First I noticed that indeed the eager mode is enabled (by default) as the following code returns True.
tf.executing_eagerly()

What I need to do then is just to run tf.compat.v1.disable_eager_execution() which completely disables eager mode.

Now if we run tf.executing_eagerly() then the output is going to be False.

Finally, as eager mode is already deactivated, we can just run the entire code without error at all.

Thanks again!

Note: Feel free to visit my Medium 🙂 https://medium.com/@muhammad_ardi

1. Chris

That’s indeed a simpler solution!

Best,
Chris

1. Atin Vikram Singh

Hi Chris,
The blog and this comment were indeed helpful.
However, I am getting an error on running the code after disabling eager mode

“ /usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1480 ret = tf_session.TF_SessionRunCallable(self._session._session,
1481 self._handle, args,

FailedPreconditionError: Could not find variable training/Adam/beta_1. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status=Not found: Resource localhost/training/Adam/beta_1/N10tensorflow3VarE does not exist.

I’ve tried doing tf.compat.v1.global_variables_initializer() but the error persists. tf version = 2.5

Regards, Atin

2. Chris

Hi Atin,
What does the code for specifying the optimizer look like?
Is it simply optimizer=’adam’ or are you using a different way (e.g. by passing an Adam() object)?
Best,
Chris

3. Atin Vikram Singh

Hi Chris,
Apparently changing keras to tensorflow.keras at the top did the trick. Must be some issue with tensorflow 2.5.0
Best Wishes

4. Chris

Great that it helped!

Best,
Chris

6. Faaiz

Hello,
I am using your approach for my own 1d dataset with shape like (samples, timesteps, dimensions). I get an error at vae.fit. the error is:
ValueError: Dimensions must be equal, but are 376 and 32 for ‘{{node AddV2_4}} = AddV2[T=DT_FLOAT, _cloned=true](inputs, inputs_1)’ with input shapes: [32,376], [32].

Do you know why this is?

1. Chris

Hi Faaiz,

This error is very common in cases where there is a mismatch related to the shape of your input data.
What is the precise error stack (i.e. where does the error occur precisely; in what part of the model?)
And could you share the relevant parts of your code?

If I have some additional time, I can try and take a look.

Best,
Chris

7. Pablo

Hi, excellent post, very clear and well explained, however when I run the code, I got this error:

“ValueError: Output tensors to a Model must be the output of a TensorFlow Layer (thus holding past layer metadata). Found: Tensor(“lambda/add:0″, shape=(?, 2), dtype=float32)”

Debbuging code I found the error is generate by z variable in this line:
# Instantiate encoder
encoder = Model(i, [mu, sigma, z], name=’encoder’)

Thanks

1. Chris

Hi Pablo,

Could be the case that newer TensorFlow versions don’t support this as natively anymore, and I think that the issue resolves with some addition. Perhaps seeing how Add() layers work can help you?
I’d have to check and run this code some day, but something related to time … 🙁

Best,
Chris

8. David

Very useful thank you!

1. Chris

Thank you David!

Best,
Chris