README.md 11.9 KB
Newer Older
thomas_schnake's avatar
thomas_schnake committed
1
# Demo code of GNN-LRP
thomas_schnake's avatar
thomas_schnake committed
2

thomas_schnake's avatar
thomas_schnake committed
3
This notebook provides a demo of the interpretation method [GNN-LRP](https://arxiv.org/abs/2006.03589), available at
thomas_schnake's avatar
thomas_schnake committed
4
5
6
7
8

<blockquote style='background-color:#EEEEEE; padding: 3px; border: 1px dashed #999999'>
T Schnake, O Eberle, J Lederer, S Nakajima, K T. Schütt, KR Müller, G Montavon<br><a href="https://arxiv.org/abs/2006.03589">Higher-Order Explanations of Graph Neural Networks via Relevant Walks</a><br><font color="#008800">arXiv:2006.03589, 2020</font>
</blockquote>

thomas_schnake's avatar
thomas_schnake committed
9
which explains the network prediction strategy of a GNN by extracting relevant walks on the input graph.
thomas_schnake's avatar
thomas_schnake committed
10
11

 We will train a [GCN](https://arxiv.org/abs/1609.02907) on scale-free Barabási-Albert graphs to detect its growth parameter. We will give an implementation of the GNN-LRP method and apply it to the trained network. Finally, we will show qualitative evidence for the network predictions, by visualizing the heatmaps of relevant walks.
thomas_schnake's avatar
thomas_schnake committed
12

thomas_schnake's avatar
readme    
thomas_schnake committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
```python
import sys
import numpy
import random
import torch
import igraph
import utils
import matplotlib.pyplot as plt
```

## 1. The Data
We want to train the model on scale-free graphs, more precisely on Barabási-Albert (BA) graphs with different growth parameters. We want to define a synthetic data set and visualize the graphs. In our case, we fix the number of nodes in the graph to be 10 and set the growth of the BA graphs to be either 1 or 2.


```python
# Function to generate a BA graph
def scalefreegraph(seed=0,embed=False,growth=None):
    random = numpy.random.mtrand.RandomState(seed)
    N = 10
    A = numpy.zeros([N,N])
    A[1,0] = 1
    A[0,1] = 1
    growth = growth if growth is not None else random.randint(1,3)
    N0 = 2
    for i in range(N0,N):
        if   growth==1: tt = 1   # Barabasi-Albert 1
        elif growth==2: tt = 2   # Barabasi-Albert 2
        else:
            tt = 1 + 1*((growth-1)>random.uniform(0,1))
        p = A.sum(axis=0)/A.sum()
        for j in random.choice(N,tt,p=p,replace=False):
            A[i,j] = 1
            A[j,i] = 1
    r = random.permutation(len(A))
    A = A[r][:,r]*1.0

    # Add Self-Connections
    A = A + numpy.identity(len(A))

    # Build Data Structures
    D = A.sum(axis=1)
    L = torch.FloatTensor(A/(numpy.outer(D,D)**.5+1e-9)) # laplacian

    return {
        'adjacency':torch.FloatTensor(A),
        'laplacian':L,
        'target':growth,
        'layout':utils.layout(A,seed) if embed else None,
        'walks':utils.walks(A)
    }

```

### Looking at the data


```python
sample_ids = [1,3,4,5]
####################

# Function to visualise a graph
def vis_graph(g, ax):
    # Arange graph layout
    r = g['layout']
    r = r - r.min(axis=0)
    r = r / r.max(axis=0) * 2 - 1
thomas_schnake's avatar
thomas_schnake committed
79

thomas_schnake's avatar
readme    
thomas_schnake committed
80
81
82
83
84
85
86
87
88
89
90
91
    # Plot the graph
    N = len(g['adjacency'])
    for i in numpy.arange(N):
        for j in numpy.arange(N):
            if g['adjacency'][i,j] > 0 and i != j: plt.plot([r[i,0], r[j,0]], [r[i,1], r[j,1]], color='gray', lw=0.5, ls='dotted')
    ax.plot(r[:,0],r[:,1],'o',color='black',ms=3)

# Plotting
plt.figure(figsize=(3*len(sample_ids), 3))
for ids, seed in enumerate(sample_ids):
    ax =plt.subplot( 1,len(sample_ids), ids+1)
    sfg = scalefreegraph(seed=seed, embed=True)
thomas_schnake's avatar
thomas_schnake committed
92

thomas_schnake's avatar
readme    
thomas_schnake committed
93
94
95
96
97
98
99
100
101
102
103
104
105
    vis_graph(sfg, ax=ax)
    plt.subplots_adjust(left=0,right=1,bottom=0,top=1)

    plt.axis('off')
    plt.xlim(-1.2,1.2)
    plt.ylim(-1.2,1.2)
    ax.set_title('growth={}'.format( sfg['target']))

plt.show()
plt.close()
```


thomas_schnake's avatar
thomas_schnake committed
106

thomas_schnake's avatar
readme    
thomas_schnake committed
107
![png](demo_gnn_lrp_files/demo_gnn_lrp_5_0.png)
thomas_schnake's avatar
thomas_schnake committed
108

thomas_schnake's avatar
readme    
thomas_schnake committed
109
110
111


## 2. The Model and its Explanation
thomas_schnake's avatar
thomas_schnake committed
112
We want to learn a model <img src="https://latex.codecogs.com/gif.latex?f(\Lambda)" /> that is able to predict the growth parameter of a scale-free graph encoded in <img src="https://latex.codecogs.com/gif.latex?\Lambda" /> . We choose <img src="https://latex.codecogs.com/gif.latex?f(\Lambda)" />  to consist of an FFN that maps the initial node representation into the hidden space, followed by some *interaction blocks* to learn the graph structure, followed by an FFN that maps the final hidden representation onto the target space. The *interaction blocks* consist of [GCNs](https://arxiv.org/abs/1609.02907), which can be described by the strategy
thomas_schnake's avatar
thomas_schnake committed
113

thomas_schnake's avatar
thomas_schnake committed
114
<img src="https://latex.codecogs.com/gif.latex?\text{aggregate:} \quad Z_t = \Lambda H_{t-1}" /><br/><br/>
thomas_schnake's avatar
thomas_schnake committed
115
<img src="https://latex.codecogs.com/gif.latex?\text{combine:} \quad  H_t = \rho(Z_t W_t)" />
thomas_schnake's avatar
thomas_schnake committed
116

thomas_schnake's avatar
thomas_schnake committed
117
where, <img src="https://latex.codecogs.com/gif.latex?\rho" /> is the ReLU activation, <img src="https://latex.codecogs.com/gif.latex?W_t" /> are the learnable parameter, and <img src="https://latex.codecogs.com/gif.latex?\Lambda" /> is the renormalized convolutional operator as proposed in [[Kipf et al.]](https://arxiv.org/abs/1609.02907). Note that we limit our model to a depth of 2 interaction blocks, which will result in walks consisting of 3 nodes.
thomas_schnake's avatar
readme    
thomas_schnake committed
118

thomas_schnake's avatar
thomas_schnake committed
119
To explain the prediction of <img src="https://latex.codecogs.com/gif.latex?f(\Lambda)" /> with GNN-LRP we first rewrite the combine step in the interaction blocks to be
thomas_schnake's avatar
thomas_schnake committed
120
121
122

<img src="https://latex.codecogs.com/gif.latex? P_t \gets Z_t W_t^\wedge" /><br/><br/>
<img src="https://latex.codecogs.com/gif.latex?Q_t \gets P_t \odot [\kern1pt \rho(Z_t W_t) \oslash  P_t]{.data}" /><br/><br/>
thomas_schnake's avatar
thomas_schnake committed
123
<img src="https://latex.codecogs.com/gif.latex?H_t \gets Q_t \odot M_K + [Q_t]{.data} \odot (1-M_K)" /><br/><br/>
thomas_schnake's avatar
thomas_schnake committed
124

thomas_schnake's avatar
thomas_schnake committed
125
for masking array <img src="https://latex.codecogs.com/gif.latex?M_K" /> that is one for neurons associated to node <img src="https://latex.codecogs.com/gif.latex?K" /> and zero elsewhere. The <img src="https://latex.codecogs.com/gif.latex?{.data}" /> operation detaches the considered tensor from the gradient. And second, with masks chosen in a way that it selects for some walk <img src="https://latex.codecogs.com/gif.latex?\mathcal{W} = (I,J,K)" /> of interest, we compute the relevance score for <img src="https://latex.codecogs.com/gif.latex?\mathcal{W}" /> as
thomas_schnake's avatar
thomas_schnake committed
126
127
128

<img src="https://latex.codecogs.com/gif.latex?R_\mathcal{W}  = \big\langle \text{Autograd}( f, H_{0,I}) \,,\, H_{0,I}\big\rangle" /> <br/> <br/>
where the initial vector <img src="https://latex.codecogs.com/gif.latex?H_{0,I}" /> is in our case one-hot-encoded vector for the node <img src="https://latex.codecogs.com/gif.latex?I" />.
thomas_schnake's avatar
readme    
thomas_schnake committed
129
130
131
132
133
134
135
136
137
138


```python
class GraphNet:
    def __init__(self,d,h,c):
        self.U  = torch.nn.Parameter(torch.FloatTensor(numpy.random.normal(0,d**-.5,[d,h])))
        self.W1 = torch.nn.Parameter(torch.FloatTensor(numpy.random.normal(0,h**-.5,[h,h])))
        self.W2 = torch.nn.Parameter(torch.FloatTensor(numpy.random.normal(0,h**-.5,[h,h])))
        self.V  = torch.nn.Parameter(torch.FloatTensor(numpy.random.normal(0,h**-.5,[h,c])))
        self.params = [self.U,self.W1,self.W2,self.V]
thomas_schnake's avatar
thomas_schnake committed
139

thomas_schnake's avatar
readme    
thomas_schnake committed
140
141
142
143
144
145
146
    def forward(self,A):
        H = torch.eye(len(A))
        H = H.matmul(self.U).clamp(min=0)
        H = (A.transpose(1,0).matmul(H.matmul(self.W1))).clamp(min=0)
        H = (A.transpose(1,0).matmul(H.matmul(self.W2))).clamp(min=0)
        H = H.matmul(self.V).clamp(min=0)
        return H.mean(dim=0)
thomas_schnake's avatar
thomas_schnake committed
147

thomas_schnake's avatar
readme    
thomas_schnake committed
148
149
150
151
152
153
154
155
156
    def lrp(self,A,gamma,l,inds):
        if inds is not None:
            j,k = inds
            Mj = torch.FloatTensor(numpy.eye(len(A))[j][:,numpy.newaxis])
            Mk = torch.FloatTensor(numpy.eye(len(A))[k][:,numpy.newaxis])

        W1p = self.W1+gamma*self.W1.clamp(min=0)
        W2p = self.W2+gamma*self.W2.clamp(min=0)
        Vp  = self.V+gamma*self.V.clamp(min=0)
thomas_schnake's avatar
thomas_schnake committed
157

thomas_schnake's avatar
readme    
thomas_schnake committed
158
159
160
161
162
163
164
165
        X = torch.eye(len(A))
        X.requires_grad_(True)

        H  = X.matmul(self.U).clamp(min=0)

        Z  = A.transpose(1,0).matmul(H.matmul(self.W1))
        Zp = A.transpose(1,0).matmul(H.matmul(W1p))
        H  = (Zp*(Z/(Zp+1e-6)).data).clamp(min=0)
thomas_schnake's avatar
thomas_schnake committed
166

thomas_schnake's avatar
readme    
thomas_schnake committed
167
168
169
170
171
        if inds is not None: H = H * Mj + (1-Mj) * (H.data)

        Z  = A.transpose(1,0).matmul(H.matmul(self.W2))
        Zp = A.transpose(1,0).matmul(H.matmul(W2p))
        H  = (Zp*(Z/(Zp+1e-6)).data).clamp(min=0)
thomas_schnake's avatar
thomas_schnake committed
172

thomas_schnake's avatar
readme    
thomas_schnake committed
173
174
175
176
177
        if inds is not None: H = H * Mk + (1-Mk) * (H.data)

        Z  = H.matmul(self.V)
        Zp = H.matmul(Vp)
        H  = (Zp*(Z/(Zp+1e-6)).data).clamp(min=0)
thomas_schnake's avatar
thomas_schnake committed
178

thomas_schnake's avatar
readme    
thomas_schnake committed
179
        Y = H.mean(dim=0)[l]
thomas_schnake's avatar
thomas_schnake committed
180

thomas_schnake's avatar
readme    
thomas_schnake committed
181
        Y.backward()
thomas_schnake's avatar
thomas_schnake committed
182

thomas_schnake's avatar
readme    
thomas_schnake committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        return X.data*X.grad
```

### Train the model
We choose a hidden dimension of 64 and train the model on 20000 samples. We use stochastic gradient descent with respect to the mean square error, to optimize the parameter of the model.


```python
def train_scalefree():
    model = GraphNet(10,64,2)
    optimizer = torch.optim.SGD(model.params, lr=0.001, momentum=0.99)
    erravg = 0.5
    print('Train model:')
    print('   iter | err')
    print('   -----------')
    for it in range(0,20001):
        optimizer.zero_grad()
        g = scalefreegraph(seed=it, embed=False)
        y = model.forward(g['laplacian'])
        err = (y[0]-(g['target']==1)*1.0)**2 + (y[1]-(g['target']==2)*1.0)**2
        erravg = 0.999*erravg + 0.001*err.data.numpy()
        err.backward()
        optimizer.step()
        if it%1000==0:
            print('% 8d %.3f'%(it,erravg))
    return model
```


```python
# Train the model
model = train_scalefree()
```

    Train model:
       iter | err
       -----------
           0 0.500
        1000 0.466
        2000 0.218
        3000 0.098
        4000 0.043
        5000 0.022
        6000 0.012
        7000 0.009
        8000 0.009
        9000 0.007
       10000 0.007
       11000 0.006
       12000 0.006
       13000 0.008
       14000 0.007
       15000 0.006
       16000 0.005
       17000 0.005
       18000 0.005
       19000 0.004
       20000 0.004


### The accuracy of the model


```python
test_size = 200

num_false = 0
for it in range(20001, 20001 + test_size):
    g = scalefreegraph(seed=it, embed=False)
    y = model.forward(g['laplacian'])
    prediction = int(y.data.argmax()) +1

    if prediction != g['target']: num_false += 1

print('For {} test samples, the model predict the growth parameter with an accuracy of {} %'.format(test_size, 100 * (test_size - num_false)/test_size))
```

    For 200 test samples, the model predict the growth parameter with an accuracy of 100.0 %


## 3. Explaining with GNN-LRP
thomas_schnake's avatar
thomas_schnake committed
264
To understand the model prediction, we would like to visualize the relevance scores <img src="https://latex.codecogs.com/gif.latex?R_\mathcal{W}" />" for all walks <img src="https://latex.codecogs.com/gif.latex?\mathcal{W}" /> on <img src="https://latex.codecogs.com/gif.latex?\Lambda" />. In the plot below we see a set of graphs with different growth parameters and their corresponding relevant walks. We differentiate for all graphs between evidence for the growth parameter 1 or 2.
thomas_schnake's avatar
readme    
thomas_schnake committed
265
266
267
268
269
270
271
272


```python
def explain(g, nn, t, gamma=None, ax=None):
    # Arrange graph layout
    r = g['layout']
    r = r - r.min(axis=0)
    r = r / r.max(axis=0) * 2 - 1
thomas_schnake's avatar
thomas_schnake committed
273

thomas_schnake's avatar
readme    
thomas_schnake committed
274
275
276
277
278
279
    # Plot the graph
    N = len(g['adjacency'])
    for i in numpy.arange(N):
        for j in numpy.arange(N):
            if g['adjacency'][i,j] > 0 and i != j: plt.plot([r[i,0],r[j,0]],[r[i,1],r[j,1]], color='gray', lw=0.5, ls='dotted')
    ax.plot(r[:,0],r[:,1],'o',color='black',ms=3)
thomas_schnake's avatar
thomas_schnake committed
280

thomas_schnake's avatar
readme    
thomas_schnake committed
281
282
283
284
285
286
    for (i,j,k) in g['walks']:
        R = nn.lrp(g['laplacian'], gamma, t, (j,k))[i].sum()
        tx,ty = utils.shrink([r[i,0],r[j,0],r[k,0]],[r[i,1],r[j,1],r[k,1]])

        if R > 0.0:
            alpha = numpy.clip(20*R.data.numpy(),0,1)
thomas_schnake's avatar
thomas_schnake committed
287
            ax.plot(tx,ty,alpha=alpha,color='red',lw=1.2)
thomas_schnake's avatar
readme    
thomas_schnake committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310

        if R < -0.0:
            alpha = numpy.clip(-20*R.data.numpy(),0,1)
            ax.plot(tx,ty,alpha=alpha,color='blue',lw=1.2)
```


```python
gamma=0.1

for target in [0,1]:
    plt.figure(figsize=(3*len(sample_ids), 3))
    for ids, seed in enumerate(sample_ids):
        ax =plt.subplot( 1,len(sample_ids), ids+1)
        sfg = scalefreegraph(seed=seed, embed=True)

        # Explain
        explain(sfg, model, target, gamma=gamma, ax=ax)
        plt.subplots_adjust(left=0,right=1,bottom=0,top=1)

        plt.axis('off')
        plt.xlim(-1.2,1.2)
        plt.ylim(-1.2,1.2)
thomas_schnake's avatar
thomas_schnake committed
311

thomas_schnake's avatar
readme    
thomas_schnake committed
312
313
314
315
316
317
    plt.suptitle('Evidence for growth={} with $\gamma={}$'.format(target+1, gamma), size=14)
    plt.show()
    plt.close()
```


thomas_schnake's avatar
thomas_schnake committed
318

thomas_schnake's avatar
readme    
thomas_schnake committed
319
320
321
322
![png](demo_gnn_lrp_files/demo_gnn_lrp_15_0.png)



thomas_schnake's avatar
thomas_schnake committed
323
324


thomas_schnake's avatar
readme    
thomas_schnake committed
325
![png](demo_gnn_lrp_files/demo_gnn_lrp_15_1.png)