forked from facebookresearch/FixRes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
143 lines (126 loc) · 5.15 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from tqdm import tqdm
import torch
import hashlib
import os
import re
import shutil
import sys
import tempfile
try:
from requests.utils import urlparse
from requests import get as urlopen
requests_available = True
except ImportError:
requests_available = False
if sys.version_info[0] == 2:
from urlparse import urlparse # noqa f811
from urllib2 import urlopen # noqa f811
else:
from urllib.request import urlopen
from urllib.parse import urlparse
dependencies = ['torch', 'torchvision']
from torchvision.models.resnet import ResNet, Bottleneck
def _download_url_to_file(url, dst, hash_prefix, progress):
r"""
function from https://pytorch.org/docs/stable/model_zoo.html
"""
if requests_available:
u = urlopen(url, stream=True)
file_size = int(u.headers["Content-Length"])
u = u.raw
else:
u = urlopen(url)
meta = u.info()
if hasattr(meta, 'getheaders'):
file_size = int(meta.getheaders("Content-Length")[0])
else:
file_size = int(meta.get_all("Content-Length")[0])
f = tempfile.NamedTemporaryFile(delete=False)
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
with tqdm(total=file_size, disable=not progress) as pbar:
while True:
buffer = u.read(8192)
if len(buffer) == 0:
break
f.write(buffer)
if hash_prefix is not None:
sha256.update(buffer)
pbar.update(len(buffer))
f.close()
if hash_prefix is not None:
digest = sha256.hexdigest()
if digest[:len(hash_prefix)] != hash_prefix:
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
.format(hash_prefix, digest))
shutil.move(f.name, dst)
finally:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
r"""
function from https://pytorch.org/docs/stable/model_zoo.html
"""
if model_dir is None:
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return torch.load(cached_file, map_location=map_location)
model_urls = {
'FixResNet50': 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet50_v2.pth',
'FixResNet50CutMix': 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet50_CutMix_v2.pth',
'FixResNeXt101_32x48d': 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNext101_32x48d_v2.pth',
}
def _fixmodel(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress, map_location='cpu')['model']
model_dict = model.state_dict()
count=0
count2=0
for k in model_dict.keys():
count=count+1.0
if(('module.'+k) in pretrained_dict.keys()):
count2=count2+1.0
model_dict[k]=pretrained_dict.get(('module.'+k))
assert int(count2*100/count)== 100,"model loading error"
model.load_state_dict(model_dict)
return model
def fixresnet_50(progress=True, **kwargs):
"""Constructs a FixResNet-50
`"Fixing the train-test resolution discrepancy" <https://arxiv.org/abs/1906.06423>`_
Args:
progress (bool): If True, displays a progress bar of the download to stderr.
"""
return _fixmodel('FixResNet50', Bottleneck, [3, 4, 6, 3], True, progress, **kwargs)
def fixresnet_50_CutMix(progress=True, **kwargs):
"""Constructs a FixRes-50 CutMix
`"Fixing the train-test resolution discrepancy" <https://arxiv.org/abs/1906.06423>`_
Args:
progress (bool): If True, displays a progress bar of the download to stderr.
"""
return _fixmodel('FixResNet50CutMix', Bottleneck, [3, 4, 6, 3], True, progress, **kwargs)
def fixresnext101_32x48d(progress=True, **kwargs):
"""Constructs a FixResNeXt-101 32x48
`"Fixing the train-test resolution discrepancy" <https://arxiv.org/abs/1906.06423>`_
Args:
progress (bool): If True, displays a progress bar of the download to stderr.
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 48
return _fixmodel('FixResNeXt101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs)