-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_classes.py
146 lines (117 loc) · 4.64 KB
/
base_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from torch import nn
from torch_geometric.nn.conv import MessagePassing
from utils import Meter
from regularized_ODE_function import RegularizedODEfunc
import regularized_ODE_function as reg_lib
import six
REGULARIZATION_FNS = {
"kinetic_energy": reg_lib.quadratic_cost,
"jacobian_norm2": reg_lib.jacobian_frobenius_regularization_fn,
"total_deriv": reg_lib.total_derivative,
"directional_penalty": reg_lib.directional_derivative
}
def create_regularization_fns(args):
regularization_fns = []
regularization_coeffs = []
for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS):
if args[arg_key] is not None:
regularization_fns.append(reg_fn)
regularization_coeffs.append(args[arg_key])
regularization_fns = regularization_fns
regularization_coeffs = regularization_coeffs
return regularization_fns, regularization_coeffs
class ODEblock(nn.Module):
def __init__(self, odefunc, regularization_fns, opt, data, device, t):
super(ODEblock, self).__init__()
self.opt = opt
self.t = t
self.aug_dim = 2 if opt['augment'] else 1
self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
self.nreg = len(regularization_fns)
self.reg_odefunc = RegularizedODEfunc(self.odefunc, regularization_fns)
if opt['adjoint']:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
self.train_integrator = odeint
self.test_integrator = None
self.set_tol()
def set_x0(self, x0):
self.odefunc.x0 = x0.clone().detach()
self.reg_odefunc.odefunc.x0 = x0.clone().detach()
def set_tol(self):
self.atol = self.opt['tol_scale'] * 1e-7
self.rtol = self.opt['tol_scale'] * 1e-9
if self.opt['adjoint']:
self.atol_adjoint = self.opt['tol_scale_adjoint'] * 1e-7
self.rtol_adjoint = self.opt['tol_scale_adjoint'] * 1e-9
def reset_tol(self):
self.atol = 1e-7
self.rtol = 1e-9
self.atol_adjoint = 1e-7
self.rtol_adjoint = 1e-9
def set_time(self, time):
self.t = torch.tensor([0, time]).to(self.device)
def __repr__(self):
return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
+ ")"
class ODEFunc(MessagePassing):
# currently requires in_features = out_features
def __init__(self, opt, data, device):
super(ODEFunc, self).__init__()
self.opt = opt
self.device = device
self.edge_index = None
self.edge_weight = None
self.attention_weights = None
self.alpha_train = nn.Parameter(torch.tensor(0.0))
self.beta_train = nn.Parameter(torch.tensor(0.0))
self.x0 = None
self.nfe = 0
self.alpha_sc = nn.Parameter(torch.ones(1))
self.beta_sc = nn.Parameter(torch.ones(1))
def __repr__(self):
return self.__class__.__name__
class BaseGNN(MessagePassing):
def __init__(self, opt, dataset, device=torch.device('cpu')):
super(BaseGNN, self).__init__()
self.opt = opt
self.T = opt['time']
self.num_classes = dataset.num_classes
self.num_features = dataset.data.num_features
self.num_nodes = dataset.data.num_nodes
self.device = device
self.fm = Meter()
self.bm = Meter()
if opt['beltrami']:
self.mx = nn.Linear(self.num_features, opt['feat_hidden_dim'])
self.mp = nn.Linear(opt['pos_enc_dim'], opt['pos_enc_hidden_dim'])
opt['hidden_dim'] = opt['feat_hidden_dim'] + opt['pos_enc_hidden_dim']
else:
self.m1 = nn.Linear(self.num_features, opt['hidden_dim'])
if self.opt['use_mlp']:
self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
if opt['use_labels']:
# todo - fastest way to propagate this everywhere, but error prone - refactor later
opt['hidden_dim'] = opt['hidden_dim'] + dataset.num_classes
else:
self.hidden_dim = opt['hidden_dim']
if opt['fc_out']:
self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes)
if self.opt['batch_norm']:
self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim'])
self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim'])
self.regularization_fns, self.regularization_coeffs = create_regularization_fns(self.opt)
def getNFE(self):
return self.odeblock.odefunc.nfe + self.odeblock.reg_odefunc.odefunc.nfe
def resetNFE(self):
self.odeblock.odefunc.nfe = 0
self.odeblock.reg_odefunc.odefunc.nfe = 0
def reset(self):
self.m1.reset_parameters()
self.m2.reset_parameters()
def __repr__(self):
return self.__class__.__name__