Skip to content

Commit

Permalink
Fixed quantization process
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Oct 7, 2020
1 parent 806d725 commit 97081ff
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
16 changes: 9 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, in_channel, logdet=True):
self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))
self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))

self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
self.logdet = logdet

def initialize(self, input):
Expand Down Expand Up @@ -102,11 +102,11 @@ def __init__(self, in_channel):
w_s = torch.from_numpy(w_s)
w_u = torch.from_numpy(w_u)

self.register_buffer('w_p', w_p)
self.register_buffer('u_mask', torch.from_numpy(u_mask))
self.register_buffer('l_mask', torch.from_numpy(l_mask))
self.register_buffer('s_sign', torch.sign(w_s))
self.register_buffer('l_eye', torch.eye(l_mask.shape[0]))
self.register_buffer("w_p", w_p)
self.register_buffer("u_mask", torch.from_numpy(u_mask))
self.register_buffer("l_mask", torch.from_numpy(l_mask))
self.register_buffer("s_sign", torch.sign(w_s))
self.register_buffer("l_eye", torch.eye(l_mask.shape[0]))
self.w_l = nn.Parameter(w_l)
self.w_s = nn.Parameter(logabs(w_s))
self.w_u = nn.Parameter(w_u)
Expand Down Expand Up @@ -333,7 +333,9 @@ def reverse(self, output, eps=None, reconstruct=False):


class Glow(nn.Module):
def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True):
def __init__(
self, in_channel, n_flow, n_block, affine=True, conv_lu=True
):
super().__init__()

self.blocks = nn.ModuleList()
Expand Down
58 changes: 33 additions & 25 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,29 @@

from model import Glow

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description='Glow trainer')
parser.add_argument('--batch', default=16, type=int, help='batch size')
parser.add_argument('--iter', default=200000, type=int, help='maximum iterations')
parser = argparse.ArgumentParser(description="Glow trainer")
parser.add_argument("--batch", default=16, type=int, help="batch size")
parser.add_argument("--iter", default=200000, type=int, help="maximum iterations")
parser.add_argument(
'--n_flow', default=32, type=int, help='number of flows in each block'
"--n_flow", default=32, type=int, help="number of flows in each block"
)
parser.add_argument('--n_block', default=4, type=int, help='number of blocks')
parser.add_argument("--n_block", default=4, type=int, help="number of blocks")
parser.add_argument(
'--no_lu',
action='store_true',
help='use plain convolution instead of LU decomposed version',
"--no_lu",
action="store_true",
help="use plain convolution instead of LU decomposed version",
)
parser.add_argument(
'--affine', action='store_true', help='use affine coupling instead of additive'
"--affine", action="store_true", help="use affine coupling instead of additive"
)
parser.add_argument('--n_bits', default=5, type=int, help='number of bits')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
parser.add_argument('--img_size', default=64, type=int, help='image size')
parser.add_argument('--temp', default=0.7, type=float, help='temperature of sampling')
parser.add_argument('--n_sample', default=20, type=int, help='number of samples')
parser.add_argument('path', metavar='PATH', type=str, help='Path to image directory')
parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument("--img_size", default=64, type=int, help="image size")
parser.add_argument("--temp", default=0.7, type=float, help="temperature of sampling")
parser.add_argument("--n_sample", default=20, type=int, help="number of samples")
parser.add_argument("path", metavar="PATH", type=str, help="Path to image directory")


def sample_data(path, batch_size, image_size):
Expand All @@ -45,7 +45,6 @@ def sample_data(path, batch_size, image_size):
transforms.CenterCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (1, 1, 1)),
]
)

Expand Down Expand Up @@ -96,7 +95,7 @@ def calc_loss(log_p, logdet, image_size, n_bins):

def train(args, model, optimizer):
dataset = iter(sample_data(args.path, args.batch, args.img_size))
n_bins = 2. ** args.n_bits
n_bins = 2.0 ** args.n_bits

z_sample = []
z_shapes = calc_z_shapes(3, args.img_size, args.n_flow, args.n_block)
Expand All @@ -109,9 +108,18 @@ def train(args, model, optimizer):
image, _ = next(dataset)
image = image.to(device)

image = image * 255

if args.n_bits < 8:
image = torch.floor(image / 2 ** (8 - args.n_bits))

image = image / n_bins - 0.5

if i == 0:
with torch.no_grad():
log_p, logdet, _ = model.module(image + torch.rand_like(image) / n_bins)
log_p, logdet, _ = model.module(
image + torch.rand_like(image) / n_bins
)

continue

Expand All @@ -125,33 +133,33 @@ def train(args, model, optimizer):
loss.backward()
# warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
warmup_lr = args.lr
optimizer.param_groups[0]['lr'] = warmup_lr
optimizer.param_groups[0]["lr"] = warmup_lr
optimizer.step()

pbar.set_description(
f'Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}'
f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
)

if i % 100 == 0:
with torch.no_grad():
utils.save_image(
model_single.reverse(z_sample).cpu().data,
f'sample/{str(i + 1).zfill(6)}.png',
f"sample/{str(i + 1).zfill(6)}.png",
normalize=True,
nrow=10,
range=(-0.5, 0.5),
)

if i % 10000 == 0:
torch.save(
model.state_dict(), f'checkpoint/model_{str(i + 1).zfill(6)}.pt'
model.state_dict(), f"checkpoint/model_{str(i + 1).zfill(6)}.pt"
)
torch.save(
optimizer.state_dict(), f'checkpoint/optim_{str(i + 1).zfill(6)}.pt'
optimizer.state_dict(), f"checkpoint/optim_{str(i + 1).zfill(6)}.pt"
)


if __name__ == '__main__':
if __name__ == "__main__":
args = parser.parse_args()
print(args)

Expand Down

0 comments on commit 97081ff

Please sign in to comment.