forked from juho-lee/set_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
executable file
·44 lines (40 loc) · 1.72 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from modules import *
class DeepSet(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output, dim_hidden=128):
super(DeepSet, self).__init__()
self.num_outputs = num_outputs
self.dim_output = dim_output
self.enc = nn.Sequential(
nn.Linear(dim_input, dim_hidden),
nn.ReLU(),
nn.Linear(dim_hidden, dim_hidden),
nn.ReLU(),
nn.Linear(dim_hidden, dim_hidden),
nn.ReLU(),
nn.Linear(dim_hidden, dim_hidden))
self.dec = nn.Sequential(
nn.Linear(dim_hidden, dim_hidden),
nn.ReLU(),
nn.Linear(dim_hidden, dim_hidden),
nn.ReLU(),
nn.Linear(dim_hidden, dim_hidden),
nn.ReLU(),
nn.Linear(dim_hidden, num_outputs*dim_output))
def forward(self, X):
X = self.enc(X).mean(-2)
X = self.dec(X).reshape(-1, self.num_outputs, self.dim_output)
return X
class SetTransformer(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False):
super(SetTransformer, self).__init__()
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
def forward(self, X):
return self.dec(self.enc(X))