-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsecond_option.py
36 lines (27 loc) · 1.37 KB
/
second_option.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch.nn
from data import MultiModalDataset, PermutedMultiModalDataset
from torch.utils.data import DataLoader
from nn_utils import TwoLayersFCNetwork, evaluate_risk
if __name__ == '__main__':
n_samples: int = 1000
first_sample_size: int = 5
second_sample_size: int = 7
model = TwoLayersFCNetwork(input_size=first_sample_size + second_sample_size)
dataset = MultiModalDataset(n_samples=n_samples,
first_sample_size=first_sample_size,
second_sample_size=second_sample_size)
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss()
risk = evaluate_risk(data_loader=train_dataloader,
model=model,
loss_fn=loss_fn,
permute_first=True)
dataset = PermutedMultiModalDataset(n_samples=n_samples,
first_sample_size=first_sample_size,
second_sample_size=second_sample_size)
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
permuted_risk = evaluate_risk(data_loader=train_dataloader,
model=model,
loss_fn=loss_fn,
permute_first=True)
print(risk - permuted_risk)