-
Notifications
You must be signed in to change notification settings - Fork 32
/
eval_diffusion.py
83 lines (68 loc) · 2.95 KB
/
eval_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import random
import socket
import yaml
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import models
import datasets
import utils
from models import DenoisingDiffusion, DiffusiveRestoration
def parse_args_and_config():
parser = argparse.ArgumentParser(description='Restoring Weather with Patch-Based Denoising Diffusion Models')
parser.add_argument("--config", type=str, required=True,
help="Path to the config file")
parser.add_argument('--resume', default='', type=str,
help='Path for the diffusion model checkpoint to load for evaluation')
parser.add_argument("--grid_r", type=int, default=16,
help="Grid cell width r that defines the overlap between patches")
parser.add_argument("--sampling_timesteps", type=int, default=25,
help="Number of implicit sampling steps")
parser.add_argument("--test_set", type=str, default='raindrop',
help="restoration test set options: ['raindrop', 'snow', 'rainfog']")
parser.add_argument("--image_folder", default='results/images/', type=str,
help="Location to save restored images")
parser.add_argument('--seed', default=61, type=int, metavar='N',
help='Seed for initializing training (default: 61)')
args = parser.parse_args()
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def main():
args, config = parse_args_and_config()
# setup device to run
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device: {}".format(device))
config.device = device
if torch.cuda.is_available():
print('Note: Currently supports evaluations (restoration) when run only on a single GPU!')
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
# data loading
print("=> using dataset '{}'".format(config.data.dataset))
DATASET = datasets.__dict__[config.data.dataset](config)
_, val_loader = DATASET.get_loaders(parse_patches=False, validation=args.test_set)
# create model
print("=> creating denoising-diffusion model with wrapper...")
diffusion = DenoisingDiffusion(args, config)
model = DiffusiveRestoration(diffusion, args, config)
model.restore(val_loader, validation=args.test_set, r=args.grid_r)
if __name__ == '__main__':
main()