From 0f628ede6070157d8479375b6d3e0b854e805c0d Mon Sep 17 00:00:00 2001 From: Ming-Yu Liu Date: Wed, 2 May 2018 17:13:47 -0700 Subject: [PATCH] Fix test.py bug --- test.py | 114 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 70 insertions(+), 44 deletions(-) diff --git a/test.py b/test.py index e07bf7508..d0c558ff5 100644 --- a/test.py +++ b/test.py @@ -4,8 +4,8 @@ """ from __future__ import print_function from utils import get_config, get_data_loader_folder -from trainer import MUNIT_Trainer -from optparse import OptionParser +from trainer import MUNIT_Trainer, UNIT_Trainer +import argparse from torch.autograd import Variable import torchvision.utils as vutils import sys @@ -14,50 +14,70 @@ from torchvision import transforms from PIL import Image -parser = OptionParser() -parser.add_option('--config', type=str, help="net configuration") -parser.add_option('--input', type=str, help="input image path") -parser.add_option('--output_folder', type=str, help="output image path") -parser.add_option('--checkpoint', type=str, help="checkpoint of autoencoders") -parser.add_option('--style', type=str, default='', help="style image path") -parser.add_option('--a2b', type=int, default=1, help="1 for a2b and others for b2a") -parser.add_option('--seed', type=int, default=10, help="random seed") -parser.add_option('--num_style',type=int, default=10, help="number of styles to sample") -parser.add_option('--synchronized', action='store_true', help="whether use synchronized style code or not") -parser.add_option('--output_only', action='store_true', help="whether use synchronized style code or not") +parser = argparse.ArgumentParser() +parser.add_argument('--config', type=str, help="net configuration") +parser.add_argument('--input', type=str, help="input image path") +parser.add_argument('--output_folder', type=str, help="output image path") +parser.add_argument('--checkpoint', type=str, help="checkpoint of autoencoders") +parser.add_argument('--style', type=str, default='', help="style image path") +parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") +parser.add_argument('--seed', type=int, default=10, help="random seed") +parser.add_argument('--num_style',type=int, default=10, help="number of styles to sample") +parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") +parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") +parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight") +parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT") -def main(argv): - (opts, args) = parser.parse_args(argv) - torch.manual_seed(opts.seed) - torch.cuda.manual_seed(opts.seed) - if not os.path.exists(opts.output_folder): - os.makedirs(opts.output_folder) - # Load experiment setting - config = get_config(opts.config) - style_dim = config['gen']['style_dim'] - opts.num_style = 1 if opts.style != '' else opts.num_style +opts = parser.parse_args() + + +torch.manual_seed(opts.seed) +torch.cuda.manual_seed(opts.seed) +if not os.path.exists(opts.output_folder): + os.makedirs(opts.output_folder) - # Setup model and data loader +# Load experiment setting +config = get_config(opts.config) +opts.num_style = 1 if opts.style != '' else opts.num_style + +# Setup model and data loader +config['vgg_model_path'] = opts.output_path +if opts.trainer == 'MUNIT': + style_dim = config['gen']['style_dim'] trainer = MUNIT_Trainer(config) - state_dict = torch.load(opts.checkpoint) - trainer.gen_a.load_state_dict(state_dict['a']) - trainer.gen_b.load_state_dict(state_dict['b']) - trainer.cuda() - trainer.eval() - encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function - style_encode = trainer.gen_b.encode if opts.a2b else trainer.gen_a.encode # encode function - decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function +elif opts.trainer == 'UNIT': + trainer = UNIT_Trainer(config) +else: + sys.exit("Only support MUNIT|UNIT") +state_dict = torch.load(opts.checkpoint) +trainer.gen_a.load_state_dict(state_dict['a']) +trainer.gen_b.load_state_dict(state_dict['b']) +trainer.cuda() +trainer.eval() +encode = trainer.gen_a.encode if opts.a2b else trainer.gen_b.encode # encode function +style_encode = trainer.gen_b.encode if opts.a2b else trainer.gen_a.encode # encode function +decode = trainer.gen_b.decode if opts.a2b else trainer.gen_a.decode # decode function - transform = transforms.Compose([transforms.Resize(config['new_size']), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - image = Variable(transform(Image.open(opts.input).convert('RGB')).unsqueeze(0).cuda(), volatile=True) - style_image = Variable(transform(Image.open(opts.style).convert('RGB')).unsqueeze(0).cuda(), volatile=True) if opts.style != '' else None +if 'new_size' in config: + new_size = config['new_size'] +else: + if opts.a2b==1: + new_size = config['new_size_a'] + else: + new_size = config['new_size_b'] + +transform = transforms.Compose([transforms.Resize(new_size), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) +image = Variable(transform(Image.open(opts.input).convert('RGB')).unsqueeze(0).cuda(), volatile=True) +style_image = Variable(transform(Image.open(opts.style).convert('RGB')).unsqueeze(0).cuda(), volatile=True) if opts.style != '' else None - # Start testing +# Start testing +content, _ = encode(image) + +if opts.trainer == 'MUNIT': style_rand = Variable(torch.randn(opts.num_style, style_dim, 1, 1).cuda(), volatile=True) - content, _ = encode(image) if opts.style != '': _, style = style_encode(style_image) else: @@ -68,9 +88,15 @@ def main(argv): outputs = (outputs + 1) / 2. path = os.path.join(opts.output_folder, 'output{:03d}.jpg'.format(j)) vutils.save_image(outputs.data, path, padding=0, normalize=True) - if not opts.output_only: - # also save input images - vutils.save_image(image.data, os.path.join(opts.output_folder, 'input.jpg'), padding=0, normalize=True) +elif opts.trainer == 'UNIT': + outputs = decode(content) + outputs = (outputs + 1) / 2. + path = os.path.join(opts.output_folder, 'output.jpg') + vutils.save_image(outputs.data, path, padding=0, normalize=True) +else: + pass + +if not opts.output_only: + # also save input images + vutils.save_image(image.data, os.path.join(opts.output_folder, 'input.jpg'), padding=0, normalize=True) -if __name__ == '__main__': - main(sys.argv)