How to visualize a model with TensorFlow 2 and Keras?

How to visualize a model with TensorFlow 2 and Keras?

Last Updated on 26 January 2021

Every now and then, you might need to demonstrate your Keras model structure. There’s one or two things that you may do when this need arises. First, you may send the person who needs this overview your code, requiring them to derive the model architecture themselves. If you’re nicer, you send them a model of your architecture.

…but creating such models is often a hassle when you have to do it manually. Solutions like www.draw.io are used quite often in those cases, because they are (relatively) quick and dirty, allowing you to create models fast.

However, there’s a better solution: the built-in plot_model facility within Keras. It allows you to create a visualization of your model architecture. In this blog, I’ll show you how to create such a visualization. Specifically, I focus on the model itself, discussing its architecture so that you fully understand what happens. Subsquently, I’ll list some software dependencies that you’ll need – including a highlight about a bug in Keras that results in a weird error related to pydot and GraphViz, which are used for visualization. Finally, I present you the code used for visualization and the end result.

After reading this tutorial, you will…

  • Understand what the plot_model() util in TensorFlow 2.0/Keras does.
  • Know what types of plots it generates.
  • Have created a neural network that visualizes its structure.

Note that model code is also available on GitHub.


Update 22/Jan/2021: ensured that the tutorial is up-to-date and reflects code for TensorFlow 2.0. It can now be used with recent versions of the library. Also performed some header changes and textual improvements based on the switch from Keras 1.0 to TensorFlow 2.0. Also added an exampl of horizontal plotting.



Code example: using plot_model for visualizing the model

If you want to get started straight away, here is the code that you can use for visualizing your TensorFlow 2.0/Keras model with plot_model:

from tensorflow.keras.utils import plot_model plot_model(model, to_file='model.png')
Code language: Python (python)

Make sure to read the rest of this tutorial if you want to understand everything in more detail!


Today’s to-be-visualized model

To show you how to visualize a Keras model, I think it’s best if we discussed one first.

Today, we will visualize the Convolutional Neural Network that we created earlier to demonstrate the benefits of using CNNs over densely-connected ones.

This is the code of that model:

import tensorflow from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D # Model configuration img_width, img_height = 28, 28 batch_size = 250 no_epochs = 25 no_classes = 10 validation_split = 0.2 verbosity = 1 # Load MNIST dataset (input_train, target_train), (input_test, target_test) = mnist.load_data() # Set input shape sample_shape = input_train[0].shape img_width, img_height = sample_shape[0], sample_shape[1] input_shape = (img_width, img_height, 1) # Reshape data input_train = input_train.reshape(len(input_train), input_shape[0], input_shape[1], input_shape[2]) input_test = input_test.reshape(len(input_test), input_shape[0], input_shape[1], input_shape[2]) # Parse numbers as floats input_train = input_train.astype('float32') input_test = input_test.astype('float32') # Convert them into black or white: [0, 1]. input_train = input_train / 255 input_test = input_test / 255 # Convert target vectors to categorical targets target_train = tensorflow.keras.utils.to_categorical(target_train, no_classes) target_test = tensorflow.keras.utils.to_categorical(target_test, no_classes) # Create the model model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(256, activation='relu')) model.add(Dense(no_classes, activation='softmax')) # Compile the model model.compile(loss=tensorflow.keras.losses.categorical_crossentropy, optimizer=tensorflow.keras.optimizers.Adam(), metrics=['accuracy']) # Fit data to model model.fit(input_train, target_train, batch_size=batch_size, epochs=no_epochs, verbose=verbosity, validation_split=validation_split) # Generate generalization metrics score = model.evaluate(input_test, target_test, verbose=0) print(f'Test loss: {score[0]} / Test accuracy: {score[1]}')
Code language: Python (python)

What does it do?

I’d suggest that you read the post if you wish to understand it very deeply, but I’ll briefly cover it here.

It simply classifies the MNIST dataset. This dataset contains 28 x 28 pixel images of digits, or numbers between 0 and 9, and our CNN classifies them with a staggering 99% accuracy. It does so by combining two convolutional blocks (which consist of a two-dimensional convolutional layer, two-dimensional max pooling and dropout) with densely-conneted layers. It’s the best of both worlds in terms of interpreting the image and generating final predictions.

Let's pause for a second! πŸ‘©β€πŸ’»

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!

By signing up, you consent that any information you receive can include services and special offers by email.

But how to visualize this model’s architecture? Let’s find out.


Built-in plot_model util

Utilities. I love them, because they make my life easier. They’re often relatively simple functions that can be called upon to perform some relatively simple actions. Don’t be fooled, however, because these actions often benefit one’s efficiently greatly – in this case, not having to visualize a model architecture yourself in tools like draw.io

I’m talking about the plot_model util, which comes delivered with Keras.

It allows you to create a visualization of your Keras neural network.

More specifically, the Keras docs define it as follows:

from tensorflow.keras.utils import plot_model plot_model(model, to_file='model.png')
Code language: Python (python)

From the Keras utilities, one needs to import the function, after which it can be used with very minimal parameters:

  • The model instance, or the model that you created – whether you created it now or preloaded it instead from a model saved to disk.
  • And the to_file parameter, which essentially specifies a location on disk where the model visualization is stored.

If you wish, you can supply some additional parameters as well:

  • The show_shapes argument (which is False by default) which controls whether the shape of the layer outputs are shown in the graph. This would be beneficial if besides the architecture you also need to understand how it transforms data.
  • With show_dtypes (False by default) you can indicate whether to show layer data types on the plot.
  • The show_layer_names argument (True by default) which determines whether the names of the layers are displayed.
  • The rankdir (TB by default) can be used to indicate whether you want a vertical or horizontal plot. TB is vertical, LR is horizontal.
  • The expand_nested (False by default) controls how nested models are displayed.
  • Dpi controls the dpi value of the image.

However, likely, for a simple visualization, you don’t need them. Let’s now take a look what we would need if we were to create such a visualization.


Software dependencies

If you wish to run the code presented in this blog successfully, you need to install certain software dependencies. You’ll need those to run it:

  • TensorFlow 2.0, or any subsequent version, which makes sense given the fact that we’re using a Keras util for model visualization;
  • Python, preferably 3.8+, which is required if you wish to run Keras.
  • Graphviz, which is a graph visualization library for Python. Keras uses it to generate the visualization of your neural network. You can install Graphviz from their website.

Preferably, you’ll run this from an Anaconda environment, which allows you to run these packages in an isolated fashion. Note that many people report that a pip based installation of Graphviz doesn’t work; rather, you’ll have to install it separately into your host OS from their website. Bummer!

Keras bug: pydot failed to call GraphViz

When trying to visualize my Keras neural network with plot_model, I ran into this error:

'`pydot` failed to call GraphViz.' OSError: `pydot` failed to call GraphViz.Please install GraphViz (https://www.graphviz.org/) and ensure that its executables are in the $PATH.
Code language: JavaScript (javascript)

…which essentially made sense at first, because I didn’t have Graphviz installed.

…but which didn’t after I installed it, because the error kept reappearing, even after restarting the Anaconda terminal.

Sometimes, it helps to install pydotplus as well with pip install pydotplus. Another solution, although not preferred, is to downgrade your pydot version.


Visualization code

When adapting the code from my original CNN, scrapping away the elements I don’t need for visualizing the model architecture, I end up with this:

Never miss new Machine Learning articles βœ…

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!

By signing up, you consent that any information you receive can include services and special offers by email.
import tensorflow from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D from tensorflow.keras.utils import plot_model # Load MNIST dataset (input_train, target_train), (input_test, target_test) = mnist.load_data() # Set input shape sample_shape = input_train[0].shape img_width, img_height = sample_shape[0], sample_shape[1] input_shape = (img_width, img_height, 1) # Number of classes no_classes = 10 # Reshape data input_train = input_train.reshape(len(input_train), input_shape[0], input_shape[1], input_shape[2]) input_test = input_test.reshape(len(input_test), input_shape[0], input_shape[1], input_shape[2]) # Create the model model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(256, activation='relu')) model.add(Dense(no_classes, activation='softmax')) plot_model(model, to_file='model.png')
Code language: Python (python)

You’ll first perform the imports that you still need in order to successfully run the Python code. Specifically, you’ll import the Keras library, the Sequential API and certain layers – this is obviously dependent on what you want. Do you want to use the Functional API? That’s perfectly fine. Other layers? Fine too. I just used them since the CNN is exemplary.

Note that I also imported plot_model with from tensorflow.keras.utils import plot_model and reshaped the data to accomodate for the Conv2D layer.

Speaking about architecture: that’s what I kept in. Based on the Keras Sequential API, I apply the two convolutional blocks as discussed previously, before flattening their output and feeding it to the densely-connected layers generating the final prediction. And, of course, we need no_classes = 10 to ensure that our final Dense layer works as well.

However, in this case, no such prediction is generated. Rather, the model instance is used by plot_model to generate a model visualization stored at disk as model.png. Likely, you’ll add hyperparameter tuning and data fitting later on – but hey, that’s not the purpose of this blog.


End result

And your final end result looks like this:


Making a horizontal TF/Keras model plot

Indeed, above we saw that we can use the rankdir attribute (which is set to TB i.e. vertical by default) to generate a horizontal plot! This is new, and highly preferred, as we sometimes don’t want these massive vertical plots.

Making a horizontal plot of your TensorFlow/Keras model simply involves adding the rankdir='LR' a.k.a. horizontal attribute:

plot_model(model, to_file='model.png', rankdir='LR')
Code language: Python (python)

Which gets you this:

Awesome!


Summary

In this blog, you’ve seen how to create a Keras model visualization based on the plot_model util provided by the library. I hope you found it useful – let me know in the comments section, I’d appreciate it! 😎 If not, let me know as well, so I can improve. For now: happy engineering! πŸ‘©β€πŸ’»

Note that model code is also available on GitHub.


References

How to create a CNN classifier with Keras? – MachineCurve. (2019, September 24). Retrieved from https://www.machinecurve.com/index.php/2019/09/17/how-to-create-a-cnn-classifier-with-keras/

Keras. (n.d.). Visualization. Retrieved from https://keras.io/visualization/

Avoid wasting resources with EarlyStopping and ModelCheckpoint in Keras – MachineCurve. (2019, June 3). Retrieved from https://www.machinecurve.com/index.php/2019/05/30/avoid-wasting-resources-with-earlystopping-and-modelcheckpoint-in-keras/

pydot issue Β· Issue #7 Β· XifengGuo/CapsNet-Keras. (n.d.). Retrieved from https://github.com/XifengGuo/CapsNet-Keras/issues/7#issuecomment-536100376

Join hundreds of other learners! 😎

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to MachineCurve's free Machine Learning update today! You will learn new things and better understand concepts you already know.

We send emails at least every Friday. Welcome!

By signing up, you consent that any information you receive can include services and special offers by email.

TensorFlow. (2021). Tf.keras.utils.plot_modelhttps://www.tensorflow.org/api_docs/python/tf/keras/utils/plot_model

Do you want to start learning ML from a developer perspective? πŸ‘©β€πŸ’»

Blogs at MachineCurve teach Machine Learning for Developers. Sign up to learn new things and better understand concepts you already know. We send emails every Friday.

By signing up, you consent that any information you receive can include services and special offers by email.

9 thoughts on “How to visualize a model with TensorFlow 2 and Keras?

  1. Prem K Tiwari

    Graphviz module does not sits well with pydot. i tried everything !

    1. Chris

      What errors are you getting?

      Best,
      Chris

    2. Chris

      Hi Prem,

      When updating the article, I ran into this issue as well. You might try the following – installing pydotplus as well, with pip install pydotplus.

      All the best,
      Chris

Leave a Reply

Your email address will not be published. Required fields are marked *