-
Notifications
You must be signed in to change notification settings - Fork 70
/
visual_fv.py
executable file
·129 lines (108 loc) · 4.41 KB
/
visual_fv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import sys
import os
import yaml
sys.path.append("../")
sys.path.append("../../")
sys.path.append(os.getcwd())
from PIL import Image
from utils.defense_utils.dbd.model.model import SelfModel, LinearModel
from utils.defense_utils.dbd.model.utils import (
get_network_dbd,
load_state,
get_criterion,
get_optimizer,
get_scheduler,
)
from utils.save_load_attack import load_attack_result
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.aggregate_block.fix_random import fix_random
from utils.aggregate_block.dataset_and_transform_generate import (
get_transform,
get_dataset_denormalization,
)
from visual_utils import *
import torch
import numpy as np
import torchvision.transforms as transforms
from omnixai.explainers.vision.specific.feature_visualization.visualizer import FeatureVisualizer
# Basic setting: args
args = get_args()
with open(args.yaml_path, "r") as stream:
config = yaml.safe_load(stream)
config.update({k: v for k, v in args.__dict__.items() if v is not None})
args.__dict__ = config
args = preprocess_args(args)
fix_random(int(args.random_seed))
save_path_attack = "./record/" + args.result_file_attack
visual_save_path = save_path_attack + "/visual"
# Load result
if args.prototype:
result_attack = load_prototype_result(args, save_path_attack)
else:
result_attack = load_attack_result(save_path_attack + "/attack_result.pt")
# Load model
model_visual = generate_cls_model(args.model, args.num_classes)
if args.result_file_defense != "None":
save_path_defense = "./record/" + args.result_file_defense
visual_save_path = save_path_defense + "/visual"
result_defense = load_attack_result(
save_path_defense + "/defense_result.pt")
defense_method = args.result_file_defense.split('/')[-1]
if defense_method == 'fp':
model_visual.layer4[1].conv2 = torch.nn.Conv2d(
512, 512 - result_defense['index'], (3, 3), stride=1, padding=1, bias=False)
model_visual.linear = torch.nn.Linear(
(512 - result_defense['index'])*1, args.num_classes)
if defense_method == 'dbd':
backbone = get_network_dbd(args)
model_visual = LinearModel(
backbone, backbone.feature_dim, args.num_classes)
model_visual.load_state_dict(result_defense["model"])
print(f"Load model {args.model} from {args.result_file_defense}")
else:
model_visual.load_state_dict(result_attack["model"])
print(f"Load model {args.model} from {args.result_file_attack}")
model_visual.to(args.device)
# !!! Important to set eval mode !!!
model_visual.eval()
# make visual_save_path if not exist
os.mkdir(visual_save_path) if not os.path.exists(visual_save_path) else None
module_dict = dict(model_visual.named_modules())
target_layer = module_dict[args.target_layer_name]
print(f'Choose layer {args.target_layer_name} from model {args.model}')
# Enable training transform to enhance transform robustness
tran = get_transform(
args.dataset, *([args.input_height, args.input_width]), train=True)
for trans_t in tran.transforms:
if isinstance(trans_t, transforms.Normalize):
denormalizer = get_dataset_denormalization(trans_t)
############## Feature Visualization ##################
print("Plotting Feature Visualization")
optimizer = FeatureVisualizer(
model = model_visual,
objectives = [{"layer": target_layer, "type": "channel",
"index": list(range(target_layer.out_channels))}],
transformers = tran
)
# Some regularizations are used for better visualization results.
# The parameter for regularization is self-defined and you should set them by yourself.
# Note that such regularization may hinder optimizer to find some triggers especially when the triggers are some irregular patterns.
explanations = optimizer.explain(
num_iterations=300,
image_shape=(args.input_height, args.input_width),
regularizers=[("l1", 0.15), ("l2", 0), ("tv", 0.25)],
use_fft=True,
)
images = explanations.explanations[0]['image']
num_cnn = len(images)
num_col = 16
num_row = int(np.ceil(num_cnn/num_col))
fig, axes = plt.subplots(nrows=num_row, ncols=num_col,
figsize=(4*num_col, 5*num_row))
for cnn_i in range(num_cnn):
ax = axes[cnn_i//num_col, cnn_i % num_col]
ax.imshow(images[cnn_i])
ax.set_title(f'Kernel {cnn_i}')
plt.tight_layout()
plt.savefig(visual_save_path + f"/feature_visual.png")
print(f'Save to {visual_save_path + f"/feature_visual"}.png')