This repository has been archived by the owner on Nov 2, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cf495bb
commit 6b07132
Showing
10 changed files
with
944 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import urllib | ||
import shutil | ||
from os import listdir, makedirs, remove | ||
from os.path import exists, join | ||
from zipfile import ZipFile | ||
|
||
import h5py | ||
import numpy as np | ||
import pandas as pd | ||
from torch.utils.data import Dataset | ||
|
||
from utils.pcutil import rand_rotation_matrix | ||
|
||
all_classes = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', | ||
'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', | ||
'dresser', 'flower_pot', 'glass_box', 'guitar', 'keyboard', | ||
'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 'person', | ||
'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', | ||
'stairs', 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', | ||
'wardrobe', 'xbox'] | ||
|
||
number_to_category = {i: c for i, c in enumerate(all_classes)} | ||
category_to_number = {c: i for i, c in enumerate(all_classes)} | ||
|
||
|
||
class ModelNet40(Dataset): | ||
def __init__(self, root_dir='/home/datasets/modelnet40', classes=[], | ||
transform=[], split='train', valid_percent=10, percent_supervised=0.0): | ||
""" | ||
Args: | ||
root_dir (string): Directory with all the point clouds. | ||
transform (callable, optional): Optional transform to be applied | ||
on a sample. | ||
split (string): `train` or `test` | ||
valid_percent (int): Percent of train (from the end) to use as valid set. | ||
""" | ||
self.root_dir = root_dir | ||
self.transform = transform | ||
self.split = split.lower() | ||
self.valid_percent = valid_percent | ||
self.percent_supervised = percent_supervised | ||
|
||
self._maybe_download_data() | ||
|
||
if self.split in ('train', 'valid'): | ||
self.files_list = join(self.root_dir, 'train_files.txt') | ||
elif self.split == 'test': | ||
self.files_list = join(self.root_dir, 'test_files.txt') | ||
else: | ||
raise ValueError('Incorrect split') | ||
|
||
data, labels = self._load_files() | ||
|
||
if classes: | ||
if classes[0] in all_classes: | ||
classes = np.asarray([category_to_number[c] for c in classes]) | ||
filter = [label in classes for label in labels] | ||
data = data[filter] | ||
labels = labels[filter] | ||
else: | ||
classes = np.arange(len(all_classes)) | ||
|
||
if self.split in ('train', 'valid'): | ||
new_data, new_labels = [], [] | ||
if self.percent_supervised > 0.0: | ||
data_sup, labels_sub = [], [] | ||
for c in classes: | ||
pc_in_class = sum(labels.flatten() == c) | ||
|
||
if self.split == 'train': | ||
portion = slice(0, int(pc_in_class * (1 - (self.valid_percent / 100)))) | ||
else: | ||
portion = slice(int(pc_in_class * (1 - (self.valid_percent / 100))), pc_in_class) | ||
|
||
new_data.append(data[labels.flatten() == c][portion]) | ||
new_labels.append(labels[labels.flatten() == c][portion]) | ||
|
||
if self.percent_supervised > 0.0: | ||
n_max = int(self.percent_supervised * (portion.stop - 1)) | ||
data_sup.append(data[labels.flatten() == c][:n_max]) | ||
labels_sub.append(labels[labels.flatten() == c][:n_max]) | ||
data = np.vstack(new_data) | ||
labels = np.vstack(new_labels) | ||
if self.percent_supervised > 0.0: | ||
self.data_sup = np.vstack(data_sup) | ||
self.labels_sup = np.vstack(labels_sub) | ||
self.data = data | ||
self.labels = labels | ||
|
||
def _load_files(self) -> pd.DataFrame: | ||
|
||
with open(self.files_list) as f: | ||
files = [join(self.root_dir, line.rstrip().rsplit('/', 1)[1]) for line in f] | ||
|
||
data, labels = [], [] | ||
for file in files: | ||
with h5py.File(file) as f: | ||
data.extend(f['data'][:]) | ||
labels.extend(f['label'][:]) | ||
|
||
return np.asarray(data), np.asarray(labels) | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, idx): | ||
sample = self.data[idx] | ||
sample /= 2 # Scale to [-0.5, 0.5] range | ||
label = self.labels[idx] | ||
|
||
if 'rotate'.lower() in self.transform: | ||
r_rotation = rand_rotation_matrix() | ||
r_rotation[0, 2] = 0 | ||
r_rotation[2, 0] = 0 | ||
r_rotation[1, 2] = 0 | ||
r_rotation[2, 1] = 0 | ||
r_rotation[2, 2] = 1 | ||
|
||
sample = sample.dot(r_rotation).astype(np.float32) | ||
if self.percent_supervised > 0.0: | ||
id_sup = np.random.randint(self.data_sup.shape[0]) | ||
sample_sup = self.data_sup[id_sup] | ||
sample_sup /= 2 | ||
label_sup = self.labels_sup[id_sup] | ||
return sample, label, sample_sup, label_sup | ||
else: | ||
return sample, label | ||
|
||
def _maybe_download_data(self): | ||
if exists(self.root_dir): | ||
return | ||
|
||
print(f'ModelNet40 doesn\'t exist in root directory {self.root_dir}. ' | ||
f'Downloading...') | ||
makedirs(self.root_dir) | ||
|
||
url = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' | ||
|
||
data = urllib.request.urlopen(url) | ||
filename = url.rpartition('/')[2][:-5] | ||
file_path = join(self.root_dir, filename) | ||
with open(file_path, mode='wb') as f: | ||
d = data.read() | ||
f.write(d) | ||
|
||
print('Extracting...') | ||
with ZipFile(file_path, mode='r') as zip_f: | ||
zip_f.extractall(self.root_dir) | ||
|
||
remove(file_path) | ||
|
||
extracted_dir = join(self.root_dir, 'modelnet40_ply_hdf5_2048') | ||
for d in listdir(extracted_dir): | ||
shutil.move(src=join(extracted_dir, d), | ||
dst=self.root_dir) | ||
|
||
shutil.rmtree(extracted_dir) | ||
|
||
|
||
if __name__ == '__main__': | ||
ModelNet40() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import argparse | ||
import json | ||
import logging | ||
import random | ||
import re | ||
from datetime import datetime | ||
from importlib import import_module | ||
from os import listdir | ||
from os.path import join | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from torch.distributions.beta import Beta | ||
from torch.utils.data import DataLoader | ||
|
||
from datasets.shapenet import ShapeNetDataset | ||
from metrics.jsd import jsd_between_point_cloud_sets | ||
from utils.util import cuda_setup, setup_logging | ||
|
||
|
||
def _get_epochs_by_regex(path, regex): | ||
reg = re.compile(regex) | ||
return {int(w[:5]) for w in listdir(path) if reg.match(w)} | ||
|
||
|
||
def main(eval_config): | ||
# Load hyperparameters as they were during training | ||
train_results_path = join(eval_config['results_root'], eval_config['arch'], | ||
eval_config['experiment_name']) | ||
with open(join(train_results_path, 'config.json')) as f: | ||
train_config = json.load(f) | ||
|
||
random.seed(train_config['seed']) | ||
torch.manual_seed(train_config['seed']) | ||
torch.cuda.manual_seed_all(train_config['seed']) | ||
|
||
setup_logging(join(train_results_path, 'results')) | ||
log = logging.getLogger(__name__) | ||
|
||
log.debug('Evaluating JensenShannon divergences on validation set on all ' | ||
'saved epochs.') | ||
|
||
weights_path = join(train_results_path, 'weights') | ||
|
||
# Find all epochs that have saved model weights | ||
e_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_E\.pth') | ||
g_epochs = _get_epochs_by_regex(weights_path, r'(?P<epoch>\d{5})_G\.pth') | ||
epochs = sorted(e_epochs.intersection(g_epochs)) | ||
log.debug(f'Testing epochs: {epochs}') | ||
|
||
device = cuda_setup(eval_config['cuda'], eval_config['gpu']) | ||
log.debug(f'Device variable: {device}') | ||
if device.type == 'cuda': | ||
log.debug(f'Current CUDA device: {torch.cuda.current_device()}') | ||
|
||
# | ||
# Dataset | ||
# | ||
dataset_name = train_config['dataset'].lower() | ||
if dataset_name == 'shapenet': | ||
dataset = ShapeNetDataset(root_dir=train_config['data_dir'], | ||
classes=train_config['classes'], split='valid') | ||
elif dataset_name == 'faust': | ||
from datasets.dfaust import DFaustDataset | ||
dataset = DFaustDataset(root_dir=train_config['data_dir'], | ||
classes=train_config['classes'], split='valid') | ||
elif dataset_name == 'mcgill': | ||
from datasets.mcgill import McGillDataset | ||
dataset = McGillDataset(root_dir=train_config['data_dir'], | ||
classes=train_config['classes'], split='valid') | ||
else: | ||
raise ValueError(f'Invalid dataset name. Expected `shapenet` or ' | ||
f'`faust`. Got: `{dataset_name}`') | ||
classes_selected = ('all' if not train_config['classes'] | ||
else ','.join(train_config['classes'])) | ||
log.debug(f'Selected {classes_selected} classes. Loaded {len(dataset)} ' | ||
f'samples.') | ||
|
||
if 'distribution' in train_config: | ||
distribution = train_config['distribution'] | ||
elif 'distribution' in eval_config: | ||
distribution = eval_config['distribution'] | ||
else: | ||
log.warning('No distribution type specified. Assumed normal = N(0, 0.2)') | ||
distribution = 'normal' | ||
|
||
# | ||
# Models | ||
# | ||
arch = import_module(f"model.architectures.{train_config['arch']}") | ||
E = arch.Encoder(train_config).to(device) | ||
G = arch.Generator(train_config).to(device) | ||
|
||
E.eval() | ||
G.eval() | ||
|
||
num_samples = len(dataset.point_clouds_names_valid) | ||
data_loader = DataLoader(dataset, batch_size=num_samples, | ||
shuffle=False, num_workers=4, | ||
drop_last=False, pin_memory=True) | ||
|
||
# We take 3 times as many samples as there are in test data in order to | ||
# perform JSD calculation in the same manner as in the reference publication | ||
noise = torch.FloatTensor(3 * num_samples, train_config['z_size'], 1) | ||
noise = noise.to(device) | ||
|
||
X, _ = next(iter(data_loader)) | ||
X = X.to(device) | ||
|
||
results = {} | ||
|
||
for epoch in reversed(epochs): | ||
try: | ||
E.load_state_dict(torch.load( | ||
join(weights_path, f'{epoch:05}_E.pth'))) | ||
G.load_state_dict(torch.load( | ||
join(weights_path, f'{epoch:05}_G.pth'))) | ||
|
||
start_clock = datetime.now() | ||
|
||
# We average JSD computation from 3 independet trials. | ||
js_results = [] | ||
for _ in range(3): | ||
if distribution == 'normal': | ||
noise.normal_(0, 0.2) | ||
elif distribution == 'beta': | ||
noise_np = np.random.beta(train_config['z_beta_a'], | ||
train_config['z_beta_b'], | ||
noise.shape) | ||
noise = torch.tensor(noise_np).float().round().to(device) | ||
|
||
with torch.no_grad(): | ||
X_g = G(noise) | ||
if X_g.shape[-2:] == (3, 2048): | ||
X_g.transpose_(1, 2) | ||
|
||
jsd = jsd_between_point_cloud_sets(X, X_g, voxels=28) | ||
js_results.append(jsd) | ||
|
||
js_result = np.mean(js_results) | ||
log.debug(f'Epoch: {epoch} JSD: {js_result: .6f} ' | ||
f'Time: {datetime.now() - start_clock}') | ||
results[epoch] = js_result | ||
except KeyboardInterrupt: | ||
log.debug(f'Interrupted during epoch: {epoch}') | ||
break | ||
|
||
results = pd.DataFrame.from_dict(results, orient='index', columns=['jsd']) | ||
log.debug(f"Minimum JSD at epoch {results.idxmin()['jsd']}: " | ||
f"{results.min()['jsd']: .6f}") | ||
|
||
|
||
if __name__ == '__main__': | ||
logger = logging.getLogger() | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-c', '--config', default=None, type=str, | ||
help='File path for evaluation config') | ||
args = parser.parse_args() | ||
|
||
evaluation_config = None | ||
if args.config is not None and args.config.endswith('.json'): | ||
with open(args.config) as f: | ||
evaluation_config = json.load(f) | ||
assert evaluation_config is not None | ||
|
||
main(evaluation_config) |
Oops, something went wrong.