Skip to content
This repository has been archived by the owner on Nov 2, 2022. It is now read-only.

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
MaciejZamorski committed May 7, 2019
1 parent cf495bb commit 6b07132
Show file tree
Hide file tree
Showing 10 changed files with 944 additions and 47 deletions.
161 changes: 161 additions & 0 deletions datasets/modelnet40.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import urllib
import shutil
from os import listdir, makedirs, remove
from os.path import exists, join
from zipfile import ZipFile

import h5py
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from utils.pcutil import rand_rotation_matrix

all_classes = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle',
'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door',
'dresser', 'flower_pot', 'glass_box', 'guitar', 'keyboard',
'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 'person',
'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa',
'stairs', 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase',
'wardrobe', 'xbox']

number_to_category = {i: c for i, c in enumerate(all_classes)}
category_to_number = {c: i for i, c in enumerate(all_classes)}


class ModelNet40(Dataset):
def __init__(self, root_dir='/home/datasets/modelnet40', classes=[],
transform=[], split='train', valid_percent=10, percent_supervised=0.0):
"""
Args:
root_dir (string): Directory with all the point clouds.
transform (callable, optional): Optional transform to be applied
on a sample.
split (string): `train` or `test`
valid_percent (int): Percent of train (from the end) to use as valid set.
"""
self.root_dir = root_dir
self.transform = transform
self.split = split.lower()
self.valid_percent = valid_percent
self.percent_supervised = percent_supervised

self._maybe_download_data()

if self.split in ('train', 'valid'):
self.files_list = join(self.root_dir, 'train_files.txt')
elif self.split == 'test':
self.files_list = join(self.root_dir, 'test_files.txt')
else:
raise ValueError('Incorrect split')

data, labels = self._load_files()

if classes:
if classes[0] in all_classes:
classes = np.asarray([category_to_number[c] for c in classes])
filter = [label in classes for label in labels]
data = data[filter]
labels = labels[filter]
else:
classes = np.arange(len(all_classes))

if self.split in ('train', 'valid'):
new_data, new_labels = [], []
if self.percent_supervised > 0.0:
data_sup, labels_sub = [], []
for c in classes:
pc_in_class = sum(labels.flatten() == c)

if self.split == 'train':
portion = slice(0, int(pc_in_class * (1 - (self.valid_percent / 100))))
else:
portion = slice(int(pc_in_class * (1 - (self.valid_percent / 100))), pc_in_class)

new_data.append(data[labels.flatten() == c][portion])
new_labels.append(labels[labels.flatten() == c][portion])

if self.percent_supervised > 0.0:
n_max = int(self.percent_supervised * (portion.stop - 1))
data_sup.append(data[labels.flatten() == c][:n_max])
labels_sub.append(labels[labels.flatten() == c][:n_max])
data = np.vstack(new_data)
labels = np.vstack(new_labels)
if self.percent_supervised > 0.0:
self.data_sup = np.vstack(data_sup)
self.labels_sup = np.vstack(labels_sub)
self.data = data
self.labels = labels

def _load_files(self) -> pd.DataFrame:

with open(self.files_list) as f:
files = [join(self.root_dir, line.rstrip().rsplit('/', 1)[1]) for line in f]

data, labels = [], []
for file in files:
with h5py.File(file) as f:
data.extend(f['data'][:])
labels.extend(f['label'][:])

return np.asarray(data), np.asarray(labels)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
sample = self.data[idx]
sample /= 2 # Scale to [-0.5, 0.5] range
label = self.labels[idx]

if 'rotate'.lower() in self.transform:
r_rotation = rand_rotation_matrix()
r_rotation[0, 2] = 0
r_rotation[2, 0] = 0
r_rotation[1, 2] = 0
r_rotation[2, 1] = 0
r_rotation[2, 2] = 1

sample = sample.dot(r_rotation).astype(np.float32)
if self.percent_supervised > 0.0:
id_sup = np.random.randint(self.data_sup.shape[0])
sample_sup = self.data_sup[id_sup]
sample_sup /= 2
label_sup = self.labels_sup[id_sup]
return sample, label, sample_sup, label_sup
else:
return sample, label

def _maybe_download_data(self):
if exists(self.root_dir):
return

print(f'ModelNet40 doesn\'t exist in root directory {self.root_dir}. '
f'Downloading...')
makedirs(self.root_dir)

url = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'

data = urllib.request.urlopen(url)
filename = url.rpartition('/')[2][:-5]
file_path = join(self.root_dir, filename)
with open(file_path, mode='wb') as f:
d = data.read()
f.write(d)

print('Extracting...')
with ZipFile(file_path, mode='r') as zip_f:
zip_f.extractall(self.root_dir)

remove(file_path)

extracted_dir = join(self.root_dir, 'modelnet40_ply_hdf5_2048')
for d in listdir(extracted_dir):
shutil.move(src=join(extracted_dir, d),
dst=self.root_dir)

shutil.rmtree(extracted_dir)


if __name__ == '__main__':
ModelNet40()
168 changes: 168 additions & 0 deletions evaluation/find_best_epoch_on_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import json
import logging
import random
import re
from datetime import datetime
from importlib import import_module
from os import listdir
from os.path import join

import numpy as np
import pandas as pd
import torch
from torch.distributions.beta import Beta
from torch.utils.data import DataLoader

from datasets.shapenet import ShapeNetDataset
from metrics.jsd import jsd_between_point_cloud_sets
from utils.util import cuda_setup, setup_logging


def _get_epochs_by_regex(path, regex):
reg = re.compile(regex)
return {int(w[:5]) for w in listdir(path) if reg.match(w)}


def main(eval_config):
# Load hyperparameters as they were during training
train_results_path = join(eval_config['results_root'], eval_config['arch'],
eval_config['experiment_name'])
with open(join(train_results_path, 'config.json')) as f:
train_config = json.load(f)

random.seed(train_config['seed'])
torch.manual_seed(train_config['seed'])
torch.cuda.manual_seed_all(train_config['seed'])

setup_logging(join(train_results_path, 'results'))
log = logging.getLogger(__name__)

log.debug('Evaluating JensenShannon divergences on validation set on all '
'saved epochs.')

weights_path = join(train_results_path, 'weights')

# Find all epochs that have saved model weights
e_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_E\.pth')
g_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_G\.pth')
epochs = sorted(e_epochs.intersection(g_epochs))
log.debug(f'Testing epochs: {epochs}')

device = cuda_setup(eval_config['cuda'], eval_config['gpu'])
log.debug(f'Device variable: {device}')
if device.type == 'cuda':
log.debug(f'Current CUDA device: {torch.cuda.current_device()}')

#
# Dataset
#
dataset_name = train_config['dataset'].lower()
if dataset_name == 'shapenet':
dataset = ShapeNetDataset(root_dir=train_config['data_dir'],
classes=train_config['classes'], split='valid')
elif dataset_name == 'faust':
from datasets.dfaust import DFaustDataset
dataset = DFaustDataset(root_dir=train_config['data_dir'],
classes=train_config['classes'], split='valid')
elif dataset_name == 'mcgill':
from datasets.mcgill import McGillDataset
dataset = McGillDataset(root_dir=train_config['data_dir'],
classes=train_config['classes'], split='valid')
else:
raise ValueError(f'Invalid dataset name. Expected `shapenet` or '
f'`faust`. Got: `{dataset_name}`')
classes_selected = ('all' if not train_config['classes']
else ','.join(train_config['classes']))
log.debug(f'Selected {classes_selected} classes. Loaded {len(dataset)} '
f'samples.')

if 'distribution' in train_config:
distribution = train_config['distribution']
elif 'distribution' in eval_config:
distribution = eval_config['distribution']
else:
log.warning('No distribution type specified. Assumed normal = N(0, 0.2)')
distribution = 'normal'

#
# Models
#
arch = import_module(f"model.architectures.{train_config['arch']}")
E = arch.Encoder(train_config).to(device)
G = arch.Generator(train_config).to(device)

E.eval()
G.eval()

num_samples = len(dataset.point_clouds_names_valid)
data_loader = DataLoader(dataset, batch_size=num_samples,
shuffle=False, num_workers=4,
drop_last=False, pin_memory=True)

# We take 3 times as many samples as there are in test data in order to
# perform JSD calculation in the same manner as in the reference publication
noise = torch.FloatTensor(3 * num_samples, train_config['z_size'], 1)
noise = noise.to(device)

X, _ = next(iter(data_loader))
X = X.to(device)

results = {}

for epoch in reversed(epochs):
try:
E.load_state_dict(torch.load(
join(weights_path, f'{epoch:05}_E.pth')))
G.load_state_dict(torch.load(
join(weights_path, f'{epoch:05}_G.pth')))

start_clock = datetime.now()

# We average JSD computation from 3 independet trials.
js_results = []
for _ in range(3):
if distribution == 'normal':
noise.normal_(0, 0.2)
elif distribution == 'beta':
noise_np = np.random.beta(train_config['z_beta_a'],
train_config['z_beta_b'],
noise.shape)
noise = torch.tensor(noise_np).float().round().to(device)

with torch.no_grad():
X_g = G(noise)
if X_g.shape[-2:] == (3, 2048):
X_g.transpose_(1, 2)

jsd = jsd_between_point_cloud_sets(X, X_g, voxels=28)
js_results.append(jsd)

js_result = np.mean(js_results)
log.debug(f'Epoch: {epoch} JSD: {js_result: .6f} '
f'Time: {datetime.now() - start_clock}')
results[epoch] = js_result
except KeyboardInterrupt:
log.debug(f'Interrupted during epoch: {epoch}')
break

results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd'])
log.debug(f"Minimum JSD at epoch {results.idxmin()['jsd']}: "
f"{results.min()['jsd']: .6f}")


if __name__ == '__main__':
logger = logging.getLogger()

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', default=None, type=str,
help='File path for evaluation config')
args = parser.parse_args()

evaluation_config = None
if args.config is not None and args.config.endswith('.json'):
with open(args.config) as f:
evaluation_config = json.load(f)
assert evaluation_config is not None

main(evaluation_config)
Loading

0 comments on commit 6b07132

Please sign in to comment.