Recent advances in training deep neural networks have led to a whole bunch of impressive machine learning models which are able to tackle a very diverse range of tasks. When you are developing such a model, one of the notable downsides is that it is considered a “black-box” approach in the sense that your model learns from data you feed it, but you don’t really know what is going on inside the model.

To make it clearer: you don’t really know what your model actually learned and if you have a flaw in your training / data approach it might work well according to your metrics while having learnt the wrong thing. As a self-respecting developer you want to do better than that, so today I will show you a method you can use to get some better introspection into your model by using visualization techniques.

So what is a visualization techniqe when we talk about deep neural networks?

The basic idea of visualization is that you try to figure out what your network is tuned towards by letting it manipulate your input image to make the image as “exciting” as possible for the network. It’s kind of like letting a kid decide which and how much ice cream he wants to have - you will end up with 9 scoops of all flavors.

Let’s look at the ice cream flavors and scoops of an example neural network. To make it easy, I’ll focus on the well-known MNIST problem in PyTorch to make sure it is easy for everyone to have an intuition and figure out what the visualization can show us.

Example PyTorch network: MNIST classification

So, first, we will need to code to train a PyTorch network for MNIST - of course as this is a simple problem we don’t need a very deep network, but the techniques presented here can be used for very deep ones in the same way, so I consider this a positive note.

In MNIST, the task of the network is to classify the written digits 0 - 9 in images.

Example images look like this:

TODO Insert example images

I used the PyTorch code for MNIST from the examples repository which consists of two convolutional layers followed by two fully connected layers. When training it for 10 epochs on 60000 examples, it achieved 99% accuracy, so it is working well.

Visualization Intuition

When we are training a neural net we are using the method of backpropagation to change the weights of our networks with respect to the error gradients we receive going backwards from our output. The input image is fixed in this respect and the weights are variable.

In visualization we are doing the exact opposite: we are keeping the weights fixed to evaluate a given model and then use backpropagation to change the input image to get “better” results of our network. And this “better” results is where several possibilities kick in. We could be interested in finding out which input image our network considers most appropriate when we want to find out a specific class, e.g. the digit 5 in our MNIST example. The task then becomes: create an image which leads the network to output the class 5 as much as possible.

Visualization Implementation

We need to make sure that our network is fixed, so its weights don’t change when we use backpropagation. To do so we TODO.

On the other hand, we want the algorithm to change the input image to excite the network more and more, so we need to make sure that the input image can be adjusted. To do so we TODO.

Visualizing the deeper parts

So far, we asked the high-level question: how does an input image look like to excite the network most. Another interesting question we can ask is: What do the inputs for the different channels in our layers look like that excite these channels most?

To do this, we need to use so-called hooks to find out the inputs and outputs to these inner parts and then tune the network several times to maximally excite this.

Further possibilities

Of course, this blog post only scratches the surface of what you can do. You need to go deeper. TODO Insert go deeper meme.

You can visualize the actual inputs each part of the network receives when you present it with a particular image, you can use more advanced techniques to better visualize the network etc.

If you want to know more, leave a comment, so I consider to write another post about this topic.