-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathrun_demo.py
105 lines (84 loc) · 3.65 KB
/
run_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import csv
import argparse
import numpy as np
from datetime import datetime
import model as model_config
from data_utils import load as load_data, extract_features
from adversarial_tools import ForwardGradWrapper, adversarial_paraphrase, \
_stats_probability_shifts
parser = argparse.ArgumentParser(
description='Craft adversarial examples for a text classifier.')
parser.add_argument('--model_path',
help='Path to model weights',
default='./data/model.dat')
parser.add_argument('--adversarial_texts_path',
help='Path where results will be saved',
default='./data/adversarial_texts.csv')
parser.add_argument('--test_samples_cap',
help='Amount of test samples to use',
type=int, default=2000)
parser.add_argument('--use_typos',
help='Whether to use typos for paraphrases',
type=bool, default=False)
def clean(text):
'''
Clean non-unicode characters
'''
return ''.join([i if ord(i) < 128 else ' ' for i in str(text)])
if __name__ == '__main__':
args = parser.parse_args()
test_samples_cap = args.test_samples_cap
# Load Twitter gender data
(_, _, X_test, y_test), (docs_train, docs_test, _) = \
load_data('twitter_gender_data', from_cache=False)
# Load model from weights
model = model_config.build_model()
model.load_weights(args.model_path)
# Initialize the class that computes forward derivatives
grad_guide = ForwardGradWrapper(model)
# Calculate accuracy on test examples
preds = model.predict_classes(X_test[:test_samples_cap, ]).squeeze()
accuracy = np.mean(preds == y_test[:test_samples_cap])
print('Model accuracy on test:', accuracy)
# Choose some female tweets
female_indices, = np.where(y_test[:test_samples_cap] == 0)
print('Crafting adversarial examples...')
successful_perturbations = 0
failed_perturbations = 0
adversarial_text_data = []
adversarial_preds = np.array(preds)
for index, doc in enumerate(docs_test[:test_samples_cap]):
if y_test[index] == 0 and preds[index] == 0:
# If model prediction is correct, and the true class is female,
# craft adversarial text
adv_doc, (y, adv_y) = adversarial_paraphrase(
doc, grad_guide, target=1, use_typos=args.use_typos)
pred = np.round(adv_y)
if pred != preds[index]:
successful_perturbations += 1
print('{}. Successful example crafted.'.format(index))
else:
failed_perturbations += 1
print('{}. Failure.'.format(index))
adversarial_preds[index] = pred
adversarial_text_data.append({
'index': index,
'doc': clean(doc),
'adv': clean(adv_doc),
'success': pred != preds[index],
'confidence': y,
'adv_confidence': adv_y
})
print('Model accuracy on adversarial examples:',
np.mean(adversarial_preds == y_test[:test_samples_cap]))
print('Fooling success rate:',
successful_perturbations / (successful_perturbations + failed_perturbations))
print('Average probability shift:', np.mean(
np.array(_stats_probability_shifts)))
# Save resulting docs in a CSV file
with open(args.adversarial_texts_path, 'w') as handle:
writer = csv.DictWriter(handle,
fieldnames=adversarial_text_data[0].keys())
writer.writeheader()
for item in adversarial_text_data:
writer.writerow(item)