-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfunction_laplacian_diffusion.py
51 lines (43 loc) · 1.86 KB
/
function_laplacian_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch import nn
import torch_sparse
from base_classes import ODEFunc
from utils import MaxNFEException
# Define the ODE function.
# Input:
# --- t: A tensor with shape [], meaning the current time.
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
# Output:
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
class LaplacianODEFunc(ODEFunc):
# currently requires in_features = out_features
def __init__(self, in_features, out_features, opt, data, device):
super(LaplacianODEFunc, self).__init__(opt, data, device)
self.in_features = in_features
self.out_features = out_features
self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
self.alpha_sc = nn.Parameter(torch.ones(1))
self.beta_sc = nn.Parameter(torch.ones(1))
def sparse_multiply(self, x):
if self.opt['block'] in ['attention']: # adj is a multihead attention
mean_attention = self.attention_weights.mean(dim=1)
ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix
ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x)
else: # adj is a torch sparse matrix
ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
return ax
def forward(self, t, x): # the t param is needed by the ODE solver.
if self.nfe > self.opt["max_nfe"]:
raise MaxNFEException
self.nfe += 1
ax = self.sparse_multiply(x)
if not self.opt['no_alpha_sigmoid']:
alpha = torch.sigmoid(self.alpha_train)
else:
alpha = self.alpha_train
f = alpha * (ax - x)
if self.opt['add_source']:
f = f + self.beta_train * self.x0
return f