-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmegvii_nano_2_mmdet.py
122 lines (112 loc) · 6.3 KB
/
megvii_nano_2_mmdet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
sd = "./yolox_nano.pth"
import torch
model_dict = torch.load(sd, map_location=torch.device('cpu'))
if "state_dict" in model_dict:
model_dict = model_dict["state_dict"]
if "model" in model_dict:
model_dict = model_dict["model"]
new_dict = dict()
for k, v in model_dict.items():
new_k = k
if "backbone.backbone." in k:
new_k = k.replace("backbone.backbone.", "backbone.")
if "backbone.dark2." in new_k:
new_k = new_k.replace("backbone.dark2.", "backbone.stage1.")
if "backbone.dark3." in new_k:
new_k = new_k.replace("backbone.dark3.", "backbone.stage2.")
if "backbone.dark4." in new_k:
new_k = new_k.replace("backbone.dark4.", "backbone.stage3.")
if "backbone.dark5." in new_k:
new_k = new_k.replace("backbone.dark5.", "backbone.stage4.")
if "dconv." in new_k:
new_k = new_k.replace("dconv.", "depthwise_conv.")
if "pconv." in new_k:
new_k = new_k.replace("pconv.", "pointwise_conv.")
if "backbone.stage1.1.conv1." in new_k:
new_k = new_k.replace("backbone.stage1.1.conv1.", "backbone.stage1.1.main_conv.")
if "backbone.stage1.1.conv2." in new_k:
new_k = new_k.replace("backbone.stage1.1.conv2.", "backbone.stage1.1.short_conv.")
if "backbone.stage1.1.conv3." in new_k:
new_k = new_k.replace("backbone.stage1.1.conv3.", "backbone.stage1.1.final_conv.")
if ".m." in new_k:
new_k = new_k.replace(".m.", ".blocks.")
if "backbone.stage2.1.conv1." in new_k:
new_k = new_k.replace("backbone.stage2.1.conv1.", "backbone.stage2.1.main_conv.")
if "backbone.stage2.1.conv2." in new_k:
new_k = new_k.replace("backbone.stage2.1.conv2.", "backbone.stage2.1.short_conv.")
if "backbone.stage2.1.conv3." in new_k:
new_k = new_k.replace("backbone.stage2.1.conv3.", "backbone.stage2.1.final_conv.")
if "backbone.stage3.1.conv1." in new_k:
new_k = new_k.replace("backbone.stage3.1.conv1.", "backbone.stage3.1.main_conv.")
if "backbone.stage3.1.conv2." in new_k:
new_k = new_k.replace("backbone.stage3.1.conv2.", "backbone.stage3.1.short_conv.")
if "backbone.stage3.1.conv3." in new_k:
new_k = new_k.replace("backbone.stage3.1.conv3.", "backbone.stage3.1.final_conv.")
if "backbone.stage4.2.conv1." in new_k:
new_k = new_k.replace("backbone.stage4.2.conv1.", "backbone.stage4.2.main_conv.")
if "backbone.stage4.2.conv2." in new_k:
new_k = new_k.replace("backbone.stage4.2.conv2.", "backbone.stage4.2.short_conv.")
if "backbone.stage4.2.conv3." in new_k:
new_k = new_k.replace("backbone.stage4.2.conv3.", "backbone.stage4.2.final_conv.")
if "backbone.lateral_conv0." in new_k:
new_k = new_k.replace("backbone.lateral_conv0.", "neck.reduce_layers.0.")
if "backbone.reduce_conv1." in new_k:
new_k = new_k.replace("backbone.reduce_conv1.", "neck.reduce_layers.1.")
if "backbone.C3_p4." in new_k:
new_k = new_k.replace("backbone.C3_p4.", "neck.top_down_blocks.0.")
if "neck.top_down_blocks.0.conv1." in new_k:
new_k = new_k.replace("neck.top_down_blocks.0.conv1.", "neck.top_down_blocks.0.main_conv.")
if "neck.top_down_blocks.0.conv2." in new_k:
new_k = new_k.replace("neck.top_down_blocks.0.conv2.", "neck.top_down_blocks.0.short_conv.")
if "neck.top_down_blocks.0.conv3." in new_k:
new_k = new_k.replace("neck.top_down_blocks.0.conv3.", "neck.top_down_blocks.0.final_conv.")
if "backbone.C3_p3." in new_k:
new_k = new_k.replace("backbone.C3_p3.", "neck.top_down_blocks.1.")
if "neck.top_down_blocks.1.conv1." in new_k:
new_k = new_k.replace("neck.top_down_blocks.1.conv1.", "neck.top_down_blocks.1.main_conv.")
if "neck.top_down_blocks.1.conv2." in new_k:
new_k = new_k.replace("neck.top_down_blocks.1.conv2.", "neck.top_down_blocks.1.short_conv.")
if "neck.top_down_blocks.1.conv3." in new_k:
new_k = new_k.replace("neck.top_down_blocks.1.conv3.", "neck.top_down_blocks.1.final_conv.")
if "backbone.bu_conv2." in new_k:
new_k = new_k.replace("backbone.bu_conv2.", "neck.downsamples.0.")
if "backbone.bu_conv1." in new_k:
new_k = new_k.replace("backbone.bu_conv1.", "neck.downsamples.1.")
if "backbone.C3_n3." in new_k:
new_k = new_k.replace("backbone.C3_n3.", "neck.bottom_up_blocks.0.")
if "neck.bottom_up_blocks.0.conv1." in new_k:
new_k = new_k.replace("neck.bottom_up_blocks.0.conv1.", "neck.bottom_up_blocks.0.main_conv.")
if "neck.bottom_up_blocks.0.conv2." in new_k:
new_k = new_k.replace("neck.bottom_up_blocks.0.conv2.", "neck.bottom_up_blocks.0.short_conv.")
if "neck.bottom_up_blocks.0.conv3." in new_k:
new_k = new_k.replace("neck.bottom_up_blocks.0.conv3.", "neck.bottom_up_blocks.0.final_conv.")
if "backbone.C3_n4." in new_k:
new_k = new_k.replace("backbone.C3_n4.", "neck.bottom_up_blocks.1.")
if "neck.bottom_up_blocks.1.conv1." in new_k:
new_k = new_k.replace("neck.bottom_up_blocks.1.conv1.", "neck.bottom_up_blocks.1.main_conv.")
if "neck.bottom_up_blocks.1.conv2." in new_k:
new_k = new_k.replace("neck.bottom_up_blocks.1.conv2.", "neck.bottom_up_blocks.1.short_conv.")
if "neck.bottom_up_blocks.1.conv3." in new_k:
new_k = new_k.replace("neck.bottom_up_blocks.1.conv3.", "neck.bottom_up_blocks.1.final_conv.")
if "head.stems." in new_k:
new_k = new_k.replace("head.stems.", "neck.out_convs.")
if "head.cls_convs." in new_k:
new_k = new_k.replace("head.cls_convs.", "bbox_head.multi_level_cls_convs.")
if "head.reg_convs." in new_k:
new_k = new_k.replace("head.reg_convs.", "bbox_head.multi_level_reg_convs.")
if "head.cls_preds." in new_k:
new_k = new_k.replace("head.cls_preds.", "bbox_head.multi_level_conv_cls.")
if "head.reg_preds." in new_k:
new_k = new_k.replace("head.reg_preds.", "bbox_head.multi_level_conv_reg.")
if "head.obj_preds." in new_k:
new_k = new_k.replace("head.obj_preds.", "bbox_head.multi_level_conv_obj.")
if "bbox_head.multi_level_conv_cls." in new_k:
new_dict[new_k] = v[:num_classes,...] # there take the num_classes
else:
new_dict[new_k] = v
# 保存 pth 模型文件
data = {'state_dict': new_dict}
pth = "./yolox_nano_mmdet.pth"
torch.save(data, pth)
for k,v in new_dict.items():
print(k,"-",v.size())