-
Notifications
You must be signed in to change notification settings - Fork 8
/
layer.py
259 lines (218 loc) · 10 KB
/
layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author: XiaShan
@Contact: [email protected]
@Time: 2024/4/17 20:40
"""
from typing import Tuple
import torch
from torch import nn
from torch_geometric.utils import degree
class CentralityEncoding(nn.Module):
def __init__(self, max_in_degree: int, max_out_degree: int, node_dim: int):
"""
:param max_in_degree: max in degree of nodes
:param max_out_degree: max out degree of nodes
:param node_dim: hidden dimensions of node features
"""
super().__init__()
self.max_in_degree = max_in_degree
self.max_out_degree = max_out_degree
self.node_dim = node_dim
self.z_in = nn.Parameter(torch.randn((max_in_degree, node_dim)))
self.z_out = nn.Parameter(torch.randn((max_out_degree, node_dim)))
def forward(self, x: torch.Tensor, edge_index: torch.LongTensor) -> torch.Tensor:
"""
:param x: node feature matrix
:param edge_index: edge_index of graph (adjacency list)
:return: torch.Tensor, node embeddings after Centrality encoding
"""
num_nodes = x.shape[0]
in_degree = self.decrease_to_max_value(degree(index=edge_index[1], num_nodes=num_nodes).long(),
self.max_in_degree - 1) # 每个节点的入度
out_degree = self.decrease_to_max_value(degree(index=edge_index[0], num_nodes=num_nodes).long(),
self.max_out_degree - 1)
x += self.z_in[in_degree] + self.z_out[out_degree] # 将每个节点度的数值作为索引,挑选z_in或z_out的每行,形成每个节点的嵌入
return x
def decrease_to_max_value(self, x, max_value):
"限制节点度的最大值"
x[x > max_value] = max_value
return x
class SpatialEncoding(nn.Module):
def __init__(self, max_path_distance: int):
"""
:param max_path_distance: max pairwise distance between nodes
"""
super().__init__()
self.max_path_distance = max_path_distance
self.b = nn.Parameter(torch.randn(self.max_path_distance))
def forward(self, x: torch.Tensor, paths) -> torch.Tensor:
"""
:param x: node feature matrix
:param paths: pairwise node paths
:return: torch.Tensor, spatial Encoding matrix
"""
spatial_matrix = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device) # (num_nodes, num_nodes)
for src in paths:
for dst in paths[src]:
spatial_matrix[src][dst] = self.b[min(len(paths[src][dst]), self.max_path_distance) - 1] # 索引从 0 到 max_path_distance-1
return spatial_matrix
class EdgeEncoding(nn.Module):
def __init__(self, edge_dim: int, max_path_distance: int):
"""
:param edge_dim: edge feature matrix number of dimension
"""
super().__init__()
self.edge_dim = edge_dim
self.max_path_distance = max_path_distance
self.edge_vector = nn.Parameter(torch.randn(self.max_path_distance, self.edge_dim))
def forward(self, x: torch.Tensor, edge_attr: torch.Tensor, edge_paths) -> torch.Tensor:
"""
:param x: node feature matrix
:param edge_attr: edge feature matrix
:param edge_paths: pairwise node paths in edge indexes
:return: torch.Tensor, Edge Encoding matrix
"""
cij = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device)
for src in edge_paths:
for dst in edge_paths[src]:
path_ij = edge_paths[src][dst][:self.max_path_distance] # 获取最短路径(截断)
weight_inds = [i for i in range(len(path_ij))]
cij[src][dst] = self.dot_product(self.edge_vector[weight_inds], edge_attr[path_ij]).mean()
cij = torch.nan_to_num(cij) # 路径可能无数值,后续计算产生NaN
return cij
def dot_product(self, x1, x2) -> torch.Tensor:
return (x1 * x2).sum(dim=1) # 沿着第二维度求和,即对二维张量的每行求和(返回值为一维张量,一行多列)
class GraphormerAttentionHead(nn.Module):
def __init__(self, dim_in: int, dim_q: int, dim_k: int, edge_dim: int, max_path_distance: int):
"""
:param dim_in: node feature matrix input number of dimension
:param dim_q: query node feature matrix input number dimension
:param dim_k: key node feature matrix input number of dimension
:param edge_dim: edge feature matrix number of dimension
"""
super().__init__()
self.edge_encoding = EdgeEncoding(edge_dim, max_path_distance)
self.q = nn.Linear(dim_in, dim_q)
self.k = nn.Linear(dim_in, dim_k)
self.v = nn.Linear(dim_in, dim_k)
def forward(self,
x: torch.Tensor,
edge_attr: torch.Tensor,
b: torch.Tensor,
edge_paths,
ptr=None) -> torch.Tensor:
"""
:param query: node feature matrix
:param key: node feature matrix
:param value: node feature matrix
:param edge_attr: edge feature matrix
:param b: spatial Encoding matrix
:param edge_paths: pairwise node paths in edge indexes
:param ptr: batch pointer that shows graph indexes in batch of graphs
:return: torch.Tensor, node embeddings after attention operation
"""
batch_mask_neg_inf = torch.full(size=(x.shape[0], x.shape[0]), fill_value=-1e6).to(next(self.parameters()).device)
batch_mask_zeros = torch.zeros(size=(x.shape[0], x.shape[0])).to(next(self.parameters()).device)
if type(ptr) == type(None):
batch_mask_neg_inf = torch.ones(size=(x.shape[0], x.shape[0])).to(next(self.parameters()).device)
batch_mask_zeros += 1
else:
# 批图的mask,邻接矩阵以对角阵组合
for i in range(len(ptr) - 1):
batch_mask_neg_inf[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1
batch_mask_zeros[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1
query = self.q(x)
key = self.k(x)
value = self.v(x)
c = self.edge_encoding(x, edge_attr, edge_paths)
a = self.compute_a(key, query, ptr)
a = (a + b + c) * batch_mask_neg_inf
softmax = torch.softmax(a, dim=-1) * batch_mask_zeros # e^(-inf) ——> 0
x = softmax.mm(value)
return x
def compute_a(self, key, query, ptr=None):
"Query-Key product(normalization)"
if type(ptr) == type(None):
a = query.mm(key.transpose(0, 1)) / query.size(-1) ** 0.5
else:
a = torch.zeros((query.shape[0], query.shape[0]), device=key.device)
for i in range(len(ptr) - 1):
a[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = query[ptr[i]:ptr[i + 1]].mm(
key[ptr[i]:ptr[i + 1]].transpose(0, 1)) / query.size(-1) ** 0.5
return a
class GraphormerMultiHeadAttention(nn.Module):
def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int, edge_dim: int, max_path_distance: int):
"""
:param num_heads: number of attention heads
:param dim_in: node feature matrix input number of dimension
:param dim_q: query node feature matrix input number dimension
:param dim_k: key node feature matrix input number of dimension
:param edge_dim: edge feature matrix number of dimension
"""
super().__init__()
self.heads = nn.ModuleList(
[GraphormerAttentionHead(dim_in, dim_q, dim_k, edge_dim, max_path_distance) for _ in range(num_heads)]
)
self.linear = nn.Linear(num_heads * dim_k, dim_in)
def forward(self,
x: torch.Tensor,
edge_attr: torch.Tensor,
b: torch.Tensor,
edge_paths,
ptr) -> torch.Tensor:
"""
:param x: node feature matrix
:param edge_attr: edge feature matrix
:param b: spatial Encoding matrix
:param edge_paths: pairwise node paths in edge indexes
:param ptr: batch pointer that shows graph indexes in batch of graphs
:return: torch.Tensor, node embeddings after all attention heads
"""
return self.linear(
torch.cat([
attention_head(x, edge_attr, b, edge_paths, ptr) for attention_head in self.heads
], dim=-1)
)
class GraphormerEncoderLayer(nn.Module):
def __init__(self, node_dim, edge_dim, num_heads, max_path_distance):
"""
:param node_dim: node feature matrix input number of dimension
:param edge_dim: edge feature matrix input number of dimension
:param num_heads: number of attention heads
"""
super().__init__()
self.node_dim = node_dim
self.edge_dim = edge_dim
self.num_heads = num_heads
self.attention = GraphormerMultiHeadAttention(
dim_in=node_dim,
dim_k=node_dim,
dim_q=node_dim,
num_heads=num_heads,
edge_dim=edge_dim,
max_path_distance=max_path_distance,
)
self.ln_1 = nn.LayerNorm(node_dim)
self.ln_2 = nn.LayerNorm(node_dim)
self.ff = nn.Linear(node_dim, node_dim)
def forward(self,
x: torch.Tensor,
edge_attr: torch.Tensor,
b: torch,
edge_paths,
ptr) -> Tuple[torch.Tensor, torch.Tensor]:
"""
h′(l) = MHA(LN(h(l−1))) + h(l−1)
h(l) = FFN(LN(h′(l))) + h′(l)
:param x: node feature matrix
:param edge_attr: edge feature matrix
:param b: spatial Encoding matrix
:param edge_paths: pairwise node paths in edge indexes
:param ptr: batch pointer that shows graph indexes in batch of graphs
:return: torch.Tensor, node embeddings after Graphormer layer operations
"""
x_prime = self.attention(self.ln_1(x), edge_attr, b, edge_paths, ptr) + x
x_new = self.ff(self.ln_2(x_prime)) + x_prime
return x_new