-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdemo.py
74 lines (62 loc) · 2.47 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
from PIL import Image
import yaml
import torch
from torchvision import transforms
import models
from utils import make_coord
from test import batched_predict
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', default='input.png')
parser.add_argument('--model')
parser.add_argument('--scale')
parser.add_argument('--output', default='output.png')
parser.add_argument('--gpu', default='0')
parser.add_argument('--fast', default=True) # Set fast to True for LMF, False for original LIIF/LTE/CiaoSR
parser.add_argument('--cmsr', default=False)
parser.add_argument('--cmsr_mse', default=0.00002)
parser.add_argument('--cmsr_path')
args = parser.parse_args()
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Maximum scale factor during training
scale_max = 4
if args.cmsr:
try:
# Test with CMSR
with open(args.cmsr_path, 'r') as f:
s2m_tables = yaml.load(f, Loader=yaml.FullLoader)
cmsr_spec = {
"mse_threshold": float(args.cmsr_mse),
"path": args.cmsr_path,
"s2m_tables": s2m_tables,
"log": False,
}
except FileNotFoundError:
cmsr_spec = None
else:
cmsr_spec = None
model_spec = torch.load(args.model)['model']
model_spec["args"]["cmsr_spec"] = cmsr_spec
model = models.make(model_spec, load_sd=True).to(DEVICE)
model.eval()
img = transforms.ToTensor()(Image.open(args.input).convert('RGB')).to(DEVICE)
h = int(img.shape[-2] * int(args.scale))
w = int(img.shape[-1] * int(args.scale))
scale = h / img.shape[-2]
coord = make_coord((h, w)).to(DEVICE)
cell = torch.ones_like(coord)
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w
cell_factor = max(scale/scale_max, 1)
if args.fast:
with torch.no_grad():
pred = model(((img - 0.5) / 0.5).unsqueeze(0),
coord.unsqueeze(0), cell_factor * cell.unsqueeze(0))[0]
else:
pred = batched_predict(model, ((img - 0.5) / 0.5).unsqueeze(0),
coord.unsqueeze(0), cell_factor * cell.unsqueeze(0), bsize=30000)[0]
pred = (pred * 0.5 + 0.5).clamp(0, 1).view(h, w, 3).permute(2, 0, 1).cpu()
transforms.ToPILImage()(pred).save(args.output)