diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bb5cba0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,527 @@ +# Created by .ignore support plugin (hsz.mobi) +.idea + +### Linux template +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* +### Vim template +# Swap +[._]*.s[a-v][a-z] +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim + +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ +### macOS template +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk +### SublimeText template +# Cache files for Sublime Text +*.tmlanguage.cache +*.tmPreferences.cache +*.stTheme.cache + +# Workspace files are user-specific +*.sublime-workspace + +# Project files should be checked into the repository, unless a significant +# proportion of contributors will probably not be using Sublime Text +# *.sublime-project + +# SFTP configuration file +sftp-config.json + +# Package control specific files +Package Control.last-run +Package Control.ca-list +Package Control.ca-bundle +Package Control.system-ca-bundle +Package Control.cache/ +Package Control.ca-certs/ +Package Control.merged-ca-bundle +Package Control.user-ca-bundle +oscrypto-ca-bundle.crt +bh_unicode_properties.cache + +# Sublime-github package stores a github token in this file +# https://packagecontrol.io/packages/sublime-github +GitHub.sublime-settings +### TextMate template +*.tmproj +*.tmproject +tmtags +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +### VirtualEnv template +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +.Python +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +.venv +pip-selfcheck.json +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/architectures.xml +# .idea/*.iml +# .idea/architectures + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests +### TeX template +## Core latex/pdflatex auxiliary files: +*.aux +*.lof +*.log +*.lot +*.fls +*.out +*.toc +*.fmt +*.fot +*.cb +*.cb2 +.*.lb + +## Intermediate documents: +*.dvi +*.xdv +*-converted-to.* +# these rules might exclude image files for figures etc. +# *.ps +# *.eps +# *.pdf + +## Generated if empty string is given at "Please type another file name for output:" +.pdf + +## Bibliography auxiliary files (bibtex/biblatex/biber): +*.bbl +*.bcf +*.blg +*-blx.aux +*-blx.bib +*.run.xml + +## Build tool auxiliary files: +*.fdb_latexmk +*.synctex +*.synctex(busy) +*.synctex.gz +*.synctex.gz(busy) +*.pdfsync + +## Build tool directories for auxiliary files +# latexrun +latex.out/ + +## Auxiliary and intermediate files from other packages: +# algorithms +*.alg +*.loa + +# achemso +acs-*.bib + +# amsthm +*.thm + +# beamer +*.nav +*.pre +*.snm +*.vrb + +# changes +*.soc + +# cprotect +*.cpt + +# elsarticle (documentclass of Elsevier journals) +*.spl + +# endnotes +*.ent + +# fixme +*.lox + +# feynmf/feynmp +*.mf +*.mp +*.t[1-9] +*.t[1-9][0-9] +*.tfm + +#(r)(e)ledmac/(r)(e)ledpar +*.end +*.?end +*.[1-9] +*.[1-9][0-9] +*.[1-9][0-9][0-9] +*.[1-9]R +*.[1-9][0-9]R +*.[1-9][0-9][0-9]R +*.eledsec[1-9] +*.eledsec[1-9]R +*.eledsec[1-9][0-9] +*.eledsec[1-9][0-9]R +*.eledsec[1-9][0-9][0-9] +*.eledsec[1-9][0-9][0-9]R + +# glossaries +*.acn +*.acr +*.glg +*.glo +*.gls +*.glsdefs + +# gnuplottex +*-gnuplottex-* + +# gregoriotex +*.gaux +*.gtex + +# htlatex +*.4ct +*.4tc +*.idv +*.lg +*.trc +*.xref + +# hyperref +*.brf + +# knitr +*-concordance.tex +# TODO Comment the next line if you want to keep your tikz graphics files +*.tikz +*-tikzDictionary + +# listings +*.lol + +# makeidx +*.idx +*.ilg +*.ind +*.ist + +# minitoc +*.maf +*.mlf +*.mlt +*.mtc[0-9]* +*.slf[0-9]* +*.slt[0-9]* +*.stc[0-9]* + +# minted +_minted* +*.pyg + +# morewrites +*.mw + +# nomencl +*.nlg +*.nlo +*.nls + +# pax +*.pax + +# pdfpcnotes +*.pdfpc + +# sagetex +*.sagetex.sage +*.sagetex.py +*.sagetex.scmd + +# scrwfile +*.wrt + +# sympy +*.sout +*.sympy +sympy-plots-for-*.tex/ + +# pdfcomment +*.upa +*.upb + +# pythontex +*.pytxcode +pythontex-files-*/ + +# thmtools +*.loe + +# TikZ & PGF +*.dpth +*.md5 +*.auxlock + +# todonotes +*.tdo + +# easy-todo +*.lod + +# xmpincl +*.xmpi + +# xindy +*.xdy + +# xypic precompiled matrices +*.xyc + +# endfloat +*.ttt +*.fff + +# Latexian +TSWLatexianTemp* + +## Editors: +# WinEdt +*.bak +*.sav + +# Texpad +.texpadtmp + +# LyX +*.lyx~ + +# Kile +*.backup + +# KBibTeX +*~[0-9]* + +# auto folder when using emacs and auctex +./auto/* +*.el + +# expex forward references with \gathertags +*-tags.tex + +# standalone packages +*.sta + +# linter configuration +setup.cfg diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datasets/shapenet.py b/datasets/shapenet.py new file mode 100644 index 0000000..98ac83d --- /dev/null +++ b/datasets/shapenet.py @@ -0,0 +1,134 @@ +import urllib +import shutil +from os import listdir, makedirs, remove +from os.path import exists, join +from zipfile import ZipFile + +import pandas as pd +from torch.utils.data import Dataset + +from utils.plyfile import load_ply + +synth_id_to_category = { + '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', + '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', + '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', + '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', + '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', + '02954340': 'cap', '02958343': 'car', '03001627': 'chair', + '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', + '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', + '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', + '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', + '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', + '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', + '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', + '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', + '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', + '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', + '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', + '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', + '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' +} + +category_to_synth_id = {v: k for k, v in synth_id_to_category.items()} +synth_id_to_number = {k: i for i, k in enumerate(synth_id_to_category.keys())} + + +class ShapeNetDataset(Dataset): + def __init__(self, root_dir='/home/datasets/shapenet', classes=[], + transform=None, split='train'): + """ + Args: + root_dir (string): Directory with all the point clouds. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.root_dir = root_dir + self.transform = transform + self.split = split + + self._maybe_download_data() + + pc_df = self._get_names() + if classes: + if classes[0] not in synth_id_to_category.keys(): + classes = [category_to_synth_id[c] for c in classes] + pc_df = pc_df[pc_df.category.isin(classes)].reset_index(drop=True) + else: + classes = synth_id_to_category.keys() + + self.point_clouds_names_train = pd.concat([pc_df[pc_df['category'] == c][:int(0.85*len(pc_df[pc_df['category'] == c]))].reset_index(drop=True) for c in classes]) + self.point_clouds_names_valid = pd.concat([pc_df[pc_df['category'] == c][int(0.85*len(pc_df[pc_df['category'] == c])):int(0.9*len(pc_df[pc_df['category'] == c]))].reset_index(drop=True) for c in classes]) + self.point_clouds_names_test = pd.concat([pc_df[pc_df['category'] == c][int(0.9*len(pc_df[pc_df['category'] == c])):].reset_index(drop=True) for c in classes]) + + def __len__(self): + if self.split == 'train': + pc_names = self.point_clouds_names_train + elif self.split == 'valid': + pc_names = self.point_clouds_names_valid + elif self.split == 'test': + pc_names = self.point_clouds_names_test + else: + raise ValueError('Invalid split. Should be train, valid or test.') + return len(pc_names) + + def __getitem__(self, idx): + if self.split == 'train': + pc_names = self.point_clouds_names_train + elif self.split == 'valid': + pc_names = self.point_clouds_names_valid + elif self.split == 'test': + pc_names = self.point_clouds_names_test + else: + raise ValueError('Invalid split. Should be train, valid or test.') + + pc_category, pc_filename = pc_names.iloc[idx].values + + pc_filepath = join(self.root_dir, pc_category, pc_filename) + sample = load_ply(pc_filepath) + + if self.transform: + sample = self.transform(sample) + + return sample, synth_id_to_number[pc_category] + + def _get_names(self) -> pd.DataFrame: + filenames = [] + for category_id in synth_id_to_category.keys(): + for f in listdir(join(self.root_dir, category_id)): + if f not in ['.DS_Store']: + filenames.append((category_id, f)) + return pd.DataFrame(filenames, columns=['category', 'filename']) + + def _maybe_download_data(self): + if exists(self.root_dir): + return + + print(f'ShapeNet doesn\'t exist in root directory {self.root_dir}. ' + f'Downloading...') + makedirs(self.root_dir) + + url = 'https://www.dropbox.com/s/vmsdrae6x5xws1v/shape_net_core_uniform_samples_2048.zip?dl=1' + + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2][:-5] + file_path = join(self.root_dir, filename) + with open(file_path, mode='wb') as f: + d = data.read() + f.write(d) + + print('Extracting...') + with ZipFile(file_path, mode='r') as zip_f: + zip_f.extractall(self.root_dir) + + remove(file_path) + + extracted_dir = join(self.root_dir, + 'shape_net_core_uniform_samples_2048') + for d in listdir(extracted_dir): + shutil.move(src=join(extracted_dir, d), + dst=self.root_dir) + + shutil.rmtree(extracted_dir) + diff --git a/experiments/train_aae.py b/experiments/train_aae.py new file mode 100644 index 0000000..bd77d8f --- /dev/null +++ b/experiments/train_aae.py @@ -0,0 +1,283 @@ +import argparse +import json +import logging +import random +from datetime import datetime +from importlib import import_module +from itertools import chain +from os.path import join, exists + +import matplotlib.pyplot as plt +import torch +from torch.autograd import grad +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +from torch.utils.data import DataLoader + +from utils.pcutil import plot_3d_point_cloud +from utils.util import find_latest_epoch, prepare_results_dir, cuda_setup, setup_logging + +cudnn.benchmark = True + + +def weights_init(m): + classname = m.__class__.__name__ + if classname in ('Conv1d', 'Linear'): + torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + +def main(config): + random.seed(config['seed']) + torch.manual_seed(config['seed']) + torch.cuda.manual_seed_all(config['seed']) + + results_dir = prepare_results_dir(config) + starting_epoch = find_latest_epoch(results_dir) + 1 + + if not exists(join(results_dir, 'config.json')): + with open(join(results_dir, 'config.json'), mode='w') as f: + json.dump(config, f) + + setup_logging(results_dir) + log = logging.getLogger(__name__) + + device = cuda_setup(config['cuda'], config['gpu']) + log.debug(f'Device variable: {device}') + if device.type == 'cuda': + log.debug(f'Current CUDA device: {torch.cuda.current_device()}') + + weights_path = join(results_dir, 'weights') + + # + # Dataset + # + dataset_name = config['dataset'].lower() + if dataset_name == 'shapenet': + from datasets.shapenet import ShapeNetDataset + dataset = ShapeNetDataset(root_dir=config['data_dir'], + classes=config['classes']) + elif dataset_name == 'faust': + from datasets.dfaust import DFaustDataset + dataset = DFaustDataset(root_dir=config['data_dir'], + classes=config['classes']) + else: + raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' + f'`faust`. Got: `{dataset_name}`') + log.debug("Selected {} classes. Loaded {} samples.".format( + 'all' if not config['classes'] else ','.join(config['classes']), + len(dataset))) + + points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], + shuffle=config['shuffle'], + num_workers=config['num_workers'], + drop_last=True, pin_memory=True) + + # + # Models + # + arch = import_module(f"models.{config['arch']}") + G = arch.Generator(config).to(device) + E = arch.Encoder(config).to(device) + D = arch.Discriminator(config).to(device) + + G.apply(weights_init) + E.apply(weights_init) + D.apply(weights_init) + + if config['reconstruction_loss'].lower() == 'chamfer': + from losses.champfer_loss import ChamferLoss + reconstruction_loss = ChamferLoss().to(device) + elif config['reconstruction_loss'].lower() == 'earth_mover': + from losses.earth_mover_distance import EMD + reconstruction_loss = EMD().to(device) + else: + raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' + f'`earth_mover`, got: {config["reconstruction_loss"]}') + # + # Float Tensors + # + fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1) + fixed_noise.normal_(mean=config['normal_mu'], std=config['normal_std']) + noise = torch.FloatTensor(config['batch_size'], config['z_size']) + + fixed_noise = fixed_noise.to(device) + noise = noise.to(device) + + # + # Optimizers + # + EG_optim = getattr(optim, config['optimizer']['EG']['type']) + EG_optim = EG_optim(chain(E.parameters(), G.parameters()), + **config['optimizer']['EG']['hyperparams']) + + D_optim = getattr(optim, config['optimizer']['D']['type']) + D_optim = D_optim(D.parameters(), + **config['optimizer']['D']['hyperparams']) + + if starting_epoch > 1: + G.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_G.pth'))) + E.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_E.pth'))) + D.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_D.pth'))) + + D_optim.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) + + EG_optim.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) + + for epoch in range(starting_epoch, config['max_epochs'] + 1): + start_epoch_time = datetime.now() + + G.train() + E.train() + D.train() + + total_loss_d = 0.0 + total_loss_eg = 0.0 + for i, point_data in enumerate(points_dataloader, 1): + log.debug('-' * 20) + + X, _ = point_data + X = X.to(device) + + # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] + if X.size(-1) == 3: + X.transpose_(X.dim() - 2, X.dim() - 1) + + codes, _, _ = E(X) + noise.normal_(mean=config['normal_mu'], std=config['normal_std']) + synth_logit = D(codes) + real_logit = D(noise) + loss_d = torch.mean(synth_logit) - torch.mean(real_logit) + + alpha = torch.rand(config['batch_size'], 1).to(device) + differences = codes - noise + interpolates = noise + alpha * differences + disc_interpolates = D(interpolates) + + gradients = grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates).to(device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) + gradient_penalty = ((slopes - 1) ** 2).mean() + loss_gp = config['gp_lambda'] * gradient_penalty + ### + loss_d += loss_gp + + D_optim.zero_grad() + D.zero_grad() + + loss_d.backward(retain_graph=True) + total_loss_d += loss_d.item() + D_optim.step() + + # EG part of training + X_rec = G(codes) + + loss_e = torch.mean( + config['reconstruction_coef'] * + reconstruction_loss(X.permute(0, 2, 1) + 0.5, + X_rec.permute(0, 2, 1) + 0.5)) + + synth_logit = D(codes) + + loss_g = -torch.mean(synth_logit) + + loss_eg = loss_e + loss_g + EG_optim.zero_grad() + E.zero_grad() + G.zero_grad() + + loss_eg.backward() + total_loss_eg += loss_eg.item() + EG_optim.step() + + log.debug(f'[{epoch}: ({i})] ' + f'Loss_D: {loss_d.item():.4f} ' + f'(GP: {loss_gp.item(): .4f}) ' + f'Loss_EG: {loss_eg.item():.4f} ' + f'(REC: {loss_e.item(): .4f}) ' + f'Time: {datetime.now() - start_epoch_time}') + + log.debug( + f'[{epoch}/{config["max_epochs"]}] ' + f'Loss_D: {total_loss_d / i:.4f} ' + f'Loss_EG: {total_loss_eg / i:.4f} ' + f'Time: {datetime.now() - start_epoch_time}' + ) + + # + # Save intermediate results + # + G.eval() + E.eval() + D.eval() + with torch.no_grad(): + fake = G(fixed_noise).data.cpu().numpy() + codes, _, _ = E(X) + X_rec = G(codes).data.cpu().numpy() + + for k in range(5): + fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig( + join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) + plt.close(fig) + + for k in range(5): + fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig( + join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png')) + plt.close(fig) + + for k in range(5): + fig = plot_3d_point_cloud(X_rec[k][0], + X_rec[k][1], + X_rec[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig(join(results_dir, 'samples', + f'{epoch:05}_{k}_reconstructed.png')) + plt.close(fig) + + if epoch % config['save_frequency'] == 0: + torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) + torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) + torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) + + torch.save(EG_optim.state_dict(), + join(weights_path, f'{epoch:05}_EGo.pth')) + + torch.save(D_optim.state_dict(), + join(weights_path, f'{epoch:05}_Do.pth')) + + +if __name__ == '__main__': + logger = logging.getLogger() + + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', default=None, type=str, + help='config file path') + args = parser.parse_args() + + config = None + if args.config is not None and args.config.endswith('.json'): + with open(args.config) as f: + config = json.load(f) + assert config is not None + + main(config) diff --git a/experiments/train_aae_binary.py b/experiments/train_aae_binary.py new file mode 100644 index 0000000..77fdc2b --- /dev/null +++ b/experiments/train_aae_binary.py @@ -0,0 +1,313 @@ +import argparse +import json +import logging +import random +from datetime import datetime +from importlib import import_module +from itertools import chain +from os.path import join, exists + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +from torch.autograd import grad +from torch.distributions import Bernoulli +from torch.utils.data import DataLoader + +from utils.pcutil import plot_3d_point_cloud +from utils.util import find_latest_epoch, prepare_results_dir, cuda_setup, setup_logging + +cudnn.benchmark = True + + +def weights_init(m): + classname = m.__class__.__name__ + if classname in ('Conv1d', 'Linear'): + torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + +def main(config): + random.seed(config['seed']) + torch.manual_seed(config['seed']) + torch.cuda.manual_seed_all(config['seed']) + + results_dir = prepare_results_dir(config) + starting_epoch = find_latest_epoch(results_dir) + 1 + + if not exists(join(results_dir, 'config.json')): + with open(join(results_dir, 'config.json'), mode='w') as f: + json.dump(config, f) + + setup_logging(results_dir) + log = logging.getLogger(__name__) + + device = cuda_setup(config['cuda'], config['gpu']) + log.debug(f'Device variable: {device}') + if device.type == 'cuda': + log.debug(f'Current CUDA device: {torch.cuda.current_device()}') + + weights_path = join(results_dir, 'weights') + + # + # Dataset + # + dataset_name = config['dataset'].lower() + if dataset_name == 'shapenet': + from datasets.shapenet import ShapeNetDataset + dataset = ShapeNetDataset(root_dir=config['data_dir'], + classes=config['classes']) + else: + raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' + f'`faust`. Got: `{dataset_name}`') + log.debug("Selected {} classes. Loaded {} samples.".format( + 'all' if not config['classes'] else ','.join(config['classes']), + len(dataset))) + + points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], + shuffle=config['shuffle'], + num_workers=config['num_workers'], + drop_last=True, pin_memory=True) + + # + # Models + # + arch = import_module(f"models.{config['arch']}") + G = arch.Generator(config).to(device) + E = arch.Encoder(config).to(device) + D = arch.Discriminator(config).to(device) + + G.apply(weights_init) + E.apply(weights_init) + D.apply(weights_init) + + if config['reconstruction_loss'].lower() == 'chamfer': + from losses.champfer_loss import ChamferLoss + reconstruction_loss = ChamferLoss().to(device) + elif config['reconstruction_loss'].lower() == 'earth_mover': + from losses.earth_mover_distance import EMD + reconstruction_loss = EMD().to(device) + else: + raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' + f'`earth_mover`, got: {config["reconstruction_loss"]}') + # + # Float Tensors + # + distribution = config['distribution'].lower() + if distribution == 'bernoulli': + p = torch.tensor(config['p']).to(device) + sampler = Bernoulli(probs=p) + fixed_noise = sampler.sample(torch.Size([config['batch_size'], + config['z_size']])) + elif distribution == 'beta': + fixed_noise_np = np.random.beta(config['z_beta_a'], + config['z_beta_b'], + size=(config['batch_size'], + config['z_size'])) + fixed_noise = torch.tensor(fixed_noise_np).float().to(device) + else: + raise ValueError('Invalid distribution for binaray model.') + + # + # Optimizers + # + EG_optim = getattr(optim, config['optimizer']['EG']['type']) + EG_optim = EG_optim(chain(E.parameters(), G.parameters()), + **config['optimizer']['EG']['hyperparams']) + + D_optim = getattr(optim, config['optimizer']['D']['type']) + D_optim = D_optim(D.parameters(), + **config['optimizer']['D']['hyperparams']) + + if starting_epoch > 1: + G.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_G.pth'))) + E.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_E.pth'))) + D.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_D.pth'))) + + D_optim.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_Do.pth'))) + + EG_optim.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) + + loss_d_tot, loss_gp_tot, loss_e_tot, loss_g_tot = [], [], [], [] + for epoch in range(starting_epoch, config['max_epochs'] + 1): + start_epoch_time = datetime.now() + + G.train() + E.train() + D.train() + + total_loss_eg = 0.0 + total_loss_d = 0.0 + for i, point_data in enumerate(points_dataloader, 1): + log.debug('-' * 20) + + X, _ = point_data + X = X.to(device) + + # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] + if X.size(-1) == 3: + X.transpose_(X.dim() - 2, X.dim() - 1) + + codes = E(X) + if distribution == 'bernoulli': + noise = sampler.sample(fixed_noise.size()) + elif distribution == 'beta': + noise_np = np.random.beta(config['z_beta_a'], + config['z_beta_b'], + size=(config['batch_size'], + config['z_size'])) + noise = torch.tensor(noise_np).float().to(device) + synth_logit = D(codes) + real_logit = D(noise) + loss_d = torch.mean(synth_logit) - torch.mean(real_logit) + loss_d_tot.append(loss_d) + + # Gradient Penalty + alpha = torch.rand(config['batch_size'], 1).to(device) + differences = codes - noise + interpolates = noise + alpha * differences + disc_interpolates = D(interpolates) + + gradients = grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates).to(device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + slopes = torch.sqrt(torch.sum(gradients ** 2, dim=1)) + gradient_penalty = ((slopes - 1) ** 2).mean() + loss_gp = config['gp_lambda'] * gradient_penalty + loss_gp_tot.append(loss_gp) + ### + + loss_d += loss_gp + + D_optim.zero_grad() + D.zero_grad() + + loss_d.backward(retain_graph=True) + total_loss_d += loss_d.item() + D_optim.step() + + # EG part of training + X_rec = G(codes) + + loss_e = torch.mean( + config['reconstruction_coef'] * + reconstruction_loss(X.permute(0, 2, 1) + 0.5, + X_rec.permute(0, 2, 1) + 0.5)) + loss_e_tot.append(loss_e) + + synth_logit = D(codes) + + loss_g = -torch.mean(synth_logit) + loss_g_tot.append(loss_g) + + loss_eg = loss_e + loss_g + EG_optim.zero_grad() + E.zero_grad() + G.zero_grad() + + loss_eg.backward() + total_loss_eg += loss_eg.item() + EG_optim.step() + + log.debug(f'[{epoch}: ({i})] ' + f'Loss_D: {loss_d.item():.4f} ' + f'(GP: {loss_gp.item(): .4f}) ' + f'Loss_EG: {loss_eg.item():.4f} ' + f'(REC: {loss_e.item(): .4f}) ' + f'Time: {datetime.now() - start_epoch_time}') + + log.debug( + f'[{epoch}/{config["max_epochs"]}] ' + f'Loss_D: {total_loss_d / i:.4f} ' + f'Loss_EG: {total_loss_eg / i:.4f} ' + f'Time: {datetime.now() - start_epoch_time}' + ) + + # + # Save intermediate results + # + G.eval() + E.eval() + D.eval() + with torch.no_grad(): + fake = G(fixed_noise).data.cpu().numpy() + X_rec = G(E(X)).data.cpu().numpy() + X = X.data.cpu().numpy() + + plt.figure(figsize=(16, 9)) + plt.plot(loss_d_tot, 'r-', label="loss_d") + plt.plot(loss_gp_tot, 'g-', label="loss_gp") + plt.plot(loss_e_tot, 'b-', label="loss_e") + plt.plot(loss_g_tot, 'k-', label="loss_g") + plt.legend() + plt.xlabel("Batch number") + plt.xlabel("Loss value") + plt.savefig( + join(results_dir, 'samples', f'loss_plot.png')) + plt.close() + + for k in range(5): + fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig( + join(results_dir, 'samples', f'{epoch:05}_{k}_real.png')) + plt.close(fig) + + for k in range(5): + fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig( + join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png')) + plt.close(fig) + + for k in range(5): + fig = plot_3d_point_cloud(X_rec[k][0], X_rec[k][1], X_rec[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig(join(results_dir, 'samples', + f'{epoch:05}_{k}_reconstructed.png')) + plt.close(fig) + + if epoch % config['save_frequency'] == 0: + torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) + torch.save(D.state_dict(), join(weights_path, f'{epoch:05}_D.pth')) + torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) + + torch.save(EG_optim.state_dict(), + join(weights_path, f'{epoch:05}_EGo.pth')) + + torch.save(D_optim.state_dict(), + join(weights_path, f'{epoch:05}_Do.pth')) + + +if __name__ == '__main__': + logger = logging.getLogger() + + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', default=None, type=str, + help='config file path') + args = parser.parse_args() + + config = None + if args.config is not None and args.config.endswith('.json'): + with open(args.config) as f: + config = json.load(f) + assert config is not None + + main(config) diff --git a/experiments/train_vae.py b/experiments/train_vae.py new file mode 100644 index 0000000..14079c1 --- /dev/null +++ b/experiments/train_vae.py @@ -0,0 +1,235 @@ +import argparse +import json +import logging +import random +from datetime import datetime +from importlib import import_module +from itertools import chain +from os.path import join, exists + +import matplotlib.pyplot as plt +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim as optim +import torch.utils.data +from torch.utils.data import DataLoader + +from utils.pcutil import plot_3d_point_cloud +from utils.util import find_latest_epoch, prepare_results_dir, cuda_setup, setup_logging + +cudnn.benchmark = True + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + gain = torch.nn.init.calculate_gain('relu') + torch.nn.init.xavier_uniform_(m.weight, gain) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + elif classname.find('BatchNorm') != -1: + torch.nn.init.constant_(m.weight, 1) + torch.nn.init.constant_(m.bias, 0) + elif classname.find('Linear') != -1: + gain = torch.nn.init.calculate_gain('relu') + torch.nn.init.xavier_uniform_(m.weight, gain) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + +def main(config): + random.seed(config['seed']) + torch.manual_seed(config['seed']) + torch.cuda.manual_seed_all(config['seed']) + + results_dir = prepare_results_dir(config) + starting_epoch = find_latest_epoch(results_dir) + 1 + + if not exists(join(results_dir, 'config.json')): + with open(join(results_dir, 'config.json'), mode='w') as f: + json.dump(config, f) + + setup_logging(results_dir) + log = logging.getLogger(__name__) + + device = cuda_setup(config['cuda'], config['gpu']) + log.debug(f'Device variable: {device}') + if device.type == 'cuda': + log.debug(f'Current CUDA device: {torch.cuda.current_device()}') + + weights_path = join(results_dir, 'weights') + + # + # Dataset + # + dataset_name = config['dataset'].lower() + if dataset_name == 'shapenet': + from datasets.shapenet import ShapeNetDataset + dataset = ShapeNetDataset(root_dir=config['data_dir'], + classes=config['classes']) + else: + raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' + f'`faust`. Got: `{dataset_name}`') + + log.debug("Selected {} classes. Loaded {} samples.".format( + 'all' if not config['classes'] else ','.join(config['classes']), + len(dataset))) + + points_dataloader = DataLoader(dataset, batch_size=config['batch_size'], + shuffle=config['shuffle'], + num_workers=config['num_workers'], + drop_last=True, pin_memory=True) + + # + # Models + # + arch = import_module(f"models.{config['arch']}") + G = arch.Generator(config).to(device) + E = arch.Encoder(config).to(device) + + G.apply(weights_init) + E.apply(weights_init) + + if config['reconstruction_loss'].lower() == 'chamfer': + from losses.champfer_loss import ChamferLoss + reconstruction_loss = ChamferLoss().to(device) + elif config['reconstruction_loss'].lower() == 'earth_mover': + from losses.earth_mover_distance import EMD + reconstruction_loss = EMD().to(device) + else: + raise ValueError(f'Invalid reconstruction loss. Accepted `chamfer` or ' + f'`earth_mover`, got: {config["reconstruction_loss"]}') + # + # Float Tensors + # + fixed_noise = torch.FloatTensor(config['batch_size'], config['z_size'], 1) + fixed_noise.normal_(mean=0, std=0.2) + std_assumed = torch.tensor(0.2) + + fixed_noise = fixed_noise.to(device) + std_assumed = std_assumed.to(device) + + # + # Optimizers + # + EG_optim = getattr(optim, config['optimizer']['EG']['type']) + EG_optim = EG_optim(chain(E.parameters(), G.parameters()), + **config['optimizer']['EG']['hyperparams']) + + if starting_epoch > 1: + G.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_G.pth'))) + E.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_E.pth'))) + + EG_optim.load_state_dict(torch.load( + join(weights_path, f'{starting_epoch-1:05}_EGo.pth'))) + + for epoch in range(starting_epoch, config['max_epochs'] + 1): + start_epoch_time = datetime.now() + + G.train() + E.train() + + total_loss = 0.0 + for i, point_data in enumerate(points_dataloader, 1): + log.debug('-' * 20) + + X, _ = point_data + X = X.to(device) + + # Change dim [BATCH, N_POINTS, N_DIM] -> [BATCH, N_DIM, N_POINTS] + if X.size(-1) == 3: + X.transpose_(X.dim() - 2, X.dim() - 1) + + codes, mu, logvar = E(X) + X_rec = G(codes) + + loss_e = torch.mean( + config['reconstruction_coef'] * + reconstruction_loss(X.permute(0, 2, 1) + 0.5, + X_rec.permute(0, 2, 1) + 0.5)) + + loss_kld = -0.5 * torch.mean( + 1 - 2.0 * torch.log(std_assumed) + logvar - + (mu.pow(2) + logvar.exp()) / torch.pow(std_assumed, 2)) + + loss_eg = loss_e + loss_kld + EG_optim.zero_grad() + E.zero_grad() + G.zero_grad() + + loss_eg.backward() + total_loss += loss_eg.item() + EG_optim.step() + + log.debug(f'[{epoch}: ({i})] ' + f'Loss_EG: {loss_eg.item():.4f} ' + f'(REC: {loss_e.item(): .4f}' + f' KLD: {loss_kld.item(): .4f})' + f' Time: {datetime.now() - start_epoch_time}') + + log.debug( + f'[{epoch}/{config["max_epochs"]}] ' + f'Loss_G: {total_loss / i:.4f} ' + f'Time: {datetime.now() - start_epoch_time}' + ) + + # + # Save intermediate results + # + G.eval() + E.eval() + with torch.no_grad(): + fake = G(fixed_noise).data.cpu().numpy() + codes, _, _ = E(X) + X_rec = G(codes).data.cpu().numpy() + + for k in range(5): + fig = plot_3d_point_cloud(X[k][0], X[k][1], X[k][2], + in_u_sphere=True, show=False) + fig.savefig( + join(results_dir, 'samples', f'{epoch}_{k}_real.png')) + plt.close(fig) + + for k in range(5): + fig = plot_3d_point_cloud(fake[k][0], fake[k][1], fake[k][2], + in_u_sphere=True, show=False, + title=str(epoch)) + fig.savefig( + join(results_dir, 'samples', f'{epoch:05}_{k}_fixed.png')) + plt.close(fig) + + for k in range(5): + fig = plot_3d_point_cloud(X_rec[k][0], + X_rec[k][1], + X_rec[k][2], + in_u_sphere=True, show=False) + fig.savefig(join(results_dir, 'samples', + f'{epoch}_{k}_reconstructed.png')) + plt.close(fig) + + if epoch % config['save_frequency'] == 0: + torch.save(G.state_dict(), join(weights_path, f'{epoch:05}_G.pth')) + torch.save(E.state_dict(), join(weights_path, f'{epoch:05}_E.pth')) + + torch.save(EG_optim.state_dict(), + join(weights_path, f'{epoch:05}_EGo.pth')) + + +if __name__ == '__main__': + logger = logging.getLogger() + + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', default=None, type=str, + help='config file path') + args = parser.parse_args() + + config = None + if args.config is not None and args.config.endswith('.json'): + with open(args.config) as f: + config = json.load(f) + assert config is not None + + main(config) diff --git a/losses/champfer_loss.py b/losses/champfer_loss.py new file mode 100644 index 0000000..b4e5305 --- /dev/null +++ b/losses/champfer_loss.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + + +class ChamferLoss(nn.Module): + + def __init__(self): + super(ChamferLoss, self).__init__() + self.use_cuda = torch.cuda.is_available() + + def forward(self, preds, gts): + P = self.batch_pairwise_dist(gts, preds) + mins, _ = torch.min(P, 1) + loss_1 = torch.sum(mins) + mins, _ = torch.min(P, 2) + loss_2 = torch.sum(mins) + return loss_1 + loss_2 + + def batch_pairwise_dist(self, x, y): + bs, num_points_x, points_dim = x.size() + _, num_points_y, _ = y.size() + xx = torch.bmm(x, x.transpose(2, 1)) + yy = torch.bmm(y, y.transpose(2, 1)) + zz = torch.bmm(x, y.transpose(2, 1)) + if self.use_cuda: + dtype = torch.cuda.LongTensor + else: + dtype = torch.LongTensor + diag_ind_x = torch.arange(0, num_points_x).type(dtype) + diag_ind_y = torch.arange(0, num_points_y).type(dtype) + rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as( + zz.transpose(2, 1)) + ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) + P = rx.transpose(2, 1) + ry - 2 * zz + return P diff --git a/losses/earth_mover_distance.py b/losses/earth_mover_distance.py new file mode 100644 index 0000000..8c5bb3c --- /dev/null +++ b/losses/earth_mover_distance.py @@ -0,0 +1,385 @@ +import torch +from pyinn.utils import Stream, load_kernel +from torch.autograd import Function +from torch.nn.modules.module import Module + + +class EMD(Module): + def __init__(self): + super().__init__() + self.emd_function = EMDFunction.apply + + def forward(self, input1, input2): + return self.emd_function(input1, input2) + + +class EMDFunction(Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + assert xyz1.dim() == 3 and xyz1.is_cuda and xyz2.is_cuda + assert xyz1.shape[-1] == 3 # as done by Panos + batch_size, num_pts, pt_dim = xyz1.size() + _, m, _ = xyz2.size() + + match = torch.zeros(batch_size, m, num_pts).cuda() + cost = torch.zeros(batch_size, ).cuda() + temp = torch.zeros(batch_size, 2 * (m + num_pts)).cuda() + + with torch.cuda.device_of(xyz1): + # 1) get matching + f = load_kernel('approxmatch', approxmatch_kernel) + f(block=(512, 1, 1), # (CUDA_NUM_THREADS,1,1), + grid=(32, 1, 1), # GET_BLOCKS(n),1,1), + args=[batch_size, num_pts, m, xyz1.data_ptr(), xyz2.data_ptr(), + match.data_ptr(), temp.data_ptr()], + stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) + + # 2) calculate matching cost + g = load_kernel('matchcost', matchcost_kernel) + g(block=(512, 1, 1), # (CUDA_NUM_THREADS, 1, 1), + grid=(32, 1, 1), # (GET_BLOCKS(n), 1, 1), + args=[batch_size, num_pts, m, xyz1.data_ptr(), xyz2.data_ptr(), + match.data_ptr(), cost.data_ptr()], + stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) + + ctx.save_for_backward(xyz1, xyz2, match) + del temp + + return cost + + @staticmethod + def backward(ctx, grad_cost): + xyz1, xyz2, match = ctx.saved_tensors + + batch_size, num_pts, _ = xyz1.size() + _, m, _ = xyz2.size() + + grad1 = torch.zeros_like(xyz1).cuda() + grad2 = torch.zeros_like(xyz2).cuda() + + with torch.cuda.device_of(grad_cost): + if xyz1.requires_grad: + f = load_kernel('matchcostgrad1', matchcostgrad1_kernel) + f(block=(512, 1, 1), # (CUDA_NUM_THREADS, 1, 1), + grid=(32, 1, 1), # (GET_BLOCKS(xyz1.numel()), 1, 1), + args=[batch_size, num_pts, m, xyz1.data_ptr(), + xyz2.data_ptr(), match.data_ptr(), grad1.data_ptr()], + stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) + + if xyz2.requires_grad: + g = load_kernel('matchcostgrad2', matchcostgrad2_kernel) + g(block=(256, 1, 1), # (CUDA_NUM_THREADS, 1, 1), + grid=(32, 32, 1), # (GET_BLOCKS(xyz2.numel()), 1, 1), + args=[batch_size, num_pts, m, xyz1.data_ptr(), + xyz2.data_ptr(), match.data_ptr(), grad2.data_ptr()], + stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) + + return grad1 * grad_cost.view(-1, 1, 1), grad2 * grad_cost.view(-1, 1, + 1) + + +approxmatch_kernel = ''' +extern "C" +__global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ + float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + float multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ float buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + float level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); +# } + + + diff --git a/models/avae.py b/models/avae.py new file mode 100644 index 0000000..9b26604 --- /dev/null +++ b/models/avae.py @@ -0,0 +1,117 @@ +import torch +import torch.nn as nn + +_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class Generator(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + self.use_bias = config['model']['G']['use_bias'] + self.relu_slope = config['model']['G']['relu_slope'] + self.model = nn.Sequential( + nn.Linear(in_features=self.z_size, out_features=64, bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=64, out_features=128, bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=128, out_features=512, bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=512, out_features=1024, bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=1024, out_features=2048 * 3, bias=self.use_bias), + ) + + def forward(self, input): + output = self.model(input.squeeze()) + output = output.view(-1, 3, 2048) + return output + + +class Discriminator(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + self.use_bias = config['model']['D']['use_bias'] + self.relu_slope = config['model']['D']['relu_slope'] + self.dropout = config['model']['D']['dropout'] + + self.pc_discriminator_fc = nn.Sequential( + + nn.Linear(self.z_size, 512, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(512, 512, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(512, 128, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(128, 64, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(64, 1, bias=True) + ) + + def forward(self, x): + logit = self.pc_discriminator_fc(x) + return logit + + +class Encoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + self.use_bias = config['model']['E']['use_bias'] + self.relu_slope = config['model']['E']['relu_slope'] + + self.pc_discriminator_conv = nn.Sequential( + nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=128, out_channels=256, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=256, out_channels=256, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1, + bias=self.use_bias), + ) + + self.pc_discriminator_fc = nn.Sequential( + nn.Linear(512, 256, bias=True), + nn.ReLU(inplace=True) + ) + + self.mu_layer = nn.Linear(256, self.z_size, bias=True) + self.std_layer = nn.Linear(256, self.z_size, bias=True) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5*logvar) + eps = torch.randn_like(std) + return eps.mul(std).add_(mu) + + def forward(self, x): + output = self.pc_discriminator_conv(x) + output2 = output.max(dim=2)[0] + logit = self.pc_discriminator_fc(output2) + mu = self.mu_layer(logit) + logvar = self.std_layer(logit) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + diff --git a/models/avae_bin.py b/models/avae_bin.py new file mode 100644 index 0000000..193b168 --- /dev/null +++ b/models/avae_bin.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn + +_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class Generator(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + use_bias = config['model']['G']['use_bias'] + self.sigmoid = nn.Sigmoid() + self.model = nn.Sequential( + nn.Linear(in_features=self.z_size, out_features=64, bias=use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=64, out_features=128, bias=use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=128, out_features=512, bias=use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=512, out_features=1024, bias=use_bias), + nn.ReLU(inplace=True), + + nn.Linear(in_features=1024, out_features=2048 * 3, bias=use_bias), + ) + + # self.model = nn.DataParallel(self.model) + + def forward(self, input): + output = self.model(input.squeeze()) + output = output.view(-1, 3, 2048) + return output + + +class Discriminator(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + self.use_bias = config['model']['D']['use_bias'] + self.relu_slope = config['model']['D']['relu_slope'] + self.dropout = config['model']['D']['dropout'] + + self.model = nn.Sequential( + + nn.Linear(self.z_size, 512, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(512, 512, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(512, 128, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(128, 64, bias=True), + nn.ReLU(inplace=True), + + nn.Linear(64, 1, bias=True) + ) + + # self.model = nn.DataParallel(self.model) + + def forward(self, x): + logit = self.model(x) + return logit + + +class Encoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + self.use_bias = config['model']['E']['use_bias'] + self.relu_slope = config['model']['E']['relu_slope'] + + self.pc_encoder_conv = nn.Sequential( + nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=128, out_channels=256, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=256, out_channels=256, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1, + bias=self.use_bias), + ) + + self.pc_encoder_fc = nn.Sequential( + nn.Linear(512, 256, bias=True), + nn.ReLU(inplace=True), + nn.Linear(256, self.z_size, bias=True) + ) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + output = self.pc_encoder_conv(x) + output2 = output.max(dim=2)[0] + logit = self.pc_encoder_fc(output2) + z = self.sigmoid(logit) + return z diff --git a/models/avae_triplet.py b/models/avae_triplet.py new file mode 100644 index 0000000..794648f --- /dev/null +++ b/models/avae_triplet.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + +_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +class Encoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.z_size = config['z_size'] + self.use_bias = config['model']['E']['use_bias'] + self.relu_slope = config['model']['E']['relu_slope'] + + self.pc_encoder_conv = nn.Sequential( + nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=128, out_channels=256, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=256, out_channels=256, kernel_size=1, + bias=self.use_bias), + nn.ReLU(inplace=True), + + nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1, + bias=self.use_bias), + ) + + self.pc_encoder_fc = nn.Sequential( + nn.Linear(512, 256, bias=True), + nn.ReLU(inplace=True), + nn.Linear(256, self.z_size, bias=True) + ) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + output = self.pc_encoder_conv(x) + output2 = output.max(dim=2)[0] + logit = self.pc_encoder_fc(output2) + return logit diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e0a4503 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +h5py +matplotlib +numpy +pandas +git+https://github.com/szagoruyko/pyinn.git@master +torch==0.4.1 \ No newline at end of file diff --git a/settings/hyperparams.json b/settings/hyperparams.json new file mode 100644 index 0000000..8c455f1 --- /dev/null +++ b/settings/hyperparams.json @@ -0,0 +1,76 @@ +{ + "experiment_name": "experiment", + "results_root": "/Users/maciek/test-results", + "clean_results_dir": false, + + "cuda": true, + "gpu": 3, + + "reconstruction_loss": "chamfer", + + "metrics": [ + ], + + "dataset": "shapenet", + "data_dir": "/Users/maciek/shapenet", + "classes": [], + "shuffle": true, + "transforms": ["rotate"], + "num_workers": 8, + "n_points": 2048, + + "max_epochs": 2000, + "batch_size": 50, + "gp_lambda": 10, + "reconstruction_coef": 0.05, + "z_size": 2048, + "distribution": "bernoulli", + + "p": 0.2, + "z_beta_a": 0.01, + "z_beta_b": 0.01, + + "normal_mu": 0.0, + "normal_std": 0.2, + + "seed": 2018, + "save_frequency": 5, + "epsilon": 0.001, + + "arch": "avae", + "model": { + "D": { + "dropout": 0.5, + "use_bias": true, + "relu_slope": 0.2 + }, + "G": { + "use_bias": true, + "relu_slope": 0.2 + }, + "E": { + "use_bias": true, + "relu_slope": 0.2 + } + }, + "optimizer": { + "D": { + "type": "Adam", + "hyperparams": { + "lr": 0.0005, + "weight_decay": 0, + "betas": [0.9, 0.999], + "amsgrad": false + } + }, + "EG": { + "type": "Adam", + "hyperparams": { + "lr": 0.0005, + "weight_decay": 0, + "betas": [0.9, 0.999], + "amsgrad": false + } + } + } +} \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000..cda704b --- /dev/null +++ b/utils/data.py @@ -0,0 +1,140 @@ +import math +import numpy as np +import os +import pandas as pd +import pickle + +from decimal import Decimal +from itertools import accumulate, tee, chain +from typing import List, Tuple, Dict, Optional, Any, Set + +from utils.plyfile import load_ply + +READERS = { + '.ply': load_ply, + '.np': lambda file_path: pickle.load(open(file_path, 'rb')), +} + + +def load_file(file_path): + _, ext = os.path.splitext(file_path) + return READERS[ext](file_path) + + +def add_float(a, b): + return float(Decimal(str(a)) + Decimal(str(b))) + + +def ranges(values: List[float]) -> List[Tuple[float]]: + lower, upper = tee(accumulate(values, add_float)) + lower = chain([0], lower) + + return zip(lower, upper) + + +def make_slices(values: List[float], N: int): + slices = [slice(int(N * s), int(N * e)) for s, e in ranges(values)] + return slices + + +def make_splits( + data: pd.DataFrame, + splits: Dict[str, float], + seed: Optional[int] = None): + + # assert correctness + if not math.isclose(sum(splits.values()), 1.0): + values = " ".join([f"{k} : {v}" for k, v in splits.items()]) + raise ValueError(f"{values} should sum up to 1") + + # shuffle with random seed + data = data.iloc[np.random.permutation(len(data))] + slices = make_slices(list(splits.values()), len(data)) + + return { + name: data[idxs].reset_index(drop=True) for name, idxs in zip(splits.keys(), slices) + } + + +def sample_other_than(black_list: Set[int], x: np.ndarray) -> int: + res = np.random.randint(0, len(x)) + while res in black_list: + res = np.random.randint(0, len(x)) + + return res + + +def clip_cloud(p: np.ndarray) -> np.ndarray: + # create list of extreme points + black_list = set(np.hstack([ + np.argmax(p, axis=0), np.argmin(p, axis=0) + ])) + + # swap any other point + for idx in black_list: + p[idx] = p[sample_other_than(black_list, p)] + + return p + + +def find_extrema(xs, n_cols: int=3, clip: bool=True) -> Dict[Any, List[float]]: + from collections import defaultdict + + mins = defaultdict(lambda: [np.inf for _ in range(n_cols)]) + maxs = defaultdict(lambda: [-np.inf for _ in range(n_cols)]) + + for x, c in xs: + x = clip_cloud(x) if clip else x + mins[c] = [min(old, new) for old, new in zip(mins[c], np.min(x, axis=0))] + maxs[c] = [max(old, new) for old, new in zip(maxs[c], np.max(x, axis=0))] + + return mins, maxs + + +def merge_dicts( + dict_old: Dict[Any, List[float]], + dict_new: Dict[Any, List[float]], op=min) -> Dict[Any, List[float]]: + ''' + Simply takes values on List of floats for given key + ''' + d_out = {** dict_old} + for k, v in dict_new.items(): + if k in dict_old: + d_out[k] = [op(new, old) for old, new in zip(dict_new[k], dict_old[k])] + else: + d_out[k] = dict_new[k] + + return d_out + + +def save_extrema(clazz, root_dir, splits=('train', 'test', 'valid')): + ''' + Maybe this should be class dependent normalization? + ''' + min_dict, max_dict = {}, {} + for split in splits: + data = clazz(root_dir=root_dir, split=split, remap=False) + mins, maxs = find_extrema(data) + min_dict = merge_dicts(min_dict, mins, min) + max_dict = merge_dicts(max_dict, maxs, max) + + # vectorzie values + for d in (min_dict, max_dict): + for k in d: + d[k] = np.array(d[k]) + + with open(os.path.join(root_dir, 'extrema.np'), 'wb') as f: + pickle.dump((min_dict, max_dict), f) + + +def remap(old_value: np.ndarray, + old_min: np.ndarray, old_max: np.ndarray, + new_min: float = -0.5, new_max: float = 0.5) -> np.ndarray: + ''' + Remap reange + ''' + old_range = (old_max - old_min) + new_range = (new_max - new_min) + new_value = (((old_value - old_min) * new_range) / old_range) + new_min + + return new_value diff --git a/utils/pcutil.py b/utils/pcutil.py new file mode 100644 index 0000000..908192d --- /dev/null +++ b/utils/pcutil.py @@ -0,0 +1,160 @@ +import matplotlib.pyplot as plt +import numpy as np +from numpy.linalg import norm + +# Don't delete this line, even if PyCharm says it's an unused import. +# It is required for projection='3d' in add_subplot() +from mpl_toolkits.mplot3d import Axes3D + + +def rand_rotation_matrix(deflection=1.0, seed=None): + """Creates a random rotation matrix. + + Args: + deflection: the magnitude of the rotation. For 0, no rotation; for 1, + completely random rotation. Small deflection => small + perturbation. + + DOI: http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c + http://blog.lostinmyterminal.com/python/2015/05/12/random-rotation-matrix.html + """ + if seed is not None: + np.random.seed(seed) + + theta, phi, z = np.random.uniform(size=(3,)) + + theta = theta * 2.0 * deflection * np.pi # Rotation about the pole (Z). + phi = phi * 2.0 * np.pi # For direction of pole deflection. + z = z * 2.0 * deflection # For magnitude of pole deflection. + + # Compute a vector V used for distributing points over the sphere + # via the reflection I - V Transpose(V). This formulation of V + # will guarantee that if x[1] and x[2] are uniformly distributed, + # the reflected points will be uniform on the sphere. Note that V + # has length sqrt(2) to eliminate the 2 in the Householder matrix. + + r = np.sqrt(z) + V = (np.sin(phi) * r, + np.cos(phi) * r, + np.sqrt(2.0 - z)) + + st = np.sin(theta) + ct = np.cos(theta) + + R = np.array(((ct, st, 0), (-st, ct, 0), (0, 0, 1))) + + # Construct the rotation matrix ( V Transpose(V) - I ) R. + M = (np.outer(V, V) - np.eye(3)).dot(R) + return M + + +def add_gaussian_noise_to_pcloud(pcloud, mu=0, sigma=1): + gnoise = np.random.normal(mu, sigma, pcloud.shape[0]) + gnoise = np.tile(gnoise, (3, 1)).T + pcloud += gnoise + return pcloud + + +def add_rotation_to_pcloud(pcloud): + r_rotation = rand_rotation_matrix() + + if len(pcloud.shape) == 2: + return pcloud.dot(r_rotation) + else: + return np.asarray([e.dot(r_rotation) for e in pcloud]) + + +def apply_augmentations(batch, conf): + if conf.gauss_augment is not None or conf.z_rotate: + batch = batch.copy() + + if conf.gauss_augment is not None: + mu = conf.gauss_augment['mu'] + sigma = conf.gauss_augment['sigma'] + batch += np.random.normal(mu, sigma, batch.shape) + + if conf.z_rotate: + r_rotation = rand_rotation_matrix() + r_rotation[0, 2] = 0 + r_rotation[2, 0] = 0 + r_rotation[1, 2] = 0 + r_rotation[2, 1] = 0 + r_rotation[2, 2] = 1 + batch = batch.dot(r_rotation) + return batch + + +def unit_cube_grid_point_cloud(resolution, clip_sphere=False): + """Returns the center coordinates of each cell of a 3D grid with + resolution^3 cells, that is placed in the unit-cube. + If clip_sphere it True it drops the "corner" cells that lie outside + the unit-sphere. + """ + grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) + spacing = 1.0 / float(resolution - 1) + for i in range(resolution): + for j in range(resolution): + for k in range(resolution): + grid[i, j, k, 0] = i * spacing - 0.5 + grid[i, j, k, 1] = j * spacing - 0.5 + grid[i, j, k, 2] = k * spacing - 0.5 + + if clip_sphere: + grid = grid.reshape(-1, 3) + grid = grid[norm(grid, axis=1) <= 0.5] + + return grid, spacing + + +def plot_3d_point_cloud(x, y, z, show=True, show_axis=True, in_u_sphere=False, + marker='.', s=8, alpha=.8, figsize=(5, 5), elev=10, + azim=240, axis=None, title=None, *args, **kwargs): + plt.switch_backend('agg') + if axis is None: + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection='3d') + else: + ax = axis + fig = axis + + if title is not None: + plt.title(title) + + sc = ax.scatter(x, y, z, marker=marker, s=s, alpha=alpha, *args, **kwargs) + ax.view_init(elev=elev, azim=azim) + + if in_u_sphere: + ax.set_xlim3d(-0.5, 0.5) + ax.set_ylim3d(-0.5, 0.5) + ax.set_zlim3d(-0.5, 0.5) + else: + # Multiply with 0.7 to squeeze free-space. + miv = 0.7 * np.min([np.min(x), np.min(y), np.min(z)]) + mav = 0.7 * np.max([np.max(x), np.max(y), np.max(z)]) + ax.set_xlim(miv, mav) + ax.set_ylim(miv, mav) + ax.set_zlim(miv, mav) + plt.tight_layout() + + if not show_axis: + plt.axis('off') + + if 'c' in kwargs: + plt.colorbar(sc) + + if show: + plt.show() + + return fig + + +def transform_point_clouds(X, only_z_rotation=False, deflection=1.0): + r_rotation = rand_rotation_matrix(deflection) + if only_z_rotation: + r_rotation[0, 2] = 0 + r_rotation[2, 0] = 0 + r_rotation[1, 2] = 0 + r_rotation[2, 1] = 0 + r_rotation[2, 2] = 1 + X = X.dot(r_rotation).astype(np.float32) + return X diff --git a/utils/plyfile.py b/utils/plyfile.py new file mode 100755 index 0000000..21fb641 --- /dev/null +++ b/utils/plyfile.py @@ -0,0 +1,941 @@ +# Copyright 2014 Darsh Ranjan +# +# This file is part of python-plyfile. +# +# python-plyfile is free software: you can redistribute it and/or +# modify it under the terms of the GNU General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# python-plyfile is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with python-plyfile. If not, see +# . + +from itertools import islice as _islice + +import numpy as _np +from sys import byteorder as _byteorder + + +try: + _range = xrange +except NameError: + _range = range + + +# Many-many relation +_data_type_relation = [ + ('int8', 'i1'), + ('char', 'i1'), + ('uint8', 'u1'), + ('uchar', 'b1'), + ('uchar', 'u1'), + ('int16', 'i2'), + ('short', 'i2'), + ('uint16', 'u2'), + ('ushort', 'u2'), + ('int32', 'i4'), + ('int', 'i4'), + ('uint32', 'u4'), + ('uint', 'u4'), + ('float32', 'f4'), + ('float', 'f4'), + ('float64', 'f8'), + ('double', 'f8') +] + +_data_types = dict(_data_type_relation) +_data_type_reverse = dict((b, a) for (a, b) in _data_type_relation) + +_types_list = [] +_types_set = set() +for (_a, _b) in _data_type_relation: + if _a not in _types_set: + _types_list.append(_a) + _types_set.add(_a) + if _b not in _types_set: + _types_list.append(_b) + _types_set.add(_b) + + +_byte_order_map = { + 'ascii': '=', + 'binary_little_endian': '<', + 'binary_big_endian': '>' +} + +_byte_order_reverse = { + '<': 'binary_little_endian', + '>': 'binary_big_endian' +} + +_native_byte_order = {'little': '<', 'big': '>'}[_byteorder] + + +def _lookup_type(type_str): + if type_str not in _data_type_reverse: + try: + type_str = _data_types[type_str] + except KeyError: + raise ValueError("field type %r not in %r" % + (type_str, _types_list)) + + return _data_type_reverse[type_str] + + +def _split_line(line, n): + fields = line.split(None, n) + if len(fields) == n: + fields.append('') + + assert len(fields) == n + 1 + + return fields + + +def make2d(array, cols=None, dtype=None): + """ + Make a 2D array from an array of arrays. The `cols' and `dtype' + arguments can be omitted if the array is not empty. + + """ + if (cols is None or dtype is None) and not len(array): + raise RuntimeError("cols and dtype must be specified for empty " + "array") + + if cols is None: + cols = len(array[0]) + + if dtype is None: + dtype = array[0].dtype + + return _np.fromiter(array, [('_', dtype, (cols,))], + count=len(array))['_'] + + +class PlyParseError(Exception): + + """ + Raised when a PLY file cannot be parsed. + + The attributes `element', `row', `property', and `message' give + additional information. + + """ + + def __init__(self, message, element=None, row=None, prop=None): + self.message = message + self.element = element + self.row = row + self.prop = prop + + s = '' + if self.element: + s += 'element %r: ' % self.element.name + if self.row is not None: + s += 'row %d: ' % self.row + if self.prop: + s += 'property %r: ' % self.prop.name + s += self.message + + Exception.__init__(self, s) + + def __repr__(self): + return ('PlyParseError(%r, element=%r, row=%r, prop=%r)' % + self.message, self.element, self.row, self.prop) + + +class PlyData(object): + + """ + PLY file header and data. + + A PlyData instance is created in one of two ways: by the static + method PlyData.read (to read a PLY file), or directly from __init__ + given a sequence of elements (which can then be written to a PLY + file). + + """ + + def __init__(self, elements=[], text=False, byte_order='=', + comments=[], obj_info=[]): + """ + elements: sequence of PlyElement instances. + + text: whether the resulting PLY file will be text (True) or + binary (False). + + byte_order: '<' for little-endian, '>' for big-endian, or '=' + for native. This is only relevant if `text' is False. + + comments: sequence of strings that will be placed in the header + between the 'ply' and 'format ...' lines. + + obj_info: like comments, but will be placed in the header with + "obj_info ..." instead of "comment ...". + + """ + if byte_order == '=' and not text: + byte_order = _native_byte_order + + self.byte_order = byte_order + self.text = text + + self.comments = list(comments) + self.obj_info = list(obj_info) + self.elements = elements + + def _get_elements(self): + return self._elements + + def _set_elements(self, elements): + self._elements = tuple(elements) + self._index() + + elements = property(_get_elements, _set_elements) + + def _get_byte_order(self): + return self._byte_order + + def _set_byte_order(self, byte_order): + if byte_order not in ['<', '>', '=']: + raise ValueError("byte order must be '<', '>', or '='") + + self._byte_order = byte_order + + byte_order = property(_get_byte_order, _set_byte_order) + + def _index(self): + self._element_lookup = dict((elt.name, elt) for elt in + self._elements) + if len(self._element_lookup) != len(self._elements): + raise ValueError("two elements with same name") + + @staticmethod + def _parse_header(stream): + """ + Parse a PLY header from a readable file-like stream. + + """ + lines = [] + comments = {'comment': [], 'obj_info': []} + while True: + line = stream.readline().decode('ascii').strip() + fields = _split_line(line, 1) + + if fields[0] == 'end_header': + break + + elif fields[0] in comments.keys(): + lines.append(fields) + else: + lines.append(line.split()) + + a = 0 + if lines[a] != ['ply']: + raise PlyParseError("expected 'ply'") + + a += 1 + while lines[a][0] in comments.keys(): + comments[lines[a][0]].append(lines[a][1]) + a += 1 + + if lines[a][0] != 'format': + raise PlyParseError("expected 'format'") + + if lines[a][2] != '1.0': + raise PlyParseError("expected version '1.0'") + + if len(lines[a]) != 3: + raise PlyParseError("too many fields after 'format'") + + fmt = lines[a][1] + + if fmt not in _byte_order_map: + raise PlyParseError("don't understand format %r" % fmt) + + byte_order = _byte_order_map[fmt] + text = fmt == 'ascii' + + a += 1 + while a < len(lines) and lines[a][0] in comments.keys(): + comments[lines[a][0]].append(lines[a][1]) + a += 1 + + return PlyData(PlyElement._parse_multi(lines[a:]), + text, byte_order, + comments['comment'], comments['obj_info']) + + @staticmethod + def read(stream): + """ + Read PLY data from a readable file-like object or filename. + + """ + (must_close, stream) = _open_stream(stream, 'read') + try: + data = PlyData._parse_header(stream) + for elt in data: + elt._read(stream, data.text, data.byte_order) + finally: + if must_close: + stream.close() + + return data + + def write(self, stream): + """ + Write PLY data to a writeable file-like object or filename. + + """ + (must_close, stream) = _open_stream(stream, 'write') + try: + stream.write(self.header.encode('ascii')) + stream.write(b'\r\n') + for elt in self: + elt._write(stream, self.text, self.byte_order) + finally: + if must_close: + stream.close() + + @property + def header(self): + """ + Provide PLY-formatted metadata for the instance. + + """ + lines = ['ply'] + + if self.text: + lines.append('format ascii 1.0') + else: + lines.append('format ' + + _byte_order_reverse[self.byte_order] + + ' 1.0') + + # Some information is lost here, since all comments are placed + # between the 'format' line and the first element. + for c in self.comments: + lines.append('comment ' + c) + + for c in self.obj_info: + lines.append('obj_info ' + c) + + lines.extend(elt.header for elt in self.elements) + lines.append('end_header') + return '\r\n'.join(lines) + + def __iter__(self): + return iter(self.elements) + + def __len__(self): + return len(self.elements) + + def __contains__(self, name): + return name in self._element_lookup + + def __getitem__(self, name): + return self._element_lookup[name] + + def __str__(self): + return self.header + + def __repr__(self): + return ('PlyData(%r, text=%r, byte_order=%r, ' + 'comments=%r, obj_info=%r)' % + (self.elements, self.text, self.byte_order, + self.comments, self.obj_info)) + + +def _open_stream(stream, read_or_write): + if hasattr(stream, read_or_write): + return (False, stream) + try: + return (True, open(stream, read_or_write[0] + 'b')) + except TypeError: + raise RuntimeError("expected open file or filename") + + +class PlyElement(object): + + """ + PLY file element. + + A client of this library doesn't normally need to instantiate this + directly, so the following is only for the sake of documenting the + internals. + + Creating a PlyElement instance is generally done in one of two ways: + as a byproduct of PlyData.read (when reading a PLY file) and by + PlyElement.describe (before writing a PLY file). + + """ + + def __init__(self, name, properties, count, comments=[]): + """ + This is not part of the public interface. The preferred methods + of obtaining PlyElement instances are PlyData.read (to read from + a file) and PlyElement.describe (to construct from a numpy + array). + + """ + self._name = str(name) + self._check_name() + self._count = count + + self._properties = tuple(properties) + self._index() + + self.comments = list(comments) + + self._have_list = any(isinstance(p, PlyListProperty) + for p in self.properties) + + @property + def count(self): + return self._count + + def _get_data(self): + return self._data + + def _set_data(self, data): + self._data = data + self._count = len(data) + self._check_sanity() + + data = property(_get_data, _set_data) + + def _check_sanity(self): + for prop in self.properties: + if prop.name not in self._data.dtype.fields: + raise ValueError("dangling property %r" % prop.name) + + def _get_properties(self): + return self._properties + + def _set_properties(self, properties): + self._properties = tuple(properties) + self._check_sanity() + self._index() + + properties = property(_get_properties, _set_properties) + + def _index(self): + self._property_lookup = dict((prop.name, prop) + for prop in self._properties) + if len(self._property_lookup) != len(self._properties): + raise ValueError("two properties with same name") + + def ply_property(self, name): + return self._property_lookup[name] + + @property + def name(self): + return self._name + + def _check_name(self): + if any(c.isspace() for c in self._name): + msg = "element name %r contains spaces" % self._name + raise ValueError(msg) + + def dtype(self, byte_order='='): + """ + Return the numpy dtype of the in-memory representation of the + data. (If there are no list properties, and the PLY format is + binary, then this also accurately describes the on-disk + representation of the element.) + + """ + return [(prop.name, prop.dtype(byte_order)) + for prop in self.properties] + + @staticmethod + def _parse_multi(header_lines): + """ + Parse a list of PLY element definitions. + + """ + elements = [] + while header_lines: + (elt, header_lines) = PlyElement._parse_one(header_lines) + elements.append(elt) + + return elements + + @staticmethod + def _parse_one(lines): + """ + Consume one element definition. The unconsumed input is + returned along with a PlyElement instance. + + """ + a = 0 + line = lines[a] + + if line[0] != 'element': + raise PlyParseError("expected 'element'") + if len(line) > 3: + raise PlyParseError("too many fields after 'element'") + if len(line) < 3: + raise PlyParseError("too few fields after 'element'") + + (name, count) = (line[1], int(line[2])) + + comments = [] + properties = [] + while True: + a += 1 + if a >= len(lines): + break + + if lines[a][0] == 'comment': + comments.append(lines[a][1]) + elif lines[a][0] == 'property': + properties.append(PlyProperty._parse_one(lines[a])) + else: + break + + return (PlyElement(name, properties, count, comments), + lines[a:]) + + @staticmethod + def describe(data, name, len_types={}, val_types={}, + comments=[]): + """ + Construct a PlyElement from an array's metadata. + + len_types and val_types can be given as mappings from list + property names to type strings (like 'u1', 'f4', etc., or + 'int8', 'float32', etc.). These can be used to define the length + and value types of list properties. List property lengths + always default to type 'u1' (8-bit unsigned integer), and value + types default to 'i4' (32-bit integer). + + """ + if not isinstance(data, _np.ndarray): + raise TypeError("only numpy arrays are supported") + + if len(data.shape) != 1: + raise ValueError("only one-dimensional arrays are " + "supported") + + count = len(data) + + properties = [] + descr = data.dtype.descr + + for t in descr: + if not isinstance(t[1], str): + raise ValueError("nested records not supported") + + if not t[0]: + raise ValueError("field with empty name") + + if len(t) != 2 or t[1][1] == 'O': + # non-scalar field, which corresponds to a list + # property in PLY. + + if t[1][1] == 'O': + if len(t) != 2: + raise ValueError("non-scalar object fields not " + "supported") + + len_str = _data_type_reverse[len_types.get(t[0], 'u1')] + if t[1][1] == 'O': + val_type = val_types.get(t[0], 'i4') + val_str = _lookup_type(val_type) + else: + val_str = _lookup_type(t[1][1:]) + + prop = PlyListProperty(t[0], len_str, val_str) + else: + val_str = _lookup_type(t[1][1:]) + prop = PlyProperty(t[0], val_str) + + properties.append(prop) + + elt = PlyElement(name, properties, count, comments) + elt.data = data + + return elt + + def _read(self, stream, text, byte_order): + """ + Read the actual data from a PLY file. + + """ + if text: + self._read_txt(stream) + else: + if self._have_list: + # There are list properties, so a simple load is + # impossible. + self._read_bin(stream, byte_order) + else: + # There are no list properties, so loading the data is + # much more straightforward. + self._data = _np.fromfile(stream, + self.dtype(byte_order), + self.count) + + if len(self._data) < self.count: + k = len(self._data) + del self._data + raise PlyParseError("early end-of-file", self, k) + + self._check_sanity() + + def _write(self, stream, text, byte_order): + """ + Write the data to a PLY file. + + """ + if text: + self._write_txt(stream) + else: + if self._have_list: + # There are list properties, so serialization is + # slightly complicated. + self._write_bin(stream, byte_order) + else: + # no list properties, so serialization is + # straightforward. + self.data.astype(self.dtype(byte_order), + copy=False).tofile(stream) + + def _read_txt(self, stream): + """ + Load a PLY element from an ASCII-format PLY file. The element + may contain list properties. + + """ + self._data = _np.empty(self.count, dtype=self.dtype()) + + k = 0 + for line in _islice(iter(stream.readline, b''), self.count): + fields = iter(line.strip().split()) + for prop in self.properties: + try: + self._data[prop.name][k] = prop._from_fields(fields) + except StopIteration: + raise PlyParseError("early end-of-line", + self, k, prop) + except ValueError: + raise PlyParseError("malformed input", + self, k, prop) + try: + next(fields) + except StopIteration: + pass + else: + raise PlyParseError("expected end-of-line", self, k) + k += 1 + + if k < self.count: + del self._data + raise PlyParseError("early end-of-file", self, k) + + def _write_txt(self, stream): + """ + Save a PLY element to an ASCII-format PLY file. The element may + contain list properties. + + """ + for rec in self.data: + fields = [] + for prop in self.properties: + fields.extend(prop._to_fields(rec[prop.name])) + + _np.savetxt(stream, [fields], '%.18g', newline='\r\n') + + def _read_bin(self, stream, byte_order): + """ + Load a PLY element from a binary PLY file. The element may + contain list properties. + + """ + self._data = _np.empty(self.count, dtype=self.dtype(byte_order)) + + for k in _range(self.count): + for prop in self.properties: + try: + self._data[prop.name][k] = \ + prop._read_bin(stream, byte_order) + except StopIteration: + raise PlyParseError("early end-of-file", + self, k, prop) + + def _write_bin(self, stream, byte_order): + """ + Save a PLY element to a binary PLY file. The element may + contain list properties. + + """ + for rec in self.data: + for prop in self.properties: + prop._write_bin(rec[prop.name], stream, byte_order) + + @property + def header(self): + """ + Format this element's metadata as it would appear in a PLY + header. + + """ + lines = ['element %s %d' % (self.name, self.count)] + + # Some information is lost here, since all comments are placed + # between the 'element' line and the first property definition. + for c in self.comments: + lines.append('comment ' + c) + + lines.extend(list(map(str, self.properties))) + + return '\r\n'.join(lines) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + def __str__(self): + return self.header + + def __repr__(self): + return ('PlyElement(%r, %r, count=%d, comments=%r)' % + (self.name, self.properties, self.count, + self.comments)) + + +class PlyProperty(object): + + """ + PLY property description. This class is pure metadata; the data + itself is contained in PlyElement instances. + + """ + + def __init__(self, name, val_dtype): + self._name = str(name) + self._check_name() + self.val_dtype = val_dtype + + def _get_val_dtype(self): + return self._val_dtype + + def _set_val_dtype(self, val_dtype): + self._val_dtype = _data_types[_lookup_type(val_dtype)] + + val_dtype = property(_get_val_dtype, _set_val_dtype) + + @property + def name(self): + return self._name + + def _check_name(self): + if any(c.isspace() for c in self._name): + msg = "Error: property name %r contains spaces" % self._name + raise RuntimeError(msg) + + @staticmethod + def _parse_one(line): + assert line[0] == 'property' + + if line[1] == 'list': + if len(line) > 5: + raise PlyParseError("too many fields after " + "'property list'") + if len(line) < 5: + raise PlyParseError("too few fields after " + "'property list'") + + return PlyListProperty(line[4], line[2], line[3]) + + else: + if len(line) > 3: + raise PlyParseError("too many fields after " + "'property'") + if len(line) < 3: + raise PlyParseError("too few fields after " + "'property'") + + return PlyProperty(line[2], line[1]) + + def dtype(self, byte_order='='): + """ + Return the numpy dtype description for this property (as a tuple + of strings). + + """ + return byte_order + self.val_dtype + + def _from_fields(self, fields): + """ + Parse from generator. Raise StopIteration if the property could + not be read. + + """ + return _np.dtype(self.dtype()).type(next(fields)) + + def _to_fields(self, data): + """ + Return generator over one item. + + """ + yield _np.dtype(self.dtype()).type(data) + + def _read_bin(self, stream, byte_order): + """ + Read data from a binary stream. Raise StopIteration if the + property could not be read. + + """ + try: + return _np.fromfile(stream, self.dtype(byte_order), 1)[0] + except IndexError: + raise StopIteration + + def _write_bin(self, data, stream, byte_order): + """ + Write data to a binary stream. + + """ + _np.dtype(self.dtype(byte_order)).type(data).tofile(stream) + + def __str__(self): + val_str = _data_type_reverse[self.val_dtype] + return 'property %s %s' % (val_str, self.name) + + def __repr__(self): + return 'PlyProperty(%r, %r)' % (self.name, + _lookup_type(self.val_dtype)) + + +class PlyListProperty(PlyProperty): + + """ + PLY list property description. + + """ + + def __init__(self, name, len_dtype, val_dtype): + PlyProperty.__init__(self, name, val_dtype) + + self.len_dtype = len_dtype + + def _get_len_dtype(self): + return self._len_dtype + + def _set_len_dtype(self, len_dtype): + self._len_dtype = _data_types[_lookup_type(len_dtype)] + + len_dtype = property(_get_len_dtype, _set_len_dtype) + + def dtype(self, byte_order='='): + """ + List properties always have a numpy dtype of "object". + + """ + return '|O' + + def list_dtype(self, byte_order='='): + """ + Return the pair (len_dtype, val_dtype) (both numpy-friendly + strings). + + """ + return (byte_order + self.len_dtype, + byte_order + self.val_dtype) + + def _from_fields(self, fields): + (len_t, val_t) = self.list_dtype() + + n = int(_np.dtype(len_t).type(next(fields))) + + data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1) + if len(data) < n: + raise StopIteration + + return data + + def _to_fields(self, data): + """ + Return generator over the (numerical) PLY representation of the + list data (length followed by actual data). + + """ + (len_t, val_t) = self.list_dtype() + + data = _np.asarray(data, dtype=val_t).ravel() + + yield _np.dtype(len_t).type(data.size) + for x in data: + yield x + + def _read_bin(self, stream, byte_order): + (len_t, val_t) = self.list_dtype(byte_order) + + try: + n = _np.fromfile(stream, len_t, 1)[0] + except IndexError: + raise StopIteration + + data = _np.fromfile(stream, val_t, n) + if len(data) < n: + raise StopIteration + + return data + + def _write_bin(self, data, stream, byte_order): + """ + Write data to a binary stream. + + """ + (len_t, val_t) = self.list_dtype(byte_order) + + data = _np.asarray(data, dtype=val_t).ravel() + + _np.array(data.size, dtype=len_t).tofile(stream) + data.tofile(stream) + + def __str__(self): + len_str = _data_type_reverse[self.len_dtype] + val_str = _data_type_reverse[self.val_dtype] + return 'property list %s %s %s' % (len_str, val_str, self.name) + + def __repr__(self): + return ('PlyListProperty(%r, %r, %r)' % + (self.name, + _lookup_type(self.len_dtype), + _lookup_type(self.val_dtype))) + + +def load_ply(file_name: str, + with_faces: bool = False, + with_color: bool = False) -> _np.ndarray: + ply_data = PlyData.read(file_name) + points = ply_data['vertex'] + points = _np.vstack([points['x'], points['y'], points['z']]).T + ret_val = [points] + + if with_faces: + faces = _np.vstack(ply_data['face']['vertex_indices']) + ret_val.append(faces) + + if with_color: + r = _np.vstack(ply_data['vertex']['red']) + g = _np.vstack(ply_data['vertex']['green']) + b = _np.vstack(ply_data['vertex']['blue']) + color = _np.hstack((r, g, b)) + ret_val.append(color) + + if len(ret_val) == 1: # Unwrap the list + ret_val = ret_val[0] + + return ret_val diff --git a/utils/util.py b/utils/util.py new file mode 100755 index 0000000..10d82d8 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,69 @@ +import logging +import re +from os import listdir, makedirs +from os.path import join, exists +from shutil import rmtree +from time import sleep + +import torch + + +def setup_logging(log_dir): + makedirs(log_dir, exist_ok=True) + + logpath = join(log_dir, 'log.txt') + filemode = 'a' if exists(logpath) else 'w' + + # set up logging to file - see previous section for more details + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(message)s', + datefmt='%m-%d %H:%M:%S', + filename=logpath, + filemode=filemode) + # define a Handler which writes INFO messages or higher to the sys.stderr + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + # set a format which is simpler for console use + formatter = logging.Formatter('%(asctime)s: %(levelname)-8s %(message)s') + # tell the handler to use this format + console.setFormatter(formatter) + # add the handler to the root logger + logging.getLogger('').addHandler(console) + + +def prepare_results_dir(config): + output_dir = join(config['results_root'], config['arch'], + config['experiment_name']) + if config['clean_results_dir']: + if exists(output_dir): + print('Attention! Cleaning results directory in 10 seconds!') + sleep(10) + rmtree(output_dir, ignore_errors=True) + makedirs(output_dir, exist_ok=True) + makedirs(join(output_dir, 'weights'), exist_ok=True) + makedirs(join(output_dir, 'samples'), exist_ok=True) + makedirs(join(output_dir, 'results'), exist_ok=True) + return output_dir + + +def find_latest_epoch(dirpath): + # Files with weights are in format ddddd_{D,E,G}.pth + epoch_regex = re.compile(r'^(?P\d+)_[DEG]\.pth$') + epochs_completed = [] + if exists(join(dirpath, 'weights')): + dirpath = join(dirpath, 'weights') + for f in listdir(dirpath): + m = epoch_regex.match(f) + if m: + epochs_completed.append(int(m.group('n_epoch'))) + return max(epochs_completed) if epochs_completed else 0 + + +def cuda_setup(cuda=False, gpu_idx=0): + if cuda and torch.cuda.is_available(): + device = torch.device('cuda') + torch.cuda.set_device(gpu_idx) + else: + device = torch.device('cpu') + return device +