How to visualize a model with Keras?

How to visualize a model with Keras?

Every now and then, you might need to demonstrate your Keras model structure. There’s one or two things that you may do when this need arises. First, you may send the person who needs this overview your code, requiring them to derive the model architecture themselves. If you’re nicer, you send them a model of your architecture.

…but creating such models is often a hassle when you have to do it manually. Solutions like www.draw.io are used quite often in those cases, because they are (relatively) quick and dirty, allowing you to create models fast.

However, there’s a better solution: the built-in plot_model facility within Keras. It allows you to create a visualization of your model architecture. In this blog, I’ll show you how to create such a visualization. Specifically, I focus on the model itself, discussing its architecture so that you fully understand what happens. Subsquently, I’ll list some software dependencies that you’ll need – including a highlight about a bug in Keras that results in a weird error related to pydot and GraphViz, which are used for visualization. Finally, I present you the code used for visualization and the end result.

Note that model code is also available on GitHub.

Today’s to-be-visualized model

To show you how to visualize a Keras model, I think it’s best if we discussed one first.

Today, we will visualize the Convolutional Neural Network that we created earlier to demonstrate the benefits of using CNNs over densely-connected ones.

This is the code of that model:

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

# Model configuration
img_width, img_height = 28, 28
batch_size = 250
no_epochs = 25
no_classes = 10
validation_split = 0.2
verbosity = 1

# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()

# Reshape data based on channels first / channels last strategy.
# This is dependent on whether you use TF, Theano or CNTK as backend.
# Source: https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
if K.image_data_format() == 'channels_first':
    input_train = input_train.reshape(input_train.shape[0], 1, img_width, img_height)
    input_test = input_test.reshape(input_test.shape[0], 1, img_width, img_height)
    input_shape = (1, img_width, img_height)
else:
    input_train = input_train.reshape(input_train.shape[0], img_width, img_height, 1)
    input_test = input_test.reshape(input_test.shape[0], img_width, img_height, 1)
    input_shape = (img_width, img_height, 1)

# Parse numbers as floats
input_train = input_train.astype('float32')
input_test = input_test.astype('float32')

# Convert them into black or white: [0, 1].
input_train = input_train / 255
input_test = input_test / 255

# Convert target vectors to categorical targets
target_train = keras.utils.to_categorical(target_train, no_classes)
target_test = keras.utils.to_categorical(target_test, no_classes)

# Create the model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(no_classes, activation='softmax'))

# Compile the model
model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adam(),
              metrics=['accuracy'])

# Fit data to model
model.fit(input_train, target_train,
          batch_size=batch_size,
          epochs=no_epochs,
          verbose=verbosity,
          validation_split=validation_split)

# Generate generalization metrics
score = model.evaluate(input_test, target_test, verbose=0)
print(f'Test loss: {score[0]} / Test accuracy: {score[1]}')

What does it do?

I’d suggest that you read the post if you wish to understand it very deeply, but I’ll briefly cover it here.

It simply classifies the MNIST dataset. This dataset contains 28 x 28 pixel images of digits, or numbers between 0 and 9, and our CNN classifies them with a staggering 99% accuracy. It does so by combining two convolutional blocks (which consist of a two-dimensional convolutional layer, two-dimensional max pooling and dropout) with densely-conneted layers. It’s the best of both worlds in terms of interpreting the image and generating final predictions.

But how to visualize this model’s architecture? Let’s find out.

Built-in plot_model util

Utilities. I love them, because they make my life easier. They’re often relatively simple functions that can be called upon to perform some relatively simple actions. Don’t be fooled, however, because these actions often benefit one’s efficiently greatly – in this case, not having to visualize a model architecture yourself in tools like draw.io

I’m talking about the plot_model util, which comes delivered with Keras.

It allows you to create a visualization of your Keras neural network.

More specifically, the Keras docs define it as follows:

from keras.utils import plot_model
plot_model(model, to_file='model.png')

From the Keras utilities, one needs to import the function, after which it can be used with very minimal parameters:

  • The model instance, or the model that you created – whether you created it now or preloaded it instead from a model saved to disk.
  • And the to_file parameter, which essentially specifies a location on disk where the model visualization is stored.

If you wish, you can supply some additional parameters as well:

  • The show_shapes argument (which is False by default) which controls whether the shape of the layer outputs are shown in the graph. This would be beneficial if besides the architecture you also need to understand how it transforms data.
  • The show_layer_names argument (True by default) which determines whether the names of the layers are displayed.
  • The expand_nested (False by default) controls how nested models are displayed.
  • Dpi controls the dpi value of the image.

However, likely, for a simple visualization, you don’t need them. Let’s now take a look what we would need if we were to create such a visualization.

Software dependencies

If you wish to run the code presented in this blog successfully, you need to install certain software dependencies. You’ll need those to run it:

  • Keras, which makes sense given the fact that we’re using a Keras util for model visualization;
  • Tensorflow, Theano or CNTK, which are the number processing frameworks that lay underneath the Keras framework. Even though you’re not effectively processing numbers, you’ll need to have at least one installed and linked to Keras.
  • Python, preferably 3.6+, which is required if you wish to run Keras.
  • Graphviz, which is a graph visualization library for Python. Keras uses it to generate the visualization of your neural network. You can install Graphviz from their website.

Preferably, you’ll run this from an Anaconda environment, which allows you to run these packages in an isolated fashion. Note that many people report that a pip based installation of Graphviz doesn’t work; rather, you’ll have to install it separately into your host OS from their website. Bummer!

Keras bug: pydot failed to call GraphViz

When trying to visualize my Keras neural network with plot_model, I ran into this error:

'`pydot` failed to call GraphViz.'
OSError: `pydot` failed to call GraphViz.Please install GraphViz (https://www.graphviz.org/) and ensure that its executables are in the $PATH.

…which essentially made sense at first, because I didn’t have Graphviz installed.

…but which didn’t after I installed it, because the error kept reappearing, even after restarting the Anaconda terminal.

Fortunately, the internet comes to the rescue in those cases:

Or you can install pydot 1.2.3 by pip.
pip install pydot==1.2.3

XifengGuo

Although downgrading packages is not likely to be the best long-term solution, it did certainly work in this case. The error was resolved and I could generate model visualizations. Let’s therefore now take a look at the visualization code.

Visualization code

When adapting the code from my original CNN, scrapping away the elements I don’t need for visualizing the model architecture, I end up with this:

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils.vis_utils import plot_model

# Load MNIST dataset
(input_train, target_train), (input_test, target_test) = mnist.load_data()

# Reshape data based on channels first / channels last strategy.
# This is dependent on whether you use TF, Theano or CNTK as backend.
# Source: https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
if K.image_data_format() == 'channels_first':
    input_train = input_train.reshape(input_train.shape[0], 1, img_width, img_height)
    input_test = input_test.reshape(input_test.shape[0], 1, img_width, img_height)
    input_shape = (1, img_width, img_height)
else:
    input_train = input_train.reshape(input_train.shape[0], img_width, img_height, 1)
    input_test = input_test.reshape(input_test.shape[0], img_width, img_height, 1)
    input_shape = (img_width, img_height, 1)

# Create the model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(no_classes, activation='softmax'))

plot_model(model, to_file='model.png')

You’ll first perform the imports that you still need in order to successfully run the Python code. Specifically, you’ll import the Keras library, the Sequential API and certain layers – this is obviously dependent on what you want. Do you want to use the Functional API? That’s perfectly fine. Other layers? Fine too. I just used them since the CNN is exemplary.

Note that I also imported plot_model with from keras.utils.vis_utils import plot_model.

Subsequently, I kept the mnist-specific reshaping based on the channels first / channels last approach of the framework. Although this might not be necessary in your model, I had to keep it in because the first Conv2D layer’s input shape is dependent on input_shape, which itself is generated by the reshaping process. For the sake of simplicity, I thus kept it MNIST-specific. If you don’t use MNIST, however, you can just keep this out, because it’s the architecture that actually matters.

Speaking about architecture: that’s what I finally kept in. Based on the Keras Sequential API, I apply the two convolutional blocks as discussed previously, before flattening their output and feeding it to the densely-connected layers generating the final prediction.

However, in this case, no such prediction is generated. Rather, the model instance is used by plot_model to generate a model visualization stored at disk as model.png. Likely, you’ll add hyperparameter tuning and data fitting later on – but hey, that’s not the purpose of this blog.

End result

And your final end result looks like this:

Summary

In this blog, you’ve seen how to create a Keras model visualization based on the plot_model util provided by the library. I hope you found it useful – let me know in the comments section, I’d appreciate it! ๐Ÿ˜Ž If not, let me know as well, so I can improve. For now: happy engineering! ๐Ÿ‘ฉโ€๐Ÿ’ป

Note that model code is also available on GitHub.

References

How to create a CNN classifier with Keras? โ€“ MachineCurve. (2019, September 24). Retrieved from https://www.machinecurve.com/index.php/2019/09/17/how-to-create-a-cnn-classifier-with-keras/

Keras. (n.d.). Visualization. Retrieved from https://keras.io/visualization/

Avoid wasting resources with EarlyStopping and ModelCheckpoint in Keras โ€“ MachineCurve. (2019, June 3). Retrieved from https://www.machinecurve.com/index.php/2019/05/30/avoid-wasting-resources-with-earlystopping-and-modelcheckpoint-in-keras/

pydot issue ยท Issue #7 ยท XifengGuo/CapsNet-Keras. (n.d.). Retrieved from https://github.com/XifengGuo/CapsNet-Keras/issues/7#issuecomment-536100376

Do you want to start learning ML from a developer perspective? ๐Ÿ‘ฉโ€๐Ÿ’ป

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to learn new things and better understand concepts you already know. We send emails every Friday.

By signing up, you consent that any information you receive can include services and special offers by email.