-
Notifications
You must be signed in to change notification settings - Fork 10
/
main.py
30 lines (22 loc) · 881 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import trainer as trainer_module
import data_loader
import matplotlib.pyplot as plt
import adversarial_perturbation
def main():
trainer = trainer_module.trainer()
trainset,testset = data_loader.load_data()
accuracy = trainer.train(trainset,testset)
trainset, testset = data_loader.load_data()
v, fooling_rates, accuracies, total_iterations=adversarial_perturbation.generate(accuracy,trainset, testset, trainer.net)
plt.title("Fooling Rates over Universal Iterations")
plt.xlabel("Universal Algorithm Iter")
plt.ylabel("Fooling Rate on test data")
plt.plot(total_iterations,fooling_rates)
plt.show()
plt.title("Accuracy over Universal Iterations")
plt.xlabel("Universal Algorithm Iter")
plt.ylabel("Accuracy on Test data")
plt.plot(total_iterations, accuracies)
plt.show()
if __name__ == "__main__":
main()