-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_configurations.py
44 lines (38 loc) · 1.14 KB
/
model_configurations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from function_transformer_attention import ODEFuncTransformerAtt
from function_GAT_attention import ODEFuncAtt
from function_laplacian_diffusion import LaplacianODEFunc
from block_transformer_attention import AttODEblock
from block_constant import ConstantODEblock
from block_mixed import MixedODEblock
from block_transformer_hard_attention import HardAttODEblock
from block_transformer_rewiring import RewireAttODEblock
class BlockNotDefined(Exception):
pass
class FunctionNotDefined(Exception):
pass
def set_block(opt):
ode_str = opt['block']
if ode_str == 'mixed':
block = MixedODEblock
elif ode_str == 'attention':
block = AttODEblock
elif ode_str == 'hard_attention':
block = HardAttODEblock
elif ode_str == 'rewire_attention':
block = RewireAttODEblock
elif ode_str == 'constant':
block = ConstantODEblock
else:
raise BlockNotDefined
return block
def set_function(opt):
ode_str = opt['function']
if ode_str == 'laplacian':
f = LaplacianODEFunc
elif ode_str == 'GAT':
f = ODEFuncAtt
elif ode_str == 'transformer':
f = ODEFuncTransformerAtt
else:
raise FunctionNotDefined
return f