-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels_dmon.py
More file actions
308 lines (210 loc) · 10.1 KB
/
models_dmon.py
File metadata and controls
308 lines (210 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from utils import *
class GCNLayer_Dmon(nn.Module):
def __init__(self, n_in, n_out, bias=True):
"""
args:
n_in: input features
n_out: number of output dimensions
"""
super(GCNLayer_Dmon, self).__init__()
self.n_in = n_in
self.n_out = n_out
self.weight = Parameter(torch.FloatTensor(n_in, n_out))
self.skip_weight = Parameter(torch.FloatTensor(n_in, n_out))
if bias:
self.bias = Parameter(torch.FloatTensor(n_out))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.skip_weight.data.uniform_(-stdv, stdv)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, A, X=None):
"""
args:
X: node features; shape: [batch_size, num_nodes, n_in]
if X is None, an identity matrix will be used
A: normalized adjacency matrix (not including self-loops)
shape: [batch_size, num_nodes, num_nodes]
"""
num_nodes = A.size(1)
if X is None:
X = torch.eye(num_nodes).unsqueeze(0)
X = X.expand(A.size(0), X.size(1), X.size(2)) #expand to batch size
agg = torch.matmul(torch.matmul(A, X), self.weight) #aggregation of neighbours
skip = torch.matmul(X, self.skip_weight)
if self.bias is not None:
return skip+agg+self.bias
else:
return skip+agg
class GCNLayer_Kipf(nn.Module):
def __init__(self, n_in, n_out, bias=True):
"""
args:
n_in: input features
n_out: number of output dimensions
"""
super(GCNLayer_Kipf, self).__init__()
self.n_in = n_in
self.n_out = n_out
self.weight = Parameter(torch.FloatTensor(n_in, n_out))
if bias:
self.bias = Parameter(torch.FloatTensor(n_out))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, A, X=None):
"""
args:
X: node features; shape: [batch_size, num_nodes, n_in]
if X is None, an identity matrix will be used
A: normalized adjacency matrix
shape: [batch_size, num_nodes, num_nodes]
"""
num_nodes = A.size(1)
if X is None:
X = torch.eye(num_nodes).unsqueeze(0)
X = X.expand(A.size(0), X.size(1), X.size(2)) #expand to batch size
outputs = torch.matmul(A, X)
if self.bias is not None:
outputs = torch.matmul(outputs, self.weight)+self.bias
else:
outputs = torch.matmul(outputs, self.weight)
return outputs
class DMoN(nn.Module):
"""Implementation of Deep Modularity Network (DMoN) Layer.
Deep Modularity Network (DMoN) Layer implementation
DMoN optimizes modularity clustering objective in a
fully unsupervised mode
args:
n_clusters: Number of clusters in the model
collapse_regularization: Collapse regularization weight
dropout_rate: Dropout rate. The dropout in applied to the
intermediate representations before softmax
do_unpooling: Parameter controlling whether to perform
unpooling of the features with respect to their soft clusters.
If true, shape of the input is preserved.
"""
def __init__(self, n_in ,n_clusters, collapse_regularization=0.1,
dropout_rate =0, do_unpooling = False):
super(DMoN, self).__init__()
self.n_in = n_in
self.n_clusters = n_clusters
self.collapse_regularization = collapse_regularization
self.dropout_rate = dropout_rate
self.do_unpooling = do_unpooling
self.fc = nn.Linear(n_in, n_clusters)
self.dropout = nn.Dropout(p=dropout_rate)
def forward(self, A, X=None):
"""
Performs DMoN clustering according to input features and input graph
args:
X: node features, shape:[batch_size, num_nodes, n_in]
A: adjacency matrix, shape:[batch_size, num_nodes, num_nodes]
"""
if X is None:
X = torch.eye(num_nodes).unsqueeze(0)
X = X.expand(A.size(0), X.size(1), X.size(2)) #expand to batch size
batch_size = A.size(0)
num_nodes = A.size(1) #get number of nodes
assignments = torch.softmax(self.dropout(self.fc(X)), dim=-1) #shape:[batch_size,num_nodes,n_clusters]
cluster_sizes = assignments.sum(1) #shape: [batch_size, num_clusters]
assignments_pooling = assignments / cluster_sizes #shape:[batch_size,num_nodes,n_clusters]
degrees = A.sum(-1) #shape: [batch_size, num_nodes]
degrees = degrees.unsqueeze(-1) #shape:[batch_size, num_nodes, 1]
edge_weights = degrees.sum(-1).sum(-1) #shape: [batch_size]
#graph_pooled = torch.matmul(A, assignments).transpose(-1,-2) #[batch_size, n_clusters, num_nodes]
graph_pooled = torch.matmul(assignments.transpose(-1,-2),A)
graph_pooled = torch.matmul(graph_pooled, assignments) #[batch_size, n_clusters, n_clusters]
#Compute the rank-1 normalizer matrix S^T*d*d^T*S
normalizer_left = torch.matmul(assignments.transpose(-1,-2), degrees)
#shape: [batch_size, n_cluster, 1]
normalizer_right = torch.matmul(degrees.transpose(-1,-2), assignments)
#shape: [batch_size, 1, n_cluster]
normalizer = torch.matmul(normalizer_left, normalizer_right)/2/edge_weights
#shape:[batch_size, n_cluster, n_cluster]
spectral_loss = -torch.diagonal(graph_pooled-normalizer, dim1=-2, dim2=-1).sum()/2/edge_weights/batch_size
if next(self.parameters()).is_cuda:
collapse_loss = ((torch.norm(cluster_sizes)/num_nodes*torch.sqrt(torch.cuda.FloatTensor([self.n_clusters]))-1)/batch_size)
else:
collapse_loss = ((torch.norm(cluster_sizes)/num_nodes*torch.sqrt(torch.FloatTensor([self.n_clusters]))-1)/batch_size)
return assignments, spectral_loss, collapse_loss
class GCN_DMoN(nn.Module):
def __init__(self, n_in, n_hid, n_out ,n_clusters, gcn_type="dmon",
activation="selu", collapse_regularization=0.1,
dropout_rate=0):
super(GCN_DMoN, self).__init__()
if gcn_type.lower() == "dmon":
self.gcn_h = GCNLayer_Dmon(n_in, n_hid)
self.gcn_o = GCNLayer_Dmon(n_hid, n_out)
else:
self.gcn_h = GCNLayer_Kipf(n_in, n_hid)
self.gcn_o = GCNLayer_Kipf(n_hid, n_out)
self.dmon = DMoN(n_out, n_clusters, collapse_regularization, dropout_rate)
if activation.lower() == "relu":
self.activation = F.relu
else: self.activation = F.selu
def forward(self, A, X):
"""
args:
A: adjacency matrix; shape:[batch_size, num_nodes]
X:node features; shape:[batch_size, num_nodes, n_in]
"""
if isinstance(self.gcn_h, GCNLayer_Dmon):
A_normalized = normalize_graph(A, add_self_loops=False)
else:
A_normalized = normalize_graph(A, add_self_loops=True)
hidden = self.activation(self.gcn_h(A_normalized, X))
#shape: [batch_size, num_nodes, n_hid]
hidden = self.activation(self.gcn_o(A_normalized, hidden))
assignments, spectral_loss, collapse_loss = self.dmon(A, hidden)
return assignments, spectral_loss, collapse_loss
class DMoN_GALA(nn.Module):
def __init__(self, n_in, n_hid, n_out, n_clusters, gcn_type="dmon",
activation="selu", collapse_regularization=0.1, dropout_rate=0):
super(DMoN_GALA, self).__init__()
if gcn_type.lower() == "dmon":
self.gcn_h = GCNLayer_Dmon(n_in, n_hid)
self.gcn_o = GCNLayer_Dmon(n_hid, n_out)
else:
self.gcn_h = GCNLayer_Kipf(n_in, n_hid)
self.gcn_o = GCNLayer_Kipf(n_hid, n_out)
#GALA decoder part
self.gcn_dh = GCNLayer_Kipf(n_out, n_hid)
self.gcn_do = GCNLayer_Kipf(n_hid, n_in)
self.dmon = DMoN(n_out, n_clusters, collapse_regularization, dropout_rate)
if activation.lower() == "relu":
self.activation = F.relu
else: self.activation = F.selu
def forward(self, A, X):
"""
args:
A: adjacency matrix; shape:[batch_size, num_nodes]
X:node features; shape:[batch_size, num_nodes, n_in]
"""
if isinstance(self.gcn_h, GCNLayer_Dmon):
A_normalized = normalize_graph(A, add_self_loops=False)
else:
A_normalized = normalize_graph(A, add_self_loops=True)
A_sharp = laplacian_sharpen(A)
hidden = self.activation(self.gcn_h(A_normalized, X))
#shape: [batch_size, num_nodes, n_hid]
hidden = self.activation(self.gcn_o(A_normalized, hidden))
assignments, spectral_loss, collapse_loss = self.dmon(A, hidden)
X_rec = self.activation(self.gcn_dh(A_sharp, hidden))
X_rec = torch.sigmoid(self.gcn_do(A_sharp, X_rec))
rec_loss = torch.norm(X-X_rec)/(X.size(0)*X.size(1))
return assignments, spectral_loss, collapse_loss, rec_loss