-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathupscaler.py
42 lines (38 loc) · 1.04 KB
/
upscaler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import numpy as np
class LatentUpscaler(nn.Module):
def head(self):
return [
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
nn.Upsample(scale_factor=self.fac, mode="nearest"),
nn.ReLU(),
]
def core(self):
layers = []
for _ in range(self.depth):
layers += [
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
]
return layers
def tail(self):
return [
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
]
def __init__(self, fac, depth=16):
super().__init__()
self.size = 64 # Conv2d size
self.chan = 4 # in/out channels
self.depth = depth # no. of layers
self.fac = fac # scale factor
self.krn = 3 # kernel size
self.pad = 1 # padding
self.sequential = nn.Sequential(
*self.head(),
*self.core(),
*self.tail(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.sequential(x)