Gregoire Montavon committed Jan 11, 2021 1 2 3 4 5 6 7 

Tutorial: Implementing Layer-Wise Relevance Propagation

first version: Jul 14, 2016
last update: Sep 17, 2019

 Gregoire Montavon committed Jan 12, 2021 8 

This tutorial explains how to implement layer-wise relevance propagation (LRP) easily and efficiently, as described in the overview paper:  Gregoire Montavon committed Jan 11, 2021 9   Gregoire Montavon committed Jan 12, 2021 10 11 

G. Montavon, A. Binder, S. Lapuschkin, W. Samek, K.-R. Müller
Layer-wise Relevance Propagation: An Overview
in Explainable AI: Interpreting, Explaining and Visualizing Deep Learning, Springer LNCS, vol. 11700,  Gregoire Montavon committed Jan 11, 2021 12 13 14 2019
 Gregoire Montavon committed Jan 12, 2021 15 16 17 18 19 20 We consider two models: (1) a simple plain deep rectifier network trained on the MNIST handwritten digits data, (2) the VGG-16 network trained on ImageNet and applicable to general image classification.

Note: If you are instead looking for ready to use software, have a look at the software section of this website. If you want to try relevance propagation without installing software, check our interactive demos. For the original paper on LRP, see instead:

S. Bach, A. Binder, G. Montavon, F. Klauschen, K.-R. Müller, W. Samek
On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation
PloS ONE 10 (7), e0130140,  Gregoire Montavon committed Jan 11, 2021 21 22 23 24 25 26 27 28 2015

1   Numpy Implementation for a Fully-Connected Network

We first load 12 examplary MNIST test digits.

 Gregoire Montavon committed Jan 12, 2021 29 30 31 32 33 34 python import utils X,T = utils.loaddata() %matplotlib inline utils.digit(X.reshape(1,12,28,28).transpose(0,2,1,3).reshape(28,12*28),9,0.75)   Gregoire Montavon committed Jan 11, 2021 35 36   Gregoire Montavon committed Jan 12, 2021 37 ![png](tutorial_files/tutorial_2_0.png)  Gregoire Montavon committed Jan 11, 2021 38 39   Gregoire Montavon committed Jan 12, 2021 40 Each digit is stored as a 784-dimensional vector of pixel values, where "-1.0" corresponds to black and "+1.0" corresponds to white.  Gregoire Montavon committed Jan 11, 2021 41 42 43 

1.1   Predicting the class of MNIST digits

 Gregoire Montavon committed Jan 12, 2021 44 

These digits are fed to a fully connected neural network with layer sizes 784-300-100-10 with ReLU activations for each hidden layer. The architecture is depicted in the figure below.

 Gregoire Montavon committed Jan 11, 2021 45 46 47 
 Gregoire Montavon committed Jan 12, 2021 48 

The network we consider achieves an error of 1.6% which is a typical performance for a neural network without particular structure or regularization. The function utils.loadparams() retrieves its parameters for us.

 Gregoire Montavon committed Jan 11, 2021 49 50   Gregoire Montavon committed Jan 12, 2021 51 52 53 54 python W,B = utils.loadparams() L = len(W)   Gregoire Montavon committed Jan 11, 2021 55   Gregoire Montavon committed Jan 12, 2021 56 

From these parameters, the forward pass can be computed as a sequence of matrix multiplications and nonlinearities.

 Gregoire Montavon committed Jan 11, 2021 57 58   Gregoire Montavon committed Jan 12, 2021 59 60 61 62 63 64 python import numpy A = [X]+[None]*L for l in range(L): A[l+1] = numpy.maximum(0,A[l].dot(W[l])+B[l])   Gregoire Montavon committed Jan 11, 2021 65   Gregoire Montavon committed Jan 12, 2021 66 Note that this code has added an additional top-layer ReLU activation compared to the original neural network. This however doesn't affect computations when looking at positive output scores. The top layer activations are scores measuring the evidence the network has found for each class. In the following, we show the first three digits and the scores produced for each class at the output:  Gregoire Montavon committed Jan 11, 2021 67 68   Gregoire Montavon committed Jan 12, 2021 69 70 71 72 73 74 python for i in range(3): utils.digit(X[i].reshape(28,28),0.75,0.75) p = A[L][i] print(" ".join(['[%1d] %.1f'%(d,p[d]) for d in range(10)]))   Gregoire Montavon committed Jan 11, 2021 75 76   Gregoire Montavon committed Jan 12, 2021 77 ![png](tutorial_files/tutorial_8_0.png)  Gregoire Montavon committed Jan 11, 2021 78 79 80 81 82 83  [0] 0.0 [1] 3.6 [2] 49.1 [3] 8.9 [4] 0.0 [5] 0.0 [6] 0.0 [7] 1.4 [8] 1.6 [9] 0.0  Gregoire Montavon committed Jan 12, 2021 84 ![png](tutorial_files/tutorial_8_2.png)  Gregoire Montavon committed Jan 11, 2021 85 86 87 88 89 90  [0] 0.0 [1] 27.0 [2] 0.0 [3] 0.0 [4] 5.3 [5] 0.0 [6] 0.0 [7] 13.0 [8] 8.1 [9] 2.3  Gregoire Montavon committed Jan 12, 2021 91 ![png](tutorial_files/tutorial_8_4.png)  Gregoire Montavon committed Jan 11, 2021 92 93 94 95 96  [0] 49.1 [1] 0.0 [2] 10.6 [3] 0.0 [4] 0.0 [5] 0.2 [6] 0.0 [7] 3.0 [8] 0.0 [9] 9.2  Gregoire Montavon committed Jan 12, 2021 97 

As expected, the highest score systematically corresponds to the correct digit.

 Gregoire Montavon committed Jan 11, 2021 98 99 100 

1.2   Explaining the predictions with LRP

 Gregoire Montavon committed Jan 12, 2021 101 

We now implement the layer-wise relevance propagation (LRP) procedure from the top to the bottom of the network. As a first step, we create a list to store relevance scores at each layer. The top layer relevance scores are set to the top-layer activations, which we multiply by a label indicator in order to retain only the evidence for the actual class.

 Gregoire Montavon committed Jan 11, 2021 102 103   Gregoire Montavon committed Jan 12, 2021 104 105 106 python R = [None]*L + [A[L]*(T[:,None]==numpy.arange(10))]   Gregoire Montavon committed Jan 11, 2021 107   Gregoire Montavon committed Jan 12, 2021 108 

The LRP-0, LRP-ϵ, and LRP-γ rules described in the LRP overview paper (Section 10.2.1) for propagating relevance on the lower layers are special cases of the more general propagation rule

 Gregoire Montavon committed Jan 11, 2021 109   Gregoire Montavon committed Jan 12, 2021 110   Gregoire Montavon committed Jan 11, 2021 111   Gregoire Montavon committed Jan 12, 2021 112 

(cf. Section 10.2.2), where ρ is a function that transform the weights, and ϵ is a small positive increment. We define below two helper functions that perform the weight transformation and the incrementation. In practice, we would like to apply different rules at different layers (cf. Section 10.3). Therefore, we also give the layer index "l" as argument to these functions.

 Gregoire Montavon committed Jan 11, 2021 113 114   Gregoire Montavon committed Jan 12, 2021 115 116 117 118 python def rho(w,l): return w + [None,0.1,0.0,0.0][l] * numpy.maximum(0,w) def incr(z,l): return z + [None,0.0,0.1,0.0][l] * (z**2).mean()**.5+1e-9   Gregoire Montavon committed Jan 11, 2021 119   Gregoire Montavon committed Jan 12, 2021 120 

In particular, these functions and the layer they receive as a parameter let us reduce the general rule to LRP-0 for the top-layer, to LRP-ϵ with ϵ = 0.1std for the layer just below, and to LRP-γ with γ=0.1 for the layer before. We now come to the practical implementation of this general rule. It can be decomposed as a sequence of four computations:

 Gregoire Montavon committed Jan 11, 2021 121 122 

 Gregoire Montavon committed Jan 12, 2021 123   Gregoire Montavon committed Jan 11, 2021 131 132 

 Gregoire Montavon committed Jan 12, 2021 133 

The layer-wise relevance propagation procedure then consists of iterating over the layers in reverse order, starting from the top layer towards the first layers, and at each layer, applying this sequence of computations.

 Gregoire Montavon committed Jan 11, 2021 134 135   Gregoire Montavon committed Jan 12, 2021 136 137 138 139 140 python for l in range(1,L)[::-1]: w = rho(W[l],l) b = rho(B[l],l)  Gregoire Montavon committed Jan 11, 2021 141   Gregoire Montavon committed Jan 12, 2021 142 143 144 145 146 147 148 149 150 151 152  z = incr(A[l].dot(w)+b,l) # step 1 s = R[l+1] / z # step 2 c = s.dot(w.T) # step 3 R[l] = A[l]*c # step 4 

Note that the loop above stops one layer before reaching the pixels. To propagate relevance scores until the pixels, we need to apply an alternate propagation rule that properly handles pixel values received as input (cf. Section 10.3.2). In particular, we apply for this layer the zB-rule given by:

In this rule, li and hi are the lower and upper bounds of pixel values, i.e. "-1" and "+1", and (·)+ and (·) are shortcut notations for max(0,·) and min(0,·). The zB-rule can again be implemented with a four-step procedure similar to the one used in the layers above. Here, we need to create two copies of the weights, and also create arrays of pixel values set to li and hi respectively:

 Gregoire Montavon committed Jan 11, 2021 153 154   Gregoire Montavon committed Jan 12, 2021 155 156 157 158 159 160 python w = W[0] wp = numpy.maximum(0,w) wm = numpy.minimum(0,w) lb = A[0]*0-1 hb = A[0]*0+1  Gregoire Montavon committed Jan 11, 2021 161   Gregoire Montavon committed Jan 12, 2021 162 163 164 165 166 167 168 169 170 171 172 173 174 z = A[0].dot(w)-lb.dot(wp)-hb.dot(wm)+1e-9 # step 1 s = R[1]/z # step 2 c,cp,cm = s.dot(w.T),s.dot(wp.T),s.dot(wm.T) # step 3 R[0] = A[0]*c-lb*cp-hb*cm # step 4 

We have now reached the bottom layer. The obtained pixel-wise relevance scores can be rendered as a heatmap.

python utils.digit(X.reshape(1,12,28,28).transpose(0,2,1,3).reshape(28,12*28),9,0.75) utils.heatmap(R[0].reshape(1,12,28,28).transpose(0,2,1,3).reshape(28,12*28),9,0.75)   Gregoire Montavon committed Jan 11, 2021 175 176   Gregoire Montavon committed Jan 12, 2021 177 ![png](tutorial_files/tutorial_18_0.png)  Gregoire Montavon committed Jan 11, 2021 178 179 180   Gregoire Montavon committed Jan 12, 2021 181 ![png](tutorial_files/tutorial_18_1.png)  Gregoire Montavon committed Jan 11, 2021 182 183   Gregoire Montavon committed Jan 12, 2021 184 Relevant pixels are highlighted in red. Pixels that contribute negatively to the prediction, if any, are shown in blue. On most digits, we find that the digit itself is highlighted, as well as some parts of the background. For example, we observe two red horizontal bars next to the digit "3", highlighting the fact that if those pixels would be different, the digit 3 would likely turn into a "8". Same for the vertical bar above the digit "4" that supports the class "4" instead of the class "9".  Gregoire Montavon committed Jan 11, 2021 185 186 187 

2   PyTorch Implementation for the VGG-16 Network

 Gregoire Montavon committed Jan 12, 2021 188 

In the example above, LRP rules could be easily expressed in terms of matrix-vector operations. In practice, state-of-the-art neural networks such as VGG-16 make use of more complex layers such as convolutions and pooling. In this case, LRP rules are more conveniently implemented by casting the operations of the four-step procedure above as forward and gradient evaluations on these layers. These operations are readily available in neural network frameworks such as PyTorch and TensorFlow, and can therefore be reused for the purpose of implementing LRP. Here, we take the VGG-16 pretrained network for image classification. For this network, we consider the task of explaining the evidence for the class "castle" it has found in the following image:

 Gregoire Montavon committed Jan 11, 2021 189 190 191 192 193 194 

The image is first loaded in the notebook.

 Gregoire Montavon committed Jan 12, 2021 195 196 197 198 python import cv2 img = numpy.array(cv2.imread('castle.jpg'))[...,::-1]/255.0   Gregoire Montavon committed Jan 11, 2021 199   Gregoire Montavon committed Jan 12, 2021 200 

It is then converted to a torch tensor of appropriate dimensions and normalized to be given as input to the VGG-16 network.

 Gregoire Montavon committed Jan 11, 2021 201 202   Gregoire Montavon committed Jan 12, 2021 203 204 python import torch  Gregoire Montavon committed Jan 11, 2021 205   Gregoire Montavon committed Jan 12, 2021 206 207 mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(1,-1,1,1) std = torch.Tensor([0.229, 0.224, 0.225]).reshape(1,-1,1,1)  Gregoire Montavon committed Jan 11, 2021 208   Gregoire Montavon committed Jan 12, 2021 209 210 X = (torch.FloatTensor(img[numpy.newaxis].transpose([0,3,1,2])*1) - mean) / std   Gregoire Montavon committed Jan 11, 2021 211   Gregoire Montavon committed Jan 12, 2021 212 213 214 215 216 217 218 219 220 221 

The VGG-16 network is then loaded and its top-level dense layers are converted into equivalent 1x1 convolutions.

python import torchvision model = torchvision.models.vgg16(pretrained=True); model.eval() layers = list(model._modules['features']) + utils.toconv(list(model._modules['classifier'])) L = len(layers)   Gregoire Montavon committed Jan 11, 2021 222 223 224 

2.1   Predicting the class of an image

 Gregoire Montavon committed Jan 12, 2021 225 

The input can then be propagated in the network and the activations at each layer are collected:

 Gregoire Montavon committed Jan 11, 2021 226 227   Gregoire Montavon committed Jan 12, 2021 228 229 230 231 python A = [X]+[None]*L for l in range(L): A[l+1] = layers[l].forward(A[l])   Gregoire Montavon committed Jan 11, 2021 232   Gregoire Montavon committed Jan 12, 2021 233 

Activations in the top layer are the scores the neural network predicts for each class. We show below the 10 classes with highest score:

 Gregoire Montavon committed Jan 11, 2021 234 235   Gregoire Montavon committed Jan 12, 2021 236 237 238 239 240 241 python scores = numpy.array(A[-1].data.view(-1)) ind = numpy.argsort(-scores) for i in ind[:10]: print('%20s (%3d): %6.3f'%(utils.imgclasses[i][:20],i,scores[i]))   Gregoire Montavon committed Jan 11, 2021 242 243 244 245 246 247 248 249 250 251 252 253 254  castle (483): 11.029 church, church build (497): 9.522 monastery (663): 9.401 bell cote, bell cot (442): 9.047 cinema, movie theate (498): 8.016 analog clock (409): 7.108 street sign (919): 7.102 traffic light, traff (920): 7.058 thatch, thatched roo (853): 6.978 alp (970): 6.812  Gregoire Montavon committed Jan 12, 2021 255 

We observe that the neuron castle (index 483) has the highest score. This is expected due to the presence of a castle in the image. Note that other building-related classes are also assigned a high score, as well as classes corresponding to other objects present in the image (e.g. street sign and traffic light).

 Gregoire Montavon committed Jan 11, 2021 256 257 258 

2.2   Explaining the prediction with LRP

 Gregoire Montavon committed Jan 12, 2021 259 

The following code iterates from the top layer to the first layer in reverse order and applies propagation rules at each layer. Top-layer activations are first multiplied by the mask to retain only the predicted evidence for the class "castle".

 Gregoire Montavon committed Jan 11, 2021 260 261   Gregoire Montavon committed Jan 12, 2021 262 263 264 265 266 python T = torch.FloatTensor((1.0*(numpy.arange(1000)==483).reshape([1,1000,1,1]))) R = [None]*L + [(A[-1]*T).data]   Gregoire Montavon committed Jan 11, 2021 267   Gregoire Montavon committed Jan 12, 2021 268 

This evidence can then be propagated backward in the network by applying propagation rules at each layer.

 Gregoire Montavon committed Jan 11, 2021 269   Gregoire Montavon committed Jan 12, 2021 270 

Convolution layers: Observing that convolutions are special types of linear layers, we can use the same propagation rules as in the MNIST example, and a similar four-step procedure for applying these rules. Steps 2 and 4 are simple element-wise computations. Step 1 can be implemented as a forward computation in the layer, where we have preliminary transformed the layer parameters, and where we apply the increment function afterwards. As shown in the LRP overview paper, Step 3 can instead be computed as a gradient in the space of input activations:

 Gregoire Montavon committed Jan 11, 2021 271   Gregoire Montavon committed Jan 12, 2021 272   Gregoire Montavon committed Jan 11, 2021 273   Gregoire Montavon committed Jan 12, 2021 274 

where sk is treated as constant.

 Gregoire Montavon committed Jan 11, 2021 275   Gregoire Montavon committed Jan 12, 2021 276 

Pooling layers: It is suggested in Section 10.3.2 of the paper to treat max-pooling layers as average pooling layers in the backward pass. Observing that average pooling is also a special linear layer, the same propagation rules as for the convolutional layers become applicable.

 Gregoire Montavon committed Jan 11, 2021 277   Gregoire Montavon committed Jan 12, 2021 278 

In the following code, we iterate the propagation procedure from the top-layer towards the lower layers. Whenever we meet a max-pooling layer, we convert it into an average pooling layer. The function rho and incr are set differently at each layer, following the strategy of Section 10.3.

 Gregoire Montavon committed Jan 11, 2021 279 280   Gregoire Montavon committed Jan 12, 2021 281 282 python for l in range(1,L)[::-1]:  Gregoire Montavon committed Jan 11, 2021 283   Gregoire Montavon committed Jan 12, 2021 284  A[l] = (A[l].data).requires_grad_(True)  Gregoire Montavon committed Jan 11, 2021 285   Gregoire Montavon committed Jan 12, 2021 286  if isinstance(layers[l],torch.nn.MaxPool2d): layers[l] = torch.nn.AvgPool2d(2)  Gregoire Montavon committed Jan 11, 2021 287   Gregoire Montavon committed Jan 12, 2021 288  if isinstance(layers[l],torch.nn.Conv2d) or isinstance(layers[l],torch.nn.AvgPool2d):  Gregoire Montavon committed Jan 11, 2021 289   Gregoire Montavon committed Jan 12, 2021 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310  if l <= 16: rho = lambda p: p + 0.25*p.clamp(min=0); incr = lambda z: z+1e-9 if 17 <= l <= 30: rho = lambda p: p; incr = lambda z: z+1e-9+0.25*((z**2).mean()**.5).data if l >= 31: rho = lambda p: p; incr = lambda z: z+1e-9 z = incr(utils.newlayer(layers[l],rho).forward(A[l])) # step 1 s = (R[l+1]/z).data # step 2 (z*s).sum().backward(); c = A[l].grad # step 3 R[l] = (A[l]*c).data # step 4 else: R[l] = R[l+1]  As each layer is composed of a collection of two-dimensional feature maps, relevance scores at each layer can be visualized as a two-dimensional map. Here, relevance scores are pooled over all feature maps at a given layer. The two-dimensional maps are shown for a selection of VGG-16 layers. python for i,l in enumerate([31,21,11,1]): utils.heatmap(numpy.array(R[l][0]).sum(axis=0),0.5*i+1.5,0.5*i+1.5)   Gregoire Montavon committed Jan 11, 2021 311 312   Gregoire Montavon committed Jan 12, 2021 313 ![png](tutorial_files/tutorial_35_0.png)  Gregoire Montavon committed Jan 11, 2021 314 315 316   Gregoire Montavon committed Jan 12, 2021 317 ![png](tutorial_files/tutorial_35_1.png)  Gregoire Montavon committed Jan 11, 2021 318 319 320   Gregoire Montavon committed Jan 12, 2021 321 ![png](tutorial_files/tutorial_35_2.png)  Gregoire Montavon committed Jan 11, 2021 322 323 324   Gregoire Montavon committed Jan 12, 2021 325 ![png](tutorial_files/tutorial_35_3.png)  Gregoire Montavon committed Jan 11, 2021 326 327   Gregoire Montavon committed Jan 12, 2021 328 

We observe that the explanation becomes increasingly resolved spatially. Note that, like for the MNIST example, we have stopped the propagation procedure one layer before the pixels because the rule we have used is not applicable to pixel layers. Like for the MNIST case, we need ot apply the pixel-specific zB-rule for this last layer. This rule can again be implemented in terms of forward passes and gradient computations.

 Gregoire Montavon committed Jan 11, 2021 329 330   Gregoire Montavon committed Jan 12, 2021 331 332 333 334 335 336 337 338 339 340 341 342 343 python A[0] = (A[0].data).requires_grad_(True) lb = (A[0].data*0+(0-mean)/std).requires_grad_(True) hb = (A[0].data*0+(1-mean)/std).requires_grad_(True) z = layers[0].forward(A[0]) + 1e-9 # step 1 (a) z -= utils.newlayer(layers[0],lambda p: p.clamp(min=0)).forward(lb) # step 1 (b) z -= utils.newlayer(layers[0],lambda p: p.clamp(max=0)).forward(hb) # step 1 (c) s = (R[1]/z).data # step 2 (z*s).sum().backward(); c,cp,cm = A[0].grad,lb.grad,hb.grad # step 3 R[0] = (A[0]*c+lb*cp+hb*cm).data # step 4   Gregoire Montavon committed Jan 11, 2021 344   Gregoire Montavon committed Jan 12, 2021 345 The relevance scores obtained in the pixel layer can now be summed over the RGB channels to indicate actual pixel-wise contributions.  Gregoire Montavon committed Jan 11, 2021 346 347   Gregoire Montavon committed Jan 12, 2021 348 349 350 python utils.heatmap(numpy.array(R[0][0]).sum(axis=0),3.5,3.5)   Gregoire Montavon committed Jan 11, 2021 351 352   Gregoire Montavon committed Jan 12, 2021 353 ![png](tutorial_files/tutorial_39_0.png)  Gregoire Montavon committed Jan 11, 2021 354 355   Gregoire Montavon committed Jan 12, 2021 356 We observe that the heatmap highlights the outline of the castle as evidence for the corresponding class. Some elements such as the traffic sign or the roof on the left are seen as having a negative effect on the neuron "castle" and are consequently highlighted in blue.