-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmetrics.py
152 lines (125 loc) · 7.07 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
import os
import json
from tqdm import tqdm
import numpy as np
import torch
from gan_training.config import load_config
from seeded_sampler import SeededSampler
parser = argparse.ArgumentParser('Computes numbers used in paper and caches them to a result files. Examples include FID, IS, reverse-KL, # modes, FSD, cluster NMI, Purity.')
parser.add_argument('paths', nargs='+', type=str, help='list of configs for each experiment')
parser.add_argument('--it', type=int, default=-1, help='If set, computes numbers only for that iteration')
parser.add_argument('--every', type=int, default=-1, help='skips some checkpoints and only computes those whose iteration number are divisible by every')
parser.add_argument('--fid', action='store_true', help='compute FID metric')
parser.add_argument('--inception', action='store_true', help='compute IS metric')
parser.add_argument('--modes', action='store_true', help='compute # modes and reverse-KL metric')
parser.add_argument('--fsd', action='store_true', help='compute FSD metric')
parser.add_argument('--cluster_metrics', action='store_true', help='compute clustering metrics (NMI, purity)')
parser.add_argument('--device', type=int, default=1, help='device to run the metrics on (can run into OOM issues if same as main device)')
args = parser.parse_args()
device = args.device
dirs = list(args.paths)
N = 50000
BS = 100
datasets = ['imagenet', 'cifar', 'stacked_mnist', 'places']
dataset_to_img = {
'places': 'output/places_gt_imgs.npz',
'imagenet': 'output/imagenet_gt_imgs.npz'}
def load_results(results_dir):
results = []
for results_file in ['fid_results.json', 'is_results.json', 'kl_results.json', 'nmodes_results.json', 'fsd_results.json', 'cluster_metrics.json']:
results_file = os.path.join(results_dir, results_file)
if not os.path.exists(results_file):
with open(results_file, 'w') as f:
f.write(json.dumps({}))
with open(results_file) as f:
results.append(json.load(f))
return results
def get_dataset_from_path(path):
for name in datasets:
if name in path:
print('Inferred dataset:', name)
return name
def pt_to_np(imgs):
'''normalizes pytorch image in [-1, 1] to [0, 255]'''
return (imgs.permute(0, 2, 3, 1).mul_(0.5).add_(0.5).mul_(255)).clamp_(0, 255).numpy()
def sample(sampler):
with torch.no_grad():
samples = []
for _ in tqdm(range(N // BS + 1)):
x_real = sampler.sample(BS)[0].detach().cpu()
x_real = [x.detach().cpu() for x in x_real]
samples.extend(x_real)
samples = torch.stack(samples[:N], dim=0)
return pt_to_np(samples)
root = './'
while len(dirs) > 0:
path = dirs.pop()
if os.path.isdir(path): # search down tree for config files
for d1 in os.listdir(path):
dirs.append(os.path.join(path, d1))
else:
if path.endswith('.yaml'):
config = load_config(path, default_path='configs/default.yaml')
outdir = config['training']['out_dir']
if not os.path.exists(outdir) and config['pretrained'] == {}:
print('Skipping', path, 'outdir', outdir)
continue
results_dir = os.path.join(outdir, 'results')
checkpoint_dir = os.path.join(outdir, 'chkpts')
os.makedirs(results_dir, exist_ok=True)
fid_results, is_results, kl_results, nmodes_results, fsd_results, cluster_results = load_results(results_dir)
checkpoint_files = os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else []
if config['pretrained'] != {}:
checkpoint_files = checkpoint_files + ['pretrained']
for checkpoint in checkpoint_files:
if (checkpoint.endswith('.pt') and checkpoint != 'model.pt') or checkpoint == 'pretrained':
print('Computing for', checkpoint)
if 'model' in checkpoint:
# infer iteration number from checkpoint file w/o loading it
if 'model_' in checkpoint:
it = int(checkpoint.split('model_')[1].split('.pt')[0])
else:
continue
if args.every != 0 and it % args.every != 0:
continue
# iteration 0 is often useless, skip it
if it == 0 or args.it != -1 and it != args.it:
continue
elif checkpoint == 'pretrained':
it = 'pretrained'
it = str(it)
clusterer_path = os.path.join(root, checkpoint_dir, f'clusterer{it}.pkl')
# don't save samples for each iteration for disk space
samples_path = os.path.join(outdir, 'results', 'samples.npz')
targets = []
if args.inception:
targets = targets + [is_results]
if args.fid:
targets = targets + [fid_results]
if args.modes:
targets = targets + [kl_results, nmodes_results]
if args.fsd:
targets = targets + [fsd_results]
if all([it in result for result in targets]):
print('Already generated', it, path)
else:
sampler = SeededSampler(path,
model_path=os.path.join(root, checkpoint_dir, checkpoint),
clusterer_path=clusterer_path,
pretrained=config['pretrained'])
samples = sample(sampler)
dataset_name = get_dataset_from_path(path)
np.savez(samples_path, fake=samples, real=dataset_name)
arguments = f'--samples {samples_path} --it {it} --results_dir {results_dir}'
if args.fid and it not in fid_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/fid.py {arguments}')
if args.inception and it not in is_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python gan_training/metrics/tf_is/inception_score.py {arguments}')
if args.modes and (it not in kl_results or it not in nmodes_results):
os.system(f'CUDA_VISIBLE_DEVICES={device} python utils/get_empirical_distribution.py {arguments} --dataset {dataset_name}')
if args.cluster_metrics and it not in cluster_results:
os.system(f'CUDA_VISIBLE_DEVICES={device} python cluster_metrics.py {path} --model_it {it}')
if args.fsd and it not in fsd_results:
gt_path = dataset_to_img[dataset_name]
os.system(f'CUDA_VISIBLE_DEVICES={device} python -m seeing.fsd {gt_path} {samples_path} --it {it} --results_dir {results_dir}')