-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_save_p6_v2.py
52 lines (40 loc) · 1.81 KB
/
merge_save_p6_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
from yolox.models import YOLOXCustomP6v2, YOLOPAFPNCustomP6v2, YOLOXHeadCustom
from yolox.utils.model_utils import fuse_model
# from submit.yolox_infer.models import YOLOX as YOLOXInfer
def merge(depth: float, width: float, ckpt: str, out_ckpt: str):
backbone = YOLOPAFPNCustomP6v2(depth=depth, width=width, act="silu")
head = YOLOXHeadCustom(num_classes=80, width=width,
strides=(8, 16, 32, 64),
in_channels=(256, 512, 768, 1024), act="silu")
model = YOLOXCustomP6v2(backbone, head)
for m in model.modules():
if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m.eps = 1e-3 # CRITICAL
model.eval().cuda()
model.load_state_dict(torch.load(ckpt, map_location="cuda")["model"], strict=True)
model.head.decode_in_inference = False
fused_model = fuse_model(model)
torch.save(fused_model.state_dict(), out_ckpt)
# infer_model = YOLOXInfer(depth=depth, width=width)
# infer_model.eval().cuda()
# infer_model.load_state_dict(fused_model.state_dict(), strict=True)
#
# with torch.no_grad():
# dummy_input = torch.empty((1, 3, 640, 640), dtype=torch.float32).uniform_(0.0, 255.0)
# dummy_input = dummy_input.cuda()
#
# out_model = model(dummy_input)
# print(out_model.shape)
#
# out_infer_reg, out_infer_obj, out_infer_cls = infer_model(dummy_input)
# out_infer = torch.cat([out_infer_reg, out_infer_obj.sigmoid(), out_infer_cls.sigmoid()], dim=2)
# print(out_infer.shape)
#
# diff = torch.abs(out_model - out_infer)
# print(diff.min(), diff.max(), diff.mean())
return model
if __name__ == '__main__':
# YOLOX-M-P6-V2
merge(0.67, 0.75, "yolox_m_p6_v2_kh.pth", "yolox_m_p6_v2_kh_merged.pth")