-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
299 lines (268 loc) · 14.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import os
import torch
import torch.distributed as dist
from timm.utils.model import unwrap_model, get_state_dict
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
max_accuracy_e = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch']
config.freeze()
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
max_accuracy_e = checkpoint['max_accuracy_e']
del checkpoint
torch.cuda.empty_cache()
return max_accuracy, max_accuracy_e
def load_pretrained(ckpt_path, model, logger):
logger.info(f"==============> Loading pretrained form {ckpt_path}....................")
checkpoint = torch.load(ckpt_path, map_location='cpu')
# msg = model.load_pretrained(checkpoint['model'])
# logger.info(msg)
# logger.info(f"=> Loaded successfully {ckpt_path} ")
# del checkpoint
# torch.cuda.empty_cache()
state_dict = checkpoint['state_dict_ema'] if 'state_dict_ema' in checkpoint.keys() else checkpoint
# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
# linear interpolate agent bias if h/w not match, bicubic interpolate agent bias if agent_num not match
agent_bias_keys = [k for k in state_dict.keys() if ("ah_bias" in k) or ("aw_bias" in k)
or ("ha_bias" in k) or ("wa_bias" in k)]
for k in agent_bias_keys:
if "ah_bias" in k:
squeeze_dim, permute = -1, False
elif "aw_bias" in k:
squeeze_dim, permute = -2, False
elif "ha_bias" in k:
squeeze_dim, permute = -2, True
else:
squeeze_dim, permute = -3, True
agent_bias_pretrained = state_dict[k].squeeze(dim=0).squeeze(dim=squeeze_dim)
agent_bias_current = model.state_dict()[k].squeeze(dim=0).squeeze(dim=squeeze_dim)
if permute:
agent_bias_pretrained = agent_bias_pretrained.permute(0, 2, 1)
agent_bias_current = agent_bias_current.permute(0, 2, 1)
num_heads1, agent_num1, hw1 = agent_bias_pretrained.size()
num_heads2, agent_num2, hw2 = agent_bias_current.size()
if num_heads1 != num_heads2:
logger.warning(f"Error in loading {k}, passing......")
else:
if agent_num1 != agent_num2:
a1 = int(agent_num1 ** 0.5)
a2 = int(agent_num2 ** 0.5)
agent_bias_pretrained_resized = agent_bias_pretrained.permute(0, 2, 1).reshape(num_heads1, hw1, a1, a1)
agent_bias_pretrained_resized = torch.nn.functional.interpolate(
agent_bias_pretrained_resized, size=(a2, a2), mode='bicubic').flatten(2).permute(0, 2, 1)
agent_bias_pretrained = agent_bias_pretrained_resized
if permute:
agent_bias_pretrained_resized = agent_bias_pretrained_resized.permute(0, 2, 1)
state_dict[k] = agent_bias_pretrained_resized.unsqueeze(dim=0).unsqueeze(dim=squeeze_dim)
if hw1 != hw2:
# linear interpolate agent bias if not match
agent_bias_pretrained_resized = torch.nn.functional.interpolate(
agent_bias_pretrained, size=hw2, mode='linear')
if permute:
agent_bias_pretrained_resized = agent_bias_pretrained_resized.permute(0, 2, 1)
state_dict[k] = agent_bias_pretrained_resized.unsqueeze(dim=0).unsqueeze(dim=squeeze_dim)
agent_bias_keys = [k for k in state_dict.keys() if ("an_bias" in k) or ("na_bias" in k)]
for k in agent_bias_keys:
agent_bias_pretrained = state_dict[k]
agent_bias_current = model.state_dict()[k]
num_heads1, agent_num1, h1, w1 = agent_bias_pretrained.size()
num_heads2, agent_num2, h2, w2 = agent_bias_current.size()
if (num_heads1 != num_heads2) or (h1 != h2) or (w1 != w2):
print(f"Error in loading {k}, passing......")
else:
if agent_num1 != agent_num2:
a1 = int(agent_num1 ** 0.5)
a2 = int(agent_num2 ** 0.5)
agent_bias_pretrained_resized = agent_bias_pretrained.flatten(2).permute(0, 2, 1).reshape(num_heads1, -1, a1, a1)
agent_bias_pretrained_resized = torch.nn.functional.interpolate(
agent_bias_pretrained_resized, size=(a2, a2), mode='bicubic').flatten(2).permute(0, 2, 1)
state_dict[k] = agent_bias_pretrained_resized.reshape(num_heads2, agent_num2, h2, w2)
# bicubic interpolate patch_embed.proj if not match
patch_embed_keys = [k for k in state_dict.keys() if ("patch_embed" in k) and (".proj.weight" in k)]
for k in patch_embed_keys:
patch_embed_pretrained = state_dict[k]
patch_embed_current = model.state_dict()[k]
out1, in1, h1, w1 = patch_embed_pretrained.size()
out2, in2, h2, w2 = patch_embed_current.size()
if (out1 != out2) or (in1 != in2):
logger.warning(f"Error in loading {k}, passing......")
else:
if (h1 != h2) or (w1 != w2):
# bicubic interpolate patch_embed.proj if not match
patch_embed_pretrained_resized = torch.nn.functional.interpolate(
patch_embed_pretrained, size=(h2, w2), mode='bicubic')
state_dict[k] = patch_embed_pretrained_resized
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = model.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {k}, passing......")
else:
if L1 != L2:
# bicubic interpolate relative_position_bias_table if not match
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
mode='bicubic')
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
# bicubic interpolate absolute_pos_embed if not match
absolute_pos_embed_keys = [k for k in state_dict.keys() if "pos_embed" in k]
for k in absolute_pos_embed_keys:
# dpe
absolute_pos_embed_pretrained = state_dict[k]
absolute_pos_embed_current = model.state_dict()[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logger.warning(f"Error in loading {k}, passing......")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
i, j = L1 - S1 ** 2, L2 - S2 ** 2
absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained[:, i:, :].reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained_ = absolute_pos_embed_pretrained_.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained_, size=(S2, S2), mode='bicubic')
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
state_dict[k] = torch.cat([absolute_pos_embed_pretrained[:, :j, :],
absolute_pos_embed_pretrained_resized], dim=1)
# check classifier, if not match, then re-init classifier to zero
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if Nc1 == 21841 and Nc2 == 1000:
logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
map22kto1k_path = f'data/map22kto1k.txt'
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
logger.info(f"=> loaded successfully '{ckpt_path}'")
del checkpoint
torch.cuda.empty_cache()
def save_checkpoint(config, epoch, model, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger):
save_state = {'model': model.state_dict(),
# 'model_ema': model_ema.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'max_accuracy_e': max_accuracy_e,
'epoch': epoch,
'config': config}
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def save_checkpoint_ema(config, epoch, model, model_ema, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger):
save_state = {'model': model.state_dict(),
# 'model_ema': model_ema.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'max_accuracy_e': max_accuracy_e,
'epoch': epoch,
'config': config}
save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model)
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def save_checkpoint_ema_new(config, epoch, model, model_ema, max_accuracy, max_accuracy_e, optimizer, lr_scheduler, logger, name=None):
save_state = {'model': model.state_dict(),
# 'model_ema': model_ema.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'max_accuracy_e': max_accuracy_e,
'epoch': epoch,
'config': config}
save_state['state_dict_ema'] = get_state_dict(model_ema, unwrap_model)
if name==None:
old_ckpt = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch-3}.pth')
if os.path.exists(old_ckpt):
os.remove(old_ckpt)
if name!=None:
save_path = os.path.join(config.OUTPUT, f'{name}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
else:
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt