diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..5792c9f --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 zgojcic + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..149d3ff --- /dev/null +++ b/README.md @@ -0,0 +1,185 @@ +# Learning Multiview 3D Point Cloud Registration repository +This repository provides code and data to train and evaluate the LMPCR, the first end-to-end algorithm for multiview registration of raw point clouds in a globally consistent manner. It represents the official implementation of the paper: + +### [Learning Multiview 3D Point Cloud Registration (CVPR 2020).](https://arxiv.org/pdf/2001.05119.pdf) +\*[Zan Gojcic](https://www.ethz.ch/content/specialinterest/baug/institute-igp/geosensors-and-engineering-geodesy/en/people/scientific-assistance/zan-gojcic.html),\* [Caifa Zhou](https://ch.linkedin.com/in/caifa-zhou-7a461510b), [Jan D. Wegner](http://www.prs.igp.ethz.ch/content/specialinterest/baug/institute-igp/photogrammetry-and-remote-sensing/en/group/people/person-detail.html?persid=186562), [Leonidas J. Guibas](https://geometry.stanford.edu/member/guibas/), [Tolga Birdal](http://tbirdal.me/)\ +|[EcoVision Lab ETH Zurich](https://prs.igp.ethz.ch/ecovision.html) | [Guibas Lab Stanford University](https://geometry.stanford.edu/index.html)|\ +\* Equal contribution + +We present a novel, end-to-end learnable, multiview 3D point cloud registration algorithm. Registration of multiple scans typically follows a two-stage pipeline: the initial pairwise alignment and the globally consistent refinement. The former is often ambiguous due to the low overlap of +neighboring point clouds, symmetries and repetitive scene parts. Therefore, the latter global refinement aims at establishing the cyclic consistency across multiple scans and helps in resolving the ambiguous cases. In this paper we propose, to the best of our knowledge, the first end-to-end algorithm for joint learning of both parts of this two-stage problem. Experimental evaluation on well accepted benchmark datasets shows that our approach outperforms the state-of-the-art by a significant margin, while being end-to-end trainable and computationally less costly. Moreover, we present detailed analysis and an ablation study that validate +the novel components of our approach. + +![LM3DPCR](figures/LM3DPCR.jpg?raw=true) + +### Citation + +If you find this code useful for your work or use it in your project, please consider citing: + +```shell +@inproceedings{gojcic2020LearningMultiview, + title={Learning Multiview 3D Point Cloud Registration}, + author={Gojcic, Zan and Zhou, Caifa and Wegner, Jan D and Guibas, Leonidas J and Birdal, Tolga}, + booktitle={International conference on computer vision and pattern recognition (CVPR)}, + year={2020} +} +``` + +### Contact +If you have any questions or find any bugs, please let us know: Zan Gojcic {firstname.lastname@geod.baug.ethz.ch} + +## Current state of the repository +Currently the repository contains only part of the code connected to the above mentioned publication and will be consistently updated in the course of the following days. The whole code will be available the following weeks. + +**NOTE**: The published model is not the same as the model used in the CVPR paper. The results can therefore slightly differ from the ones in the paper. The models will be updated in the following days (with the fully converged ones). + +## Instructions +The code was tested on Ubuntu 18.04 with Python 3.6, pytorch 1.5, CUDA 10.1.243, and GCC 7. + +### Requirements +After cloning this repository, you can start by creating a virtual environment and installing the requirements by running: + +```bash +conda create --name lmpr python=3.6 +source activate lmpr +conda config --append channels conda-forge +conda install --file requirements.txt +conda install -c open3d-admin open3d=0.9.0.0 +conda install -c intel scikit-learn +conda install pytorch torchvision cudatoolkit=10.1 -c pytorch +``` + +Our network uses [FCGF](https://github.com/chrischoy/FCGF) feature descriptor which is based on the [MinkowskiEnginge](https://github.com/StanfordVL/MinkowskiEngine) library for sparse tensors. In order to install Minkowski Engine run: + +```bash +source activate lmpr +git clone https://github.com/StanfordVL/MinkowskiEngine.git +cd MinkowskiEngine +conda install numpy mkl-include +export CXX=g++-7; python setup.py install +cd ../ +``` + +Finally, our network supports furthest point sampling, when sampling the interest points. To this end we require the [PointNet++](https://github.com/erikwijmans/Pointnet2_PyTorch/tree/master/pointnet2/models) library that can be installed as follows: + +```bash +source activate lmpr +git clone https://github.com/erikwijmans/Pointnet2_PyTorch.git +cd Pointnet2_PyTorch +pip install -r requirements.txt +cd ../ +``` + +## Pretrained models +We provide the pretrained models for [FCGF]((https://github.com/chrischoy/FCGF)) feature descriptor, our pairwise registration block, and jointly trained pairwise registration model. They can be downloaded using: + + ```bash +bash scripts/download_pretrained_models.sh +``` + +## Datasets +### Pairwise registration +In order to train the pairwise registration model, you have to download the full [3DMatch](http://3dmatch.cs.princeton.edu/) dataset. To (pre)-train the registration blocks you can either download the preprocessed dataset (~160GB) using + +```bash +bash scripts/download_3DMatch_train.sh preprocessed +``` + +This dataset contains the pointwise correspondences established in the [FCGF](https://github.com/chrischoy/FCGF) feature space for all point cloud pairs of 3DMatch dataset. + +We also provide the raw 3DMatch data (~4.5GB) and a script to generate the preprocess training data `./scripts/extract_data.py` that can be used either with the raw 3DMatch data or your personal dataset. The raw dataset can be downloaded using + +```bash +bash scripts/download_3DMatch_train.sh raw +``` + +And then processed (extract FCGF feature descriptors, establish the correspondences, and save training data) using: + +```bash +source activate lmpr +python ./scripts/extract_data.py \ + --source_path ./data/train_data/ \ + --target_path ./data/train_data/ \ + --dataset 3d_match \ + --model ./pretrained/fcgf/model_best.pth \ + --voxel_size 0.025 \ + --n_correspondences 20000 \ + --inlier_threshold 0.05 \ + --extract_features \ + --extract_correspondences \ + --extract_precomputed_training_data \ + --with_cuda \ + +bash scripts/download_preprocessed_3DMatch.sh +``` +## Training +Current release supports only training the pairwise registration network. The repository will be further updated in the next weeks. In order to train the pairwise registration network from scratch using the precomputed data run +```bash +source activate lmpr +python train.py ./configs/pairwise_registration/OANet.yaml + +``` +The training parameters can be set in `./configs/pairwise_registration/OANet.yaml`. + +In order to fine tune the pairwise registration network in an end-to-end manner (including the FCGF block), the raw data has to be used. The code to sample the batches will be released in the next weeks. + +## Evaluation +We provide the scripts for the automatic evaluation of our method on the 3DMatch and Redwood dataset. The results for our method can differ slightly from the results in the CVPR paper as we have retrained the model. Due to the different implementation the results for RANSAC might also differ slightly from +the results of the [official evaluation script](https://github.com/andyzeng/3dmatch-toolbox/tree/master/evaluation/geometric-registration). + +### 3DMatch +To evaluate on 3DMatch you can either download the raw evaluation data (~450MB) using +```bash +bash scripts/download_3DMatch_eval.sh raw +``` +or the processed data together with the results for our method and RANSAC (~2.7GB) using +```bash +bash scripts/download_3DMatch_eval.sh preprocessed +``` + +If you download the raw data you first have to process it (extract features and correspondences) + +```bash +source activate lmpr +python ./scripts/extract_data.py \ + --source_path ./data/eval_data/ \ + --target_path ./data/eval_data/ \ + --dataset 3d_match \ + --model ./pretrained/fcgf/model_best.pth \ + --voxel_size 0.025 \ + --n_correspondences 5000 \ + --extract_features \ + --extract_correspondences \ + --with_cuda \ + +bash scripts/download_preprocessed_3DMatch.sh +``` + +Then you can run the pairwise registration evaluation using our RegBlock (with all points) as + +```bash +source activate lmpr +python ./scripts/benchmark_pairwise_registration.py \ + --source ./data/eval_data/3d_match/ \ + --dataset 3d_match \ + --method RegBlock \ + --model ./pretrained/RegBlock/model_best.pt \ + --only_gt_overlap \ +``` + +This script assumes that the correspondences were already estimated using `./scripts/extract_data.py` and only benchmarks the registration algorithms. To improve efficiency the registration parameters will only be computed for the ground truth overlapping pairs, when the flag `--only_gt_overlap` is set. This does not change the registration recall, which is used as the primary evaluation metric. In order to run the estimation on all n choose 2 pairs simply omit the `--only_gt_overlap` (results in ~10 times longer computation time on 3DMatch dataset) + + +### Redwood +To evaluate the generalization performance of our method on redwood you can again either download the raw evaluation data (~1.3GB) using +```bash +bash scripts/download_redwood_eval.sh raw +``` +or the processed data together with the results for our method and RANSAC (~2.9GB) using +```bash +bash scripts/download_redwood_eval.sh preprocessed +``` + +The rest of the evaluation follow the same procedure as for 3DMatch, you simply have to replace the dataset argument with `--dataset redwood` + +**NOTE**: Using the currently provided model the performance is slightly worse then reported in the paper. The model will be updated in the following days. \ No newline at end of file diff --git a/Readme.md b/Readme.md deleted file mode 100644 index 313ec79..0000000 --- a/Readme.md +++ /dev/null @@ -1,2 +0,0 @@ -This repository will contain the source code and pretrained models from the paper [Learning multiview 3D point cloud registration](https://arxiv.org/abs/2001.05119). -Coming soon! diff --git a/configs/pairwise_registration/OANet.yaml b/configs/pairwise_registration/OANet.yaml new file mode 100644 index 0000000..9c1ade5 --- /dev/null +++ b/configs/pairwise_registration/OANet.yaml @@ -0,0 +1,69 @@ +method: + task: pairwise # Name of the task, one of [pairwise, multiview] + descriptor_module: null # Descriptor method to be used. If null precomputed correspondences are used for training. + filter_module: oanet # Filtering method to be used. + +misc: + run_mode: train # Mode to run the network in + net_depth: 12 # Number of layers + clusters: 500 # Number of clusters + iter_num: 1 # Number of iterations in the iterative network + net_channel: 128 # Number of channels in a layer + trainer: 'PairwiseTrainer' # Which class of trainer to use. Can be used if multiple different trainers are defined. + use_gpu: True # If GPU should be used or not + best_validation_metric: loss # Which validation metric to use. 0 Transformation loss, 1 Classification loss + log_dir: ./logs/pairwise_registration # Path to the folder where the models and logs will be saves + normalize_weights: True # If the inferred per point weights should be normalized to sum to 1 before SVD + trans_loss_margin: 0.1 # Value for clipping the transformation loss + inlier_weight_threshold: 0.5 # Threshold for determining the inlier/outlier class + +data: + dataset: Precomputed3DMatchExample + root: ./data/train_data/3d_match/ + dist_th: 0.05 # Distance threshold for ground truth labels + shuffle_examples: True # Shuffle training examples in the data loader + augment_data: True # If data should be augmented by being transformed with random transformation parameters + jitter: True # If jitter should be applied to input point clouds + use_mutuals: False # Use only mutual nearest neighbors or all correspondences. 0 do not use, 1 use as filter, 2 use as side info + max_num_points: 2000 #Number of keypoints to use per training example + +loss: + trans_loss_type: 3 # Type of transformation loss: 0 Eucl. distance to correspondence, 1 Fro. norm, 2 Eucl. distance to GT, 3 L1 distance to GT + loss_class: 1 # Weight of the classification loss + loss_trans: 0.2 # Weight of the transformation loss + loss_desc: 0.0 # Weight of the descriptor loss (only used if descriptor module is specified) + trans_loss_iter: 15000 # Iteration at which the transformation and confidence loss are added + inlier_threshold: 0.05 # GT inlier threshold (if gt inlier ratio is lower than this threshold the transformation loss will not be backpropagated) + +train: + batch_size: 8 # Training batch size + num_workers: 8 # Number of workers used for the data loader + max_epoch: 500 # Max number of training epochs + stat_interval: 5 # Interval at which the stats are printed out and saved for the tensorboard (if positive it denotes iteration if negative epochs) + chkpt_interval: 100 # Interval at which the model is saved (if positive it denotes iteration if negative epochs) + val_interval: 300 # Interval at which the validation is performed (if positive it denotes iteration if negative epochs) + model_selection_metric: loss # Metric used to determine best model on the validation dataset + model_selection_mode: minimize # Metric used to determine best model on the validation dataset + samp_type: rand # Sampler type on of [rand, fps] + corr_type: soft # Feature matching type on of [soft, hard, gumbel_soft] + st_grad_flag: True # If true and soft or gumbel_soft corr are selected, the gradients are propagated straight through (https://arxiv.org/abs/1308.3432) + compute_precision: True # Compute precision and recall of the per correspondence classification + +val: + batch_size: 8 # Validation batch size + num_workers: 1 # Number of workers for the validation data set + +test: + results_path: '' # Path to where to save the test results + batch_size: 8 # Test batch size + num_workers: 1 # Num of workers to use for the test data set + +optimizer: + alg: Adam # Which optimizer to use + learning_rate: 0.0001 # Initial learning rate + weight_decay: 0.0 # Weight decay weight + momentum: 0.9 #Momentum + +model: + init_from: null # Path to the pretrained model + init_file_name: model_best.pt # Pretrained model filename diff --git a/figures/LM3DPCR.jpg b/figures/LM3DPCR.jpg new file mode 100644 index 0000000..e2eaa3c Binary files /dev/null and b/figures/LM3DPCR.jpg differ diff --git a/lib/checkpoints.py b/lib/checkpoints.py new file mode 100644 index 0000000..a362b19 --- /dev/null +++ b/lib/checkpoints.py @@ -0,0 +1,127 @@ +""" +A collection of function used for saving and loading the model weights. +Based on the Occupany Networks repository: https://github.com/autonomousvision/occupancy_networks + +If you use in your project pelase consider also citing: http://www.cvlibs.net/publications/Mescheder2019CVPR.pdf +""" + +import os +import urllib +import torch +from torch.utils import model_zoo + +class CheckpointIO(object): + ''' CheckpointIO class. + It handles saving and loading checkpoints. + Args: + checkpoint_dir (str): path where checkpoints are saved + ''' + def __init__(self, checkpoint_dir='./chkpts', initialize_from=None, + initialization_file_name='model_best.pt', **kwargs): + self.module_dict = kwargs + self.checkpoint_dir = checkpoint_dir + self.initialize_from = initialize_from + self.initialization_file_name = initialization_file_name + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + def register_modules(self, **kwargs): + ''' Registers modules in current module dictionary. + ''' + self.module_dict.update(kwargs) + + def save(self, filename, **kwargs): + ''' Saves the current module dictionary. + Args: + filename (str): name of output file + ''' + if not os.path.isabs(filename): + filename = os.path.join(self.checkpoint_dir, filename) + + outdict = kwargs + for k, v in self.module_dict.items(): + outdict[k] = v.state_dict() + torch.save(outdict, filename) + + def load(self, filename): + '''Loads a module dictionary from local file or url. + Args: + filename (str): name of saved module dictionary + ''' + if is_url(filename): + return self.load_url(filename) + else: + return self.load_file(filename) + + def load_file(self, filename): + '''Loads a module dictionary from file. + Args: + filename (str): name of saved module dictionary + ''' + + if not os.path.isabs(filename): + filename = os.path.join(self.checkpoint_dir, filename) + + if os.path.exists(filename): + print(filename) + print('=> Loading checkpoint from local file...') + state_dict = torch.load(filename) + scalars = self.parse_state_dict(state_dict) + return scalars + else: + if self.initialize_from is not None: + self.initialize_weights() + raise FileExistsError + + def load_url(self, url): + '''Load a module dictionary from url. + Args: + url (str): url to saved model + ''' + print(url) + print('=> Loading checkpoint from url...') + state_dict = model_zoo.load_url(url, progress=True) + scalars = self.parse_state_dict(state_dict) + return scalars + + def parse_state_dict(self, state_dict): + '''Parse state_dict of model and return scalars. + Args: + state_dict (dict): State dict of model + ''' + + for k, v in self.module_dict.items(): + if k in state_dict: + v.load_state_dict(state_dict[k]) + else: + print('Warning: Could not find %s in checkpoint!' % k) + scalars = {k: v for k, v in state_dict.items() + if k not in self.module_dict} + return scalars + + def initialize_weights(self): + ''' Initializes the model weights from another model file. + ''' + + print('Intializing weights from model %s' % self.initialize_from) + filename_in = os.path.join( + self.initialize_from, self.initialization_file_name) + + model_state_dict = self.module_dict.get('model').state_dict() + model_dict = self.module_dict.get('model').state_dict() + model_keys = set([k for (k, v) in model_dict.items()]) + + init_model_dict = torch.load(filename_in)['model'] + init_model_k = set([k for (k, v) in init_model_dict.items()]) + + for k in model_keys: + if ((k in init_model_k) and (model_state_dict[k].shape == + init_model_dict[k].shape)): + model_state_dict[k] = init_model_dict[k] + self.module_dict.get('model').load_state_dict(model_state_dict) + + +def is_url(url): + ''' Checks if input is url.''' + scheme = urllib.parse.urlparse(url).scheme + return scheme in ('http', 'https') \ No newline at end of file diff --git a/lib/config.py b/lib/config.py new file mode 100644 index 0000000..e43cf05 --- /dev/null +++ b/lib/config.py @@ -0,0 +1,48 @@ + +from lib import pairwise, multiview +import torch + +method_dict = { + 'pairwise': pairwise, + 'multiview': multiview, +} + + +def get_model(cfg): + ''' + Gets the model instance based on the input paramters. + + Args: + cfg (dict): config dictionary + + Returns: + model (nn.Module): torch model initialized with the input params + ''' + + method = cfg['method']['task'] + device = torch.device('cuda' if (torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu') + + model = method_dict[method].config.get_model(cfg, device=device) + + return model + +def get_trainer(cfg, model, optimizer, logger): + ''' + Returns a trainer instance. + + Args: + cfg (dict): config dictionary + model (nn.Module): the model used for training + optimizer (optimizer): pytorch optimizer + logger (logger instance): logger used to output info to the consol + + Returns: + trainer (trainer instance): trainer instance used to train the network + ''' + + method = cfg['method']['task'] + device = torch.device('cuda' if (torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu') + + trainer = method_dict[method].config.get_trainer(cfg, model, optimizer, logger, device) + + return trainer diff --git a/lib/data.py b/lib/data.py new file mode 100644 index 0000000..2823249 --- /dev/null +++ b/lib/data.py @@ -0,0 +1,390 @@ +import numpy as np +import glob +import logging +import os +import torch.utils.data as data +from lib.utils import augment_precomputed_data, add_jitter, get_file_list, get_folder_list, read_trajectory +import torch + +def collate_fn(batch): + data = {} + + def to_tensor(x): + """ + Maps the numpy arrays to torch tenors. In torch tenor is used as input + it simply returns it. + + Args: + x (numpy array): numpy array input data + + Returns: + x (torch tensor): input converted to a torch tensor + """ + if isinstance(x, torch.Tensor): + return x + elif isinstance(x, np.ndarray): + return torch.from_numpy(x).float() + else: + raise ValueError(f'Can not convert to torch tensor, {x}') + + + for key in batch[0]: + data[key] = [] + + for sample in batch: + for key in sample: + if isinstance(sample[key], list): + data[key].append(sample[key]) + else: + data[key].append(to_tensor(sample[key])) + + for key in data: + if isinstance(data[key][0], torch.Tensor): + data[key] = torch.stack(data[key]) + + return data + + +class PrecomputedIndoorDataset(data.Dataset): + """ + Dataset class for precomputed FCGF descrptors and established correspondences. Used to train the Registration blocks, + confidence block, and transformation synchronization layer. If using this dataset the method descriptor argument and + train_desc flag in the arguments have to be set to null and False respectively. Eases the training of the second part + as the overlaping pairs that build a graph can be presampled + + """ + + def __init__(self, phase, config): + + self.files = [] + self.random_shuffle = config['data']['shuffle_examples'] + self.root = config['data']['root'] + self.config = config + self.data = None + self.randng = np.random.RandomState() + self.use_mutuals = config['data']['use_mutuals'] + self.dist_th = config['data']['dist_th'] + self.max_num_points = config['data']['max_num_points'] + self.augment_data = config['data']['augment_data'] + self.jitter = config['data']['jitter'] + + self.device = torch.device('cuda' if (torch.cuda.is_available() and config['misc']['use_gpu']) else 'cpu') + + logging.info("Loading the subset {} from {}!".format(phase,self.root)) + + subset_names = open(self.DATA_FILES[phase]).read().split() + + for name in subset_names: + self.files.append(name) + + def __getitem__(self, idx): + + file = os.path.join(self.root,self.files[idx]) + data = np.load(file) + + file_name = file.replace(os.sep,'/').split('/')[-1] + + xs = data['x'] + ys = data['y'] + Rs = data['R'] + ts = np.expand_dims(data['t'], -1) + + mutuals = data['mutuals'] + inlier_ratio = data['inlier_ratio'] # Thresholded at 5cm deviation + inlier_ratio_mutuals = data['inlier_ratio_mutuals'] # Thresholded at 5cm deviation + overlap = data['overlap'] # Max overlap + + + # Shuffle the examples + if self.random_shuffle: + if xs.shape[0] >= self.max_num_points: + sample_idx = np.random.choice(xs.shape[0], self.max_num_points, replace=False) + else: + sample_idx = np.concatenate((np.arange(xs.shape[0]), + np.random.choice(xs.shape[0], self.max_num_points-xs.shape[0], replace=True)),axis=-1) + + xs = xs[sample_idx,:] + ys = ys[sample_idx] + mutuals = mutuals[sample_idx] + + + # Check if the the mutuals or the ratios should be used + side = [] + if self.use_mutuals == 0: + pass + elif self.use_mutuals == 1: + mask = mutuals.reshape(-1).astype(bool) + xs = xs[mask,:] + ys = ys[mask] + elif self.use_mutuals == 2: + side.append(mutuals.reshape(-1,1)) + side = np.concatenate(side,axis=-1) + else: + raise NotImplementedError + + # Augment the data augmentation + if self.augment_data: + xs, Rs, ts = augment_precomputed_data(xs, Rs, ts) + + if self.jitter: + xs, ys = add_jitter(xs, Rs, ts) + + + # Threshold ys based on the distance threshol + ys_binary = (ys < self.dist_th).astype(xs.dtype) + + if not side: + side = np.array([0]) + + # Prepare data + xs = np.expand_dims(xs,0) + ys = np.expand_dims(ys_binary,-1) + + + return {'R': Rs, + 't': ts, + 'xs': xs, + 'ys': ys, + 'side': side, + 'overlap': overlap, + 'inlier_ratio': inlier_ratio} + + def __len__(self): + return len(self.files) + + def reset_seed(self,seed=41): + logging.info('Resetting the data loader seed to {}'.format(seed)) + self.randng.seed(seed) + + + + + + +class PrecomputedPairwiseEvalDataset(data.Dataset): + """ + Dataset class for evaluating the pairwise registration based on the precomputed feature correspondences + + """ + def __init__(self, args): + + self.files = [] + self.root = args.source_path + self.use_mutuals = args.mutuals + save_path = os.path.join(self.root, 'results', args.method) + save_path += '/mutuals/' if args.mutuals else '/all/' + + + logging.info("Loading the eval data from {}!".format(self.root)) + + scene_names = get_folder_list(os.path.join(self.root,'correspondences')) + + for folder in scene_names: + curr_scene_name = folder.split('/')[-1] + if os.path.exists(os.path.join(save_path,curr_scene_name,'traj.txt')) and not args.overwrite: + logging.info('Trajectory for scene {} already exists and will not be recomputed.'.format(curr_scene_name)) + else: + if args.only_gt_overlaping: + gt_pairs, gt_traj = read_trajectory(os.path.join(self.root,'raw_data', curr_scene_name, "gt.log")) + for idx_1, idx_2, _ in gt_pairs: + self.files.append(os.path.join(folder,curr_scene_name + '_{}_{}.npz'.format(str(idx_1).zfill(3),str(idx_2).zfill(3)))) + + else: + corr_files = get_file_list(folder) + for corr in corr_files: + self.files.append(corr) + + def __getitem__(self, idx): + + curr_file = os.path.join(self.files[idx]) + data = np.load(curr_file) + + idx_1 = str(curr_file.split('_')[-2]) + idx_2 = str(curr_file.split('_')[-1].split('.')[0]) + curr_scene_name = curr_file.split('/')[-2] + metadata = [curr_scene_name, idx_1, idx_2] + xs = data['x'] + + xyz1 = np.load(os.path.join(self.root,'features',curr_scene_name, curr_scene_name + '_{}.npz'.format(idx_1)))['xyz'] + xyz2 = np.load(os.path.join(self.root,'features',curr_scene_name, curr_scene_name + '_{}.npz'.format(idx_2)))['xyz'] + + if self.use_mutuals == 1: + mutuals = data['mutuals'] + xs = xs[mutuals.astype(bool).reshape(-1), :] + + return {'xs': np.expand_dims(xs,0), + 'metadata':metadata, + 'idx':np.array(idx), + 'xyz1': [xyz1], + 'xyz2': [xyz2]} + + def __len__(self): + test = len(self.files) + return len(self.files) + + def reset_seed(self,seed=41): + logging.info('Resetting the data loader seed to {}'.format(seed)) + self.randng.seed(seed) + + + + + + +### NEEDS TO BE IMPLEMENTED ### +class RawIndoorDataset(data.Dataset): + def __init__(self, phase, config): + + self.files = [] + self.random_shuffle = config['data']['shuffle_examples'] + self.root = config['data']['root'] + self.config = config + self.data = None + self.randng = np.random.RandomState() + self.use_ratio = config['misc']['use_ratio'] + self.use_ratio_tf = config['misc']['use_ratio_th'] + self.use_mutuals = config['misc']['use_mutuals'] + self.dist_th = config['misc']['dist_th'] + self.max_num_points = config['misc']['max_num_points'] + + self.device = torch.device('cuda' if (torch.cuda.is_available() and config['misc']['use_gpu']) else 'cpu') + + logging.info("Loading the subset {} from {}!".format(phase,self.root)) + + + subset_names = open(self.DATA_FILES[phase]).read().split() + + for name in subset_names: + self.files.append(name) + + def __getitem__(self, idx): + + file = os.path.join(self.root,self.files[idx]) + data = np.load(file) + + file_name = file.replace(os.sep,'/').split('/')[-1] + + xs = data['x'] + ys = data['y'] + Rs = data['R'] + ts = data['t'] + ratios = data['ratios'] + mutuals = data['mutuals'] + inlier_ratio = data['inlier_ratio'] # Thresholded at 5cm deviation + inlier_ratio_mutuals = data['inlier_ratio_mutuals'] # Thresholded at 5cm deviation + overlap = data['overlap'] # Max overlap + + + # Shuffle the examples + if self.random_shuffle: + if xs.shape[0] >= self.max_num_points: + sample_idx = np.random.choice(xs.shape[0], self.max_num_points, replace=False) + else: + sample_idx = np.concatenate((np.arange(xs.shape[0]), + np.random.choice(xs.shape[0], self.max_num_points-xs.shape[0], replace=True)),axis=-1) + + xs = xs[sample_idx,:] + ys = ys[sample_idx] + ratios = ratios[sample_idx] + mutuals = mutuals[sample_idx] + + + # Check if the the mutuals or the ratios should be used + side = [] + if self.use_ratio == 0 and self.use_mutuals == 0: + pass + elif self.use_ratio == 1 and self.use_mutuals == 0: + mask = ratios.reshape(-1) < self.use_ratio_tf + xs = xs[mask,:] + ys = ys[mask] + elif self.use_ratio == 0 and self.use_mutuals == 1: + mask = mutuals.reshape(-1).astype(bool) + xs = xs[mask,:] + ys = ys[mask] + elif self.use_ratio == 2 and self.use_mutuals == 2: + side.append(ratios.reshape(-1,1)) + side.append(mutuals.reshape(-1,1)) + side = np.concatenate(side,axis=-1) + else: + raise NotImplementedError + + # Threshold ys based on the distance threshol + ys_binary = (ys < self.dist_th).astype(xs.dtype) + + if not side: + side = np.array([0]) + + + return {'R': Rs, + 't': np.expand_dims(ts, -1), + 'xs': np.expand_dims(xs, 0), + 'ys': np.expand_dims(ys_binary, -1), + 'side': side, + 'overlap': overlap, + 'inlier_ratio': inlier_ratio} + + + def __len__(self): + return len(self.files) + + def reset_seed(self,seed=41): + logging.info('Resetting the data loader seed to {}'.format(seed)) + self.randng.seed(seed) + +class Precomputed3DMatch(PrecomputedIndoorDataset): + # 3D Match dataset all files + DATA_FILES = { + 'train': './configs/3DMatch/3DMatch_all_train.txt', + 'val': './configs/3DMatch/3DMatch_all_valid.txt', + 'test': './configs/3DMatch/test_all.txt' + } + +class Precomputed3DMatchFiltered(PrecomputedIndoorDataset): + # 3D Match dataset with only overlaping point cloud and with examples + # that have more than 5% inliers (see dataset readme for more info) + + DATA_FILES = { + 'train': './configs/3DMatch_new/3DMatch_filtered_train.txt', + 'val': './configs/3DMatch_new/3DMatch_filtered_valid.txt', + 'test': './configs/3DMatch/test_all.txt' + } + + +# Map the datasets to string names +ALL_DATASETS = [Precomputed3DMatch, Precomputed3DMatchFiltered] +dataset_str_mapping = {d.__name__: d for d in ALL_DATASETS} + + +def make_data_loader(config, phase, shuffle_dataset=None): + """ + Defines the data loader based on the parameters specified in the config file + Args: + config (dict): dictionary of the arguments + phase (str): phase for which the data loader should be initialized in [train,val,test] + shuffle_dataset (bool): shuffle the dataset or not + + Returns: + loader (torch data loader): data loader that handles loading the data to the model + """ + + assert config['misc']['run_mode'] in ['train','val','test'] + + if shuffle_dataset is None: + shuffle_dataset = shuffle_dataset != 'test' + + # Select the defined dataset + Dataset = dataset_str_mapping[config['data']['dataset']] + + dset = Dataset(phase, config=config) + + loader = torch.utils.data.DataLoader( + dset, + batch_size=config[phase]['batch_size'], + shuffle=shuffle_dataset, + num_workers=config[phase]['num_workers'], + collate_fn=collate_fn, + pin_memory=False, + drop_last=True + ) + + return loader diff --git a/lib/descriptor/__init__.py b/lib/descriptor/__init__.py new file mode 100644 index 0000000..94acab5 --- /dev/null +++ b/lib/descriptor/__init__.py @@ -0,0 +1,6 @@ +from lib.descriptor import fcgf + +descriptor_dict = { + 'fcgf': fcgf.FCGFNet, +} + diff --git a/lib/descriptor/common.py b/lib/descriptor/common.py new file mode 100644 index 0000000..a8aead3 --- /dev/null +++ b/lib/descriptor/common.py @@ -0,0 +1,10 @@ +import MinkowskiEngine as ME + + +def get_norm(norm_type, num_feats, bn_momentum=0.05, D=-1): + if norm_type == 'BN': + return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) + elif norm_type == 'IN': + return ME.MinkowskiInstanceNorm(num_feats, dimension=D) + else: + raise ValueError(f'Type {norm_type}, not defined') diff --git a/lib/descriptor/fcgf.py b/lib/descriptor/fcgf.py new file mode 100644 index 0000000..8baf89b --- /dev/null +++ b/lib/descriptor/fcgf.py @@ -0,0 +1,284 @@ +""" +Source code of the Fully Convolutional Feature Descriptor (ICCV 2019), based on the FCGF repository: https://github.com/chrischoy/FCGF/ + +If you use in your project please consider also citing: https://node1.chrischoy.org/data/publications/fcgf/fcgf.pdf + +""" + +# -*- coding: future_fstrings -*- +import torch +import torch.nn as nn +import MinkowskiEngine as ME +import MinkowskiEngine.MinkowskiFunctional as MEF + +def get_norm(norm_type, num_feats, bn_momentum=0.05, D=-1): + if norm_type == 'BN': + return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) + elif norm_type == 'IN': + return ME.MinkowskiInstanceNorm(num_feats, dimension=D) + else: + raise ValueError(f'Type {norm_type}, not defined') + + +class BasicBlockBase(nn.Module): + expansion = 1 + NORM_TYPE = 'BN' + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + bn_momentum=0.1, + D=3): + super(BasicBlockBase, self).__init__() + + self.conv1 = ME.MinkowskiConvolution( + inplanes, planes, kernel_size=3, stride=stride, dimension=D) + self.norm1 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) + self.conv2 = ME.MinkowskiConvolution( + planes, + planes, + kernel_size=3, + stride=1, + dilation=dilation, + has_bias=False, + dimension=D) + self.norm2 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = MEF.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = MEF.relu(out) + + return out + + +class BasicBlockBN(BasicBlockBase): + NORM_TYPE = 'BN' + + +class BasicBlockIN(BasicBlockBase): + NORM_TYPE = 'IN' + + +def get_block(norm_type, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + bn_momentum=0.1, + D=3): + if norm_type == 'BN': + return BasicBlockBN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) + elif norm_type == 'IN': + return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) + else: + raise ValueError(f'Type {norm_type}, not defined') + + + + + +class FCGFNet(ME.MinkowskiNetwork): + NORM_TYPE = 'BN' + BLOCK_NORM_TYPE = 'BN' + CHANNELS = [None, 32, 64, 128, 256] + TR_CHANNELS = [None, 64, 64, 64, 128] + + # To use the model, must call initialize_coords before forward pass. + # Once data is processed, call clear to reset the model before calling initialize_coords + def __init__(self, + in_channels=3, + out_channels=32, + bn_momentum=0.1, + normalize_feature=True, + conv1_kernel_size=7, + D=3): + + ME.MinkowskiNetwork.__init__(self, D) + NORM_TYPE = self.NORM_TYPE + BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE + CHANNELS = self.CHANNELS + TR_CHANNELS = self.TR_CHANNELS + self.normalize_feature = normalize_feature + self.conv1 = ME.MinkowskiConvolution( + in_channels=in_channels, + out_channels=CHANNELS[1], + kernel_size=conv1_kernel_size, + stride=1, + dilation=1, + has_bias=False, + dimension=D) + + self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) + + self.block1 = get_block( + BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D) + + self.conv2 = ME.MinkowskiConvolution( + in_channels=CHANNELS[1], + out_channels=CHANNELS[2], + kernel_size=3, + stride=2, + dilation=1, + has_bias=False, + dimension=D) + self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) + + self.block2 = get_block( + BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D) + + self.conv3 = ME.MinkowskiConvolution( + in_channels=CHANNELS[2], + out_channels=CHANNELS[3], + kernel_size=3, + stride=2, + dilation=1, + has_bias=False, + dimension=D) + self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) + + self.block3 = get_block( + BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D) + + self.conv4 = ME.MinkowskiConvolution( + in_channels=CHANNELS[3], + out_channels=CHANNELS[4], + kernel_size=3, + stride=2, + dilation=1, + has_bias=False, + dimension=D) + self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) + + self.block4 = get_block( + BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D) + + self.conv4_tr = ME.MinkowskiConvolutionTranspose( + in_channels=CHANNELS[4], + out_channels=TR_CHANNELS[4], + kernel_size=3, + stride=2, + dilation=1, + has_bias=False, + dimension=D) + self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) + + self.block4_tr = get_block( + BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) + + self.conv3_tr = ME.MinkowskiConvolutionTranspose( + in_channels=CHANNELS[3] + TR_CHANNELS[4], + out_channels=TR_CHANNELS[3], + kernel_size=3, + stride=2, + dilation=1, + has_bias=False, + dimension=D) + self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) + + self.block3_tr = get_block( + BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) + + self.conv2_tr = ME.MinkowskiConvolutionTranspose( + in_channels=CHANNELS[2] + TR_CHANNELS[3], + out_channels=TR_CHANNELS[2], + kernel_size=3, + stride=2, + dilation=1, + has_bias=False, + dimension=D) + self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) + + self.block2_tr = get_block( + BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) + + self.conv1_tr = ME.MinkowskiConvolution( + in_channels=CHANNELS[1] + TR_CHANNELS[2], + out_channels=TR_CHANNELS[1], + kernel_size=1, + stride=1, + dilation=1, + has_bias=False, + dimension=D) + + # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) + + self.final = ME.MinkowskiConvolution( + in_channels=TR_CHANNELS[1], + out_channels=out_channels, + kernel_size=1, + stride=1, + dilation=1, + has_bias=True, + dimension=D) + + def forward(self, x): + out_s1 = self.conv1(x) + out_s1 = self.norm1(out_s1) + out_s1 = self.block1(out_s1) + out = MEF.relu(out_s1) + + out_s2 = self.conv2(out) + out_s2 = self.norm2(out_s2) + out_s2 = self.block2(out_s2) + out = MEF.relu(out_s2) + + out_s4 = self.conv3(out) + out_s4 = self.norm3(out_s4) + out_s4 = self.block3(out_s4) + out = MEF.relu(out_s4) + + out_s8 = self.conv4(out) + out_s8 = self.norm4(out_s8) + out_s8 = self.block4(out_s8) + out = MEF.relu(out_s8) + + out = self.conv4_tr(out) + out = self.norm4_tr(out) + out = self.block4_tr(out) + out_s4_tr = MEF.relu(out) + + out = ME.cat((out_s4_tr, out_s4)) + + out = self.conv3_tr(out) + out = self.norm3_tr(out) + out = self.block3_tr(out) + out_s2_tr = MEF.relu(out) + + out = ME.cat((out_s2_tr, out_s2)) + + out = self.conv2_tr(out) + out = self.norm2_tr(out) + out = self.block2_tr(out) + out_s1_tr = MEF.relu(out) + + out = ME.cat((out_s1_tr, out_s1)) + out = self.conv1_tr(out) + out = MEF.relu(out) + out = self.final(out) + + if self.normalize_feature: + return ME.SparseTensor( + out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), + coords_key=out.coords_key, + coords_manager=out.coords_man) + else: + + return out + diff --git a/lib/descriptor/residual_block.py b/lib/descriptor/residual_block.py new file mode 100644 index 0000000..45d275e --- /dev/null +++ b/lib/descriptor/residual_block.py @@ -0,0 +1,77 @@ +import torch.nn as nn + +from fcgf_lib.model.common import get_norm + +import MinkowskiEngine as ME +import MinkowskiEngine.MinkowskiFunctional as MEF + + +class BasicBlockBase(nn.Module): + expansion = 1 + NORM_TYPE = 'BN' + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + bn_momentum=0.1, + D=3): + super(BasicBlockBase, self).__init__() + + self.conv1 = ME.MinkowskiConvolution( + inplanes, planes, kernel_size=3, stride=stride, dimension=D) + self.norm1 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) + self.conv2 = ME.MinkowskiConvolution( + planes, + planes, + kernel_size=3, + stride=1, + dilation=dilation, + has_bias=False, + dimension=D) + self.norm2 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = MEF.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = MEF.relu(out) + + return out + + +class BasicBlockBN(BasicBlockBase): + NORM_TYPE = 'BN' + + +class BasicBlockIN(BasicBlockBase): + NORM_TYPE = 'IN' + + +def get_block(norm_type, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + bn_momentum=0.1, + D=3): + if norm_type == 'BN': + return BasicBlockBN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) + elif norm_type == 'IN': + return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) + else: + raise ValueError(f'Type {norm_type}, not defined') diff --git a/lib/filtering/__init__.py b/lib/filtering/__init__.py new file mode 100644 index 0000000..22b6854 --- /dev/null +++ b/lib/filtering/__init__.py @@ -0,0 +1,6 @@ +from lib.filtering import , oanet + + +filtering_dict = { + 'oanet': oanet.OANet +} diff --git a/lib/filtering/oanet.py b/lib/filtering/oanet.py new file mode 100644 index 0000000..9c511b6 --- /dev/null +++ b/lib/filtering/oanet.py @@ -0,0 +1,265 @@ +""" +Extension of the filtering network proposed in Learning Two-View Correspondences and Geometry Using Order-Aware Network (ICCV 2019), +to 3D correspondence filtering. Source coude based on the OANet repository: https://github.com/zjhthu/OANet. + +If you use in your project pelase consider also citing: https://arxiv.org/pdf/1908.04964.pdf + +""" + +import torch +import torch.nn as nn +from lib.utils import kabsch_transformation_estimation +import logging + +# If the BN stat should be tracked and used in the inference mode +BN_TRACK_STATS = True + + +class PointCN(nn.Module): + def __init__(self, channels, out_channels=None): + nn.Module.__init__(self) + if not out_channels: + out_channels = channels + self.shot_cut = None + if out_channels != channels: + self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1) + self.conv = nn.Sequential( + nn.InstanceNorm2d(channels), + nn.BatchNorm2d(channels, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(channels, out_channels, kernel_size=1), + nn.InstanceNorm2d(out_channels), + nn.BatchNorm2d(out_channels, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=1) + ) + + def forward(self, x): + out = self.conv(x) + if self.shot_cut: + out = out + self.shot_cut(x) + else: + out = out + x + return out + + +class trans(nn.Module): + def __init__(self, dim1, dim2): + nn.Module.__init__(self) + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x): + return x.transpose(self.dim1, self.dim2) + + +class OAFilter(nn.Module): + def __init__(self, channels, points, out_channels=None): + nn.Module.__init__(self) + if not out_channels: + out_channels = channels + self.shot_cut = None + if out_channels != channels: + self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1) + self.conv1 = nn.Sequential( + nn.InstanceNorm2d(channels, eps=1e-3), + nn.BatchNorm2d(channels, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(channels, out_channels, kernel_size=1), # b*c*n*1 + trans(1, 2)) + + # Spatial Correlation Layer + self.conv2 = nn.Sequential( + nn.BatchNorm2d(points, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(points, points, kernel_size=1) + ) + self.conv3 = nn.Sequential( + trans(1, 2), + nn.InstanceNorm2d(out_channels, eps=1e-3), + nn.BatchNorm2d(out_channels, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=1) + ) + + def forward(self, x): + out = self.conv1(x) + out = out + self.conv2(out) + out = self.conv3(out) + if self.shot_cut: + out = out + self.shot_cut(x) + else: + out = out + x + return out + + +class diff_pool(nn.Module): + def __init__(self, in_channel, output_points): + nn.Module.__init__(self) + self.output_points = output_points + self.conv = nn.Sequential( + nn.InstanceNorm2d(in_channel, eps=1e-3), + nn.BatchNorm2d(in_channel, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(in_channel, output_points, kernel_size=1)) + + def forward(self, x): + embed = self.conv(x) # b*k*n*1 + S = torch.softmax(embed, dim=2).squeeze(3) + out = torch.matmul(x.squeeze(3), S.transpose(1, 2)).unsqueeze(3) + return out + + +class diff_unpool(nn.Module): + def __init__(self, in_channel, output_points): + nn.Module.__init__(self) + self.output_points = output_points + self.conv = nn.Sequential( + nn.InstanceNorm2d(in_channel, eps=1e-3), + nn.BatchNorm2d(in_channel, track_running_stats=BN_TRACK_STATS), + nn.ReLU(), + nn.Conv2d(in_channel, output_points, kernel_size=1)) + + def forward(self, x_up, x_down): + #x_up: b*c*n*1 + #x_down: b*c*k*1 + embed = self.conv(x_up) # b*k*n*1 + S = torch.softmax(embed, dim=1).squeeze(3) # b*k*n + out = torch.matmul(x_down.squeeze(3), S).unsqueeze(3) + return out + + +class OANBlock(nn.Module): + def __init__(self, net_channels, input_channel, depth, clusters, normalize_w): + nn.Module.__init__(self) + channels = net_channels + self.layer_num = depth + logging.info('OANET: channels:' + str(channels) + ', layer_num:' + str(self.layer_num)) + self.conv1 = nn.Conv2d(input_channel, channels, kernel_size=1) + + l2_nums = clusters + + self.l1_1 = [] + for _ in range(self.layer_num//2): + self.l1_1.append(PointCN(channels)) + + self.down1 = diff_pool(channels, l2_nums) + + self.l2 = [] + for _ in range(self.layer_num//2): + self.l2.append(OAFilter(channels, l2_nums)) + + self.up1 = diff_unpool(channels, l2_nums) + + self.l1_2 = [] + self.l1_2.append(PointCN(2*channels, channels)) + for _ in range(self.layer_num//2-1): + self.l1_2.append(PointCN(channels)) + + self.l1_1 = nn.Sequential(*self.l1_1) + self.l1_2 = nn.Sequential(*self.l1_2) + self.l2 = nn.Sequential(*self.l2) + + self.output = nn.Conv2d(channels, 1, kernel_size=1) + + def forward(self, data, xs): + #data: b*c*n*1 + x1_1 = self.conv1(data) + x1_1 = self.l1_1(x1_1) + x_down = self.down1(x1_1) + x2 = self.l2(x_down) + x_up = self.up1(x1_1, x2) + out = self.l1_2(torch.cat([x1_1, x_up], dim=1)) + + logits = torch.squeeze(torch.squeeze(self.output(out), 3), 1) + weights = torch.relu(torch.tanh(logits)) + + if torch.any(torch.sum(weights, dim=1) == 0.0): + weights = weights + 1/weights.shape[1] + + x1, x2 = xs[:, 0, :, :3], xs[:, 0, :, 3:] + + rotation_est, translation_est, residuals, gradient_not_valid = kabsch_transformation_estimation( + x1, x2, weights) + + return logits, weights, rotation_est, translation_est, residuals, out, gradient_not_valid + + +class OANet(nn.Module): + """ + OANet filtering class. Build an OAnet object that represent the extension of the filtering network proposed in + (https://arxiv.org/abs/1908.04964) to the problem of 3D correspondence filtering. + + The local context is aggregated using Context normalization and Order-Aware clustering blocks. + + Args: + cfg (dict): configuration parameter + + """ + def __init__(self, cfg): + nn.Module.__init__(self + ) + + self.iter_num = cfg['misc']['iter_num'] + depth_each_stage = cfg['misc']['net_depth']//(cfg['misc']['iter_num']+1) + self.side_channel = (cfg['data']['use_mutuals'] == 2) + + self.reg_init = OANBlock(cfg['misc']['net_channel'], 6 + self.side_channel, + depth_each_stage, cfg['misc']['clusters'], cfg['misc']['normalize_weights']) + + self.reg_iter = [OANBlock(cfg['misc']['net_channel'], 8 + self.side_channel, depth_each_stage, cfg['misc']['clusters'], cfg['misc']['normalize_weights']) + for _ in range(self.iter_num)] + + self.reg_iter = nn.Sequential(*self.reg_iter) + + self.device = torch.device('cuda' if (torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu') + + + def forward(self, data): + """ + For each of the putative correspondences infers a weight [0,1] denoting if the correspondence is an inlier (1) + or an outlier (0). Based on the weighted Kabsch algorithm it additionally estimates the pairwise rotation matrix + and translation parameters. + + Args: + data (dict): dictionariy of torch tensors representing the input data + + Returns: + output (duct): dictionary of output data + + """ + + assert data['xs'].dim() == 4 and data['xs'].shape[1] == 1 + #data: b*1*n*c + input_data = data['xs'].transpose(1, 3).to(self.device) + + res_logits, res_scores, res_rot_est, res_trans_est = [], [], [], [] + + # First pass through the network + logits, scores, rot_est, trans_est, residuals, latent_features, gradient_not_valid = self.reg_init( + input_data, data['xs'].to(self.device)) + + res_logits.append(logits), res_scores.append(scores), res_rot_est.append(rot_est), res_trans_est.append(trans_est) + + # If iterative approach then append residuals and scores and perform additional passes + for i in range(self.iter_num): + logits, scores, rot_est, trans_est, residuals, latent_features, temp_gradient_not_valid = self.reg_iter[i]( + torch.cat([input_data, residuals.detach().unsqueeze(1).unsqueeze(3), + scores.unsqueeze(1).unsqueeze(3)], dim=1), data['xs'].to(self.device)) + + gradient_not_valid = (temp_gradient_not_valid or gradient_not_valid) + + res_logits.append(logits), res_scores.append( + scores), res_rot_est.append(rot_est), res_trans_est.append(trans_est) + + + # Construct the output + output = {} + output['logits'] = res_logits + output['scores'] = res_scores + output['rot_est'] = res_rot_est + output['trans_est'] = res_trans_est + output['latent features'] = latent_features + output['gradient_flag'] = gradient_not_valid + + return output diff --git a/lib/layers.py b/lib/layers.py new file mode 100644 index 0000000..c22dd62 --- /dev/null +++ b/lib/layers.py @@ -0,0 +1,202 @@ +import torch +import torch.nn.functional as F +import numpy as np +import time +from sklearn.neighbors import NearestNeighbors +from lib.utils import extract_mutuals, pairwise_distance, knn_point +#from Pointnet2_PyTorch.pointnet2_ops_lib.pointnet2_ops import pointnet2_utils + + +class Soft_NN(torch.nn.Module): + """ Nearest neighbor class. Constructs either a stochastic (differentiable) or hard nearest neighbors layer. + + Args: + corr_type (string): type of the NN search + st (bool): if straight through gradient propagation should be used (biased) (https://arxiv.org/abs/1308.3432) + inv_temp (float): initial value for the inverse temperature used in softmax and gumbel_softmax + + """ + + def __init__(self, corr_type='soft', st=True, inv_temp=10): + super().__init__() + + assert corr_type in ['soft', 'hard', 'soft_gumbel'], 'Wrong correspondence type selected. Must be one of [soft, soft_gumbel, hard]' + + if corr_type == 'hard': + print('Gradients cannot be backpropagated to the feature descriptor because hard NN search is selected.') + + self.temp_inv = torch.nn.Parameter(torch.tensor([inv_temp], requires_grad=True, dtype=torch.float)) + + self.corr_type = corr_type + self.st = st + + + + def forward(self, x_f, y_f, y_c): + """ Computes the correspondences in the feature space based on the selected parameters. + + Args: + x_f (torch.tensor): infered features of points x [b,n,c] + y_f (torch.tensor): infered features of points y [b,m,c] + y_c (torch.tensor): coordinates of point y [b,m,3] + + Returns: + x_corr (torch.tensor): coordinates of the feature based correspondences of points x [b,n,3] + + """ + + dist = pairwise_distance(x_f,y_f) + #dist_min = torch.min(dist, dim=2,keepdim=True).values + #dist = dist - dist_min + + if self.corr_type == 'soft': + + y_soft = torch.softmax(-dist*self.temp_inv, dim=2) + + if self.st: + # Straight through. + index = y_soft.max(dim=2, keepdim=True)[1] + y_hard = torch.zeros_like(y_soft).scatter_(dim=2, index=index, value=1.0) + ret = y_hard - y_soft.detach() + y_soft + + else: + ret = y_soft + + elif self.corr_type == 'soft_gumbel': + + if self.st: + # Straight through. + ret = F.gumbel_softmax(-dist, tau=1.0/self.temp_inv, hard=True) + else: + ret = F.gumbel_softmax(-dist, tau=1.0/self.temp_inv, hard=False) + + else: + index = dist.min(dim=2, keepdim=True)[1] + ret = torch.zeros_like(dist).scatter_(dim=2, index=index, value=1.0) + + + # Compute corresponding coordinates + x_corr = torch.matmul(ret, y_c) + + return x_corr + +class Sampler(torch.nn.Module): + """ Sampler class. Constructs a layer used to sample the points either based on their metric distance (FPS) or by randomly selecting them. + + Args: + samp_type (string): type of the sampling to be used + st (bool): if straight through gradient propagation should be used (biased) (https://arxiv.org/abs/1308.3432) + inv_temp (float): initial value for the inverse temperature used in softmax and gumbel_softmax + + """ + def __init__(self, samp_type='fps', targeted_num_points=2000): + super().__init__() + assert samp_type in ['fps', 'rand'], 'Wrong sampling type selected. Must be one of [fps, rand]' + + + self.samp_type = samp_type + self.targeted_num_points = targeted_num_points + + + def forward(self, input_C, input_F, pts_list): + """ Samples the predifined points from the input point cloud and the corresponding feature descriptors. + + Args: + input_C (torch.tensor): coordinates of the points [~b*n,3] + input_F (torch.tensor): infered features [~b*n,c] + pts_list (list): list with the number of points of each point cloud in the batch + + Returns: + sampled_C (torch tensor): coordinates of the sampled points [b,m,3] + sampled_F (torch tensor): features of the sampled points [b,m,c] + + """ + + # Sample the data + idx_temp = [] + sampled_F = [] + sampled_C = [] + + # Final number of points to be sampled is the min of the desired number of points and smallest number of point in the batch + num_points = min(self.targeted_num_points, min(pts_list)) + + for i in range(len(pts_list)): + + pcd_range = np.arange(sum(pts_list[:i]), sum(pts_list[:(i + 1)]), 1) + + if self.samp_type == 'fps': + temp_pcd = torch.index_select(input_C, dim=0, index=torch.from_numpy(pcd_range).to(input_C).long()) + + # Perform farthest point sampling on the current point cloud + idxs = pointnet2_utils.furthest_point_sample(temp_pcd, num_points) + + # Move the indeces to the start of this point cloud + idxs += pcd_range[0] + + elif self.samp_type == 'rand': + # Randomly select the indices to keep + idxs = torch.from_numpy(np.random.choice(pcd_range,num_points, replace=False)).to(input_C) + + sampled_F.append(torch.index_select(input_F, dim=0, index=idxs.long())) + sampled_C.append(torch.index_select(input_C, dim=0, index=idxs.long())) + + return torch.stack(sampled_C, dim=0), torch.stack(sampled_F, dim=0) + + + + +if __name__ == "__main__": + + test = torch.rand((3,10,10)) + test_1 = torch.rand((3,10,10)) + test_2 = torch.rand((3,10,3)) + + soft_nn_1 = Soft_NN(corr_type='soft') + soft_nn_2 = Soft_NN(corr_type='soft_gumbel') + soft_nn_3 = Soft_NN(corr_type='hard') + + # Iterrative + neigh = NearestNeighbors() + ret_iter = [] + array_input = test_1[0,:,:] + for i in range(test.shape[0]): + neigh.fit(test_1[i,:,:].cpu().numpy()) + idx = neigh.kneighbors(test[i,:,:].cpu().numpy(), n_neighbors=1, return_distance=False) + + ret_iter.append(test_2[i,idx.reshape(-1,),:]) + + ret_iter = torch.stack(ret_iter) + ret_1 = soft_nn_1(test,test_1,test_2) + ret_2 = soft_nn_2(test,test_1,test_2) + ret_3 = soft_nn_3(test,test_1,test_2) + + diff = ret_1 - ret_2 + diff_2 = ret_2 - ret_3 + diff_3 = ret_1 - ret_3 + diff_4 = ret_1 - ret_iter + + + + + # Test the mutuals + pc_1 = torch.rand((5,2000,3)).cuda() + pc_2 = torch.rand((5,2000,3)).cuda() + pc_1_soft_c = torch.rand((5,2000,3)).cuda() + pc_2_soft_c = torch.rand((5,2000,3)).cuda() + + test_mutuals = extract_mutuals(pc_1, pc_2, pc_1_soft_c, pc_2_soft_c) + + + # Test the sampler + test_C = torch.rand(3000,3).float() + test_F = torch.rand(3000,32).float() + + pts_list = [300,700,1000,400,600] + + # Test random sampling + sampler = Sampler(targeted_num_points=100,samp_type='rand') + sampled_C, sampled_F = sampler(test_C,test_F,pts_list) + + # Test fps + sampler_fps = Sampler(targeted_num_points=100, samp_type='fps') + diff --git a/lib/logger.py b/lib/logger.py new file mode 100644 index 0000000..515d452 --- /dev/null +++ b/lib/logger.py @@ -0,0 +1,81 @@ +import os +import sys +import numpy as np +import logging +from datetime import datetime +import coloredlogs +import git +import subprocess + + +_logger = logging.getLogger() + + +def print_info(cfg, log_dir=None): + """ Logs source code configuration + + Code adapted from RPMNet repository: https://github.com/yewzijian/RPMNet/ + """ + _logger.info('Command: {}'.format(' '.join(sys.argv))) + + # Print commit ID + try: + repo = git.Repo(search_parent_directories=True) + git_sha = repo.head.object.hexsha + git_date = datetime.fromtimestamp(repo.head.object.committed_date).strftime('%Y-%m-%d') + git_message = repo.head.object.message + _logger.info('Source is from Commit {} ({}): {}'.format(git_sha[:8], git_date, git_message.strip())) + + # Also create diff file in the log directory + if log_dir is not None: + with open(os.path.join(log_dir, 'compareHead.diff'), 'w') as fid: + subprocess.run(['git', 'diff'], stdout=fid) + + except git.exc.InvalidGitRepositoryError: + pass + + # Arguments + arg_str = [] + + for k_id, k_val in cfg.items(): + for key in k_val: + arg_str.append("{}_{}: {}".format(k_id, key, k_val[key])) + + arg_str = ', '.join(arg_str) + _logger.info('Arguments: {}'.format(arg_str)) + + +def prepare_logger(cfg, log_path = None): + """Creates logging directory, and installs colorlogs + Args: + cfg (dict): config parmaters + log_path (str): Logging path (optional). This serves to overwrite the settings in cfg + + Returns: + logger (logging.Logger): logger instance + log_path (str): Logging directory + + Code borrowed from RPMNet repository: https://github.com/yewzijian/RPMNet/ + """ + + logdir = cfg['misc']['log_dir'] + + if log_path is None: + datetime_str = datetime.now().strftime('%y%m%d_%H%M%S') + log_path = os.path.join(logdir, cfg['method']['descriptor_module'] if cfg['method']['descriptor_module'] else 'No_Desc' + '_' + + cfg['method']['filter_module'] if cfg['method']['filter_module'] else 'No_Filter', datetime_str) + + os.makedirs(log_path, exist_ok=True) + + logger = logging.getLogger() + coloredlogs.install(level='INFO', logger=logger) + file_handler = logging.FileHandler('{}/log.txt'.format(log_path)) + log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s') + file_handler.setFormatter(log_formatter) + logger.addHandler(file_handler) + print_info(cfg, log_path) + logger.info('Output and logs will be saved to {}'.format(log_path)) + + return logger, log_path + + diff --git a/lib/loss.py b/lib/loss.py new file mode 100644 index 0000000..425e0f4 --- /dev/null +++ b/lib/loss.py @@ -0,0 +1,330 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math +from sklearn.metrics import precision_recall_fscore_support +from lib.utils import rotation_error, transformation_residuals + +def _hash(arr, M): + if isinstance(arr, np.ndarray): + N, D = arr.shape + else: + N, D = len(arr[0]), len(arr) + + hash_vec = np.zeros(N, dtype=np.int64) + for d in range(D): + if isinstance(arr, np.ndarray): + hash_vec += arr[:, d] * M**d + else: + hash_vec += arr[d] * M**d + + return hash_vec + + + +class DescriptorLoss(): + """ + Descriptor loss class. Creates a DescriptorLoss object that is used to train the FCGF feature descriptor. The loss is defined the same as in + the original FCGF paper. + + Args: + cfg (dict): configuration parameters + + """ + def __init__(self, cfg): + self.w_desc_loss = cfg['loss']['loss_desc'] + + # For FCGF loss we keep the parameters the same as in the original paper + self.pos_thresh = 1.4 + self.neg_thresh = 0.1 + self.batch_size = cfg['train']['batch_size'] + self.num_pos_per_batch = 1024 + self.num_hn_samples_per_batch = 256 + + def contrastive_hardest_negative_loss(self, + F0, + F1, + positive_pairs, + num_pos=5192, + num_hn_samples=2048, + pos_thresh=None, + neg_thresh=None): + """ + Computes the harderst contrastive loss as defined in the Fully Convolutional Geometric Features (Choy et al. ICCV 2019) paper + + https://node1.chrischoy.org/data/publications/fcgf/fcgf.pdf. + + Args: + F0 (torch tensor) + F1 (torch tensor) + positive_pairs (torch tensor): indices of positive pairs + num_pos (int): maximum number of positive pairs to be used + num_hn_samples (int): Number of harderst negative samples to be used + pos_thresh (): margain for positive pairs + neg_thresh (): margain for negative pairs + Returns: + pos_loss (torch tensor): loss based on the positive examples + neg_loss (torch tensor): loss based on the negative examples + """ + + N0, N1 = len(F0), len(F1) + N_pos_pairs = len(positive_pairs) + hash_seed = max(N0, N1) + sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) + sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) + + if N_pos_pairs > num_pos: + pos_sel = np.random.choice(N_pos_pairs, num_pos, replace=False) + sample_pos_pairs = positive_pairs[pos_sel] + else: + sample_pos_pairs = positive_pairs + + # Find negatives for all F1[positive_pairs[:, 1]] + subF0, subF1 = F0[sel0], F1[sel1] + + pos_ind0 = sample_pos_pairs[:, 0].long() + pos_ind1 = sample_pos_pairs[:, 1].long() + posF0, posF1 = F0[pos_ind0], F1[pos_ind1] + + D01 = torch.sum((posF0.unsqueeze(1) - posF1.unsqueeze(0)).pow(2), 2) + D10 = torch.sum((posF1.unsqueeze(1) - posF0.unsqueeze(0)).pow(2), 2) + + D01min, D01ind = D01.min(1) + D10min, D10ind = D10.min(1) + + if not isinstance(positive_pairs, np.ndarray): + positive_pairs = np.array(positive_pairs, dtype=np.int64) + + pos_keys = _hash(positive_pairs, hash_seed) + + D01ind = sel1[D01ind.cpu().numpy()] + D10ind = sel0[D10ind.cpu().numpy()] + neg_keys0 = _hash([pos_ind0.numpy(), D01ind], hash_seed) + neg_keys1 = _hash([D10ind, pos_ind1.numpy()], hash_seed) + + mask0 = torch.from_numpy( + np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) + mask1 = torch.from_numpy( + np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) + pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - pos_thresh) + neg_loss0 = F.relu(neg_thresh - D01min[mask0]).pow(2) + neg_loss1 = F.relu(neg_thresh - D10min[mask1]).pow(2) + + return pos_loss.mean(), (neg_loss0.mean() + neg_loss1.mean()) / 2 + + + def evaluate(self, F0, F1, pos_pairs): + """ + Evaluates the hardest contrastive FCGF loss given current data + + Args: + F0 (torch tensor): features of the source points [~b*n,c] + F1 (torch tensor): features of the target points [~b*n,c] + pos_pairs (torch tensor): indices of the positive pairs + + Returns: + loss (torch tensor): mean value of the HC FCGF loss + + """ + pos_loss, neg_loss = self.contrastive_hardest_negative_loss(F0, F1, pos_pairs, + num_pos=self.num_pos_per_batch * + self.batch_size, + num_hn_samples=self.num_hn_samples_per_batch * + self.batch_size, + pos_thresh=self.pos_thresh, + neg_thresh=self.neg_thresh) + + loss = pos_loss + neg_loss + + return loss + + +class ClassificationLoss(): + """ + Classification loss class. Creates a ClassificationLoss object that is used to supervise the inlier/outlier classification of the putative correspondences. + + Args: + cfg (dict): configuration parameters + + """ + def __init__(self, cfg): + self.w_class = cfg['loss']['loss_class'] + self.w_class = cfg['loss']['loss_class'] + self.compute_stats = cfg['train']['compute_precision'] + self.device = torch.device('cuda' if (torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu') + + + def class_loss(self, predicted, target): + """ + Binary classification loss per putative correspondence. + + Args: + predicted (torch tensor): predicted weight per correspondence [b,n,1] + target (torch tensor): ground truth label per correspondence (0 - outlier, 1 - inlier) [b,n,1] + + Return: + class_loss (torch tensor): binary cross entropy loss [b] + """ + + loss = nn.BCELoss(reduction='none') # Binary Cross Entropy loss, expects that the input was passed through the sigmoid + sigmoid = nn.Sigmoid() + + predicted_labels = sigmoid(predicted).flatten().to(self.device) + + class_loss = loss(predicted_labels, target.flatten()).reshape(predicted.shape[0],-1) + + # Computing weights for compensating the class imbalance + + is_pos = (target.squeeze(-1) < 0.5).type(target.type()) + is_neg = (target.squeeze(-1) > 0.5).type(target.type()) + + num_pos = torch.relu(torch.sum(is_pos, dim=1) - 1.0) + 1.0 + num_neg = torch.relu(torch.sum(is_neg, dim=1) - 1.0) + 1.0 + class_loss_p = torch.sum(class_loss * is_pos, dim=1) + class_loss_n = torch.sum(class_loss * is_neg, dim=1) + class_loss = class_loss_p * 0.5 / num_pos + class_loss_n * 0.5 / num_neg + + return class_loss + + def evaluate(self, predicted, target, scores=None): + """ + Evaluates the binary cross entropy classification loss + + Args: + predicted (torch tensor): predicted logits per correspondence [b,n] + target (torch tensor): ground truth label per correspondence (0 - outlier, 1 - inlier) [b,n,1] + scores (torch tensor): predicted score (weight) per correspondence (0 - outlier, 1 - inlier) [b,n] + + Return: + loss (torch tensor): mean binary cross entropy loss + precision (numpy array): Mean classification precision (inliers) + recall (numpy array): Mean classification recall (inliers) + """ + predicted = predicted.to(self.device) + target = target.to(self.device) + + class_loss = self.class_loss(predicted, target) + + loss = torch.tensor([0.]).to(self.device) + + if self.w_class > 0: + loss += torch.mean(self.w_class * class_loss) + + + if self.compute_stats: + assert scores != None, "If precision and recall should be computed, scores cannot be None!" + + y_predicted = scores.detach().cpu().numpy().reshape(-1) + y_gt = target.detach().cpu().numpy().reshape(-1) + + precision, recall, f_measure, _ = precision_recall_fscore_support(y_gt, y_predicted.round(), average='binary') + + return loss, precision, recall + + else: + return loss, None, None + + +class TransformationLoss(): + """ + Transformation loss class. Creates a TransformationLoss object that is used to supervise the rotation and translation estimation part of the network. + + Args: + cfg (dict): configuration parameters + + """ + + def __init__(self, cfg): + self.trans_loss_type = cfg['loss']['trans_loss_type'] + self.trans_loss_iter = cfg['loss']['trans_loss_iter'] + self.w_trans = cfg['loss']['loss_trans'] + self.device = torch.device('cuda' if (torch.cuda.is_available() and cfg['misc']['use_gpu']) else 'cpu') + self.trans_loss_margin = cfg['misc']['trans_loss_margin'] + self.inlier_threshold = cfg['loss']['inlier_threshold'] + + def trans_loss(self, x_in, rot_est, trans_est, gt_rot_mat, gt_t_vec): + """ + Loss function on the transformation parameter. Based on the selected type of the loss computes either: + 0 - Vector distance between the point reconstructed using the EST transformation paramaters and the putative correspondence + 1 - Frobenius norm on the rotation matrix and L2 norm on the translation vector + 2 - L2 distance between the points reconstructed using the estimated and GT transformation paramaters + 3 - L1 distance between the points reconstructed using the estimated and GT transformation paramaters + + Args: + x_in (torch tensor): coordinates of the input point [b,1,n,6] + rot_est (torch tensor): currently estimated rotation matrices [b,3,3] + trans_est (torch tensor): currently estimated translation vectors [b,3,1] + gt_rot_mat (torch tensor): ground truth rotation matrices [b,3,3] + gt_t_vec (torch tensor): ground truth translation vectors [b,3,1] + + Return: + r_loss (torch tensor): transformation loss if type 0 or 2 else Frobenius norm of the rotation matrices [b,1] + t_loss (torch tensor): 0 if type 0, 2 or 3 else L2 norm of the translation vectors [b,1] + """ + if self.trans_loss_type == 0: + + x2_reconstruct = torch.matmul(rot_est, x_in[:, 0, :, 0:3].transpose(1, 2)) + trans_est + r_loss = torch.mean(torch.mean(torch.norm(x2_reconstruct.transpose(1,2) - x_in[:, :, :, 3:6], dim=(1)), dim=1)) + t_loss = torch.zeros_like(r_loss) + + elif self.trans_loss_type == 1: + r_loss = torch.norm(gt_rot_mat - rot_est, dim=(1, 2)) + t_loss = torch.norm(trans_est - gt_t_vec,dim=1) # Torch norm already does sqrt (p=1 for no sqrt) + + elif self.trans_loss_type == 2: + x2_reconstruct_estimated = torch.matmul(rot_est, x_in[:, 0, :, 0:3].transpose(1, 2)) + trans_est + x2_reconstruct_gt = torch.matmul(gt_rot_mat, x_in[:, 0, :, 0:3].transpose(1, 2)) + gt_t_vec + + r_loss = torch.mean(torch.norm(x2_reconstruct_estimated - x2_reconstruct_gt, dim=1), dim=1) + t_loss = torch.zeros_like(r_loss) + + elif self.trans_loss_type == 3: + x2_reconstruct_estimated = torch.matmul(rot_est, x_in[:, 0, :, 0:3].transpose(1, 2)) + trans_est + x2_reconstruct_gt = torch.matmul(gt_rot_mat, x_in[:, 0, :, 0:3].transpose(1, 2)) + gt_t_vec + + r_loss = torch.mean(torch.sum(torch.abs(x2_reconstruct_estimated - x2_reconstruct_gt), dim=1), dim=1) + t_loss = torch.zeros_like(r_loss) + + return r_loss, t_loss + + def evaluate(self, global_step, data, rot_est, trans_est): + """ + Evaluates the pairwise loss function based on the current values + + Args: + global_step (int): current training iteration (used for controling which parts of the loss are used in the current iter) [1] + data (dict): input data of the current batch + rot_est (torch tensor): rotation matrices estimated based on the current scores [b,3,3] + trans_est (torch tensor): translation vectors estimated based on the current scores [b,3,1] + + Return: + loss (torch tensor): mean transformation loss of the current iteration over the batch + loss_raw (torch tensor): mean transformation loss of the current iteration (return value for tenbsorboard before the trans loss is plugged in ) + """ + + # Extract the current data + x_in, gt_R, gt_t = data['xs'].to(self.device), data['R'].to(self.device), data['t'].to(self.device) + gt_inlier_ratio = data['inlier_ratio'].to(self.device) + + # Compute the transformation loss + r_loss, t_loss = self.trans_loss(x_in, rot_est, trans_est, gt_R, gt_t) + + # Extract indices of pairs with a minimum inlier ratio (do not propagate Transformation loss if point clouds do not overlap) + idx_inlier_ratio = gt_inlier_ratio > self.inlier_threshold + inlier_ratio_mask = torch.zeros_like(r_loss) + inlier_ratio_mask[idx_inlier_ratio] = 1 + + loss_raw = torch.tensor([0.]).to(self.device) + + + if self.w_trans > 0: + r_loss *= inlier_ratio_mask + t_loss *= inlier_ratio_mask + + loss_raw += torch.mean(torch.min(self.w_trans * (r_loss + t_loss), self.trans_loss_margin * torch.ones_like(t_loss))) + + # Check global_step and add essential loss + loss = loss_raw if global_step >= self.trans_loss_iter else torch.tensor([0.]).to(self.device) + + return loss, loss_raw \ No newline at end of file diff --git a/lib/pairwise/__init__.py b/lib/pairwise/__init__.py new file mode 100644 index 0000000..a113d05 --- /dev/null +++ b/lib/pairwise/__init__.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import MinkowskiEngine as ME +from lib.layers import Soft_NN, Sampler +from lib.utils import extract_overlaping_pairs, extract_mutuals, construct_filtering_input_data +from lib.pairwise import ( + config, training +) + +__all__ = [ + config, training +] + + +class PairwiseReg(nn.Module): + """ + Pairwise registration class. + + It cobmines a feature descriptor with a filtering network and differentiable Kabsch algorithm to estimate + the transformation parameters of two point clouds. + + Args: + descriptor_module (nn.Module): feature descriptor network + filtering_module (nn.Module): filtering (outlier detection) network + corr_type (string): type of the correspondences to be used (hard, soft, Gumble-Softmax) + device (device): torch device + mutuals_flag (bool): if mutual nearest neighbors should be used + + Returns: + + + """ + + def __init__(self, descriptor_module, + filtering_module, device, samp_type='fps', + corr_type = 'soft', mutuals_flag=False, + connectivity_info=None, tgt_num_points=2000, + straight_through_gradient=True): + super().__init__() + + self.device = device + self.samp_type = samp_type + self.corr_type = corr_type + + self.mutuals_flag = mutuals_flag + + + self.connectivity_info = connectivity_info + + self.descriptor_module = descriptor_module + + # If the descriptor module is not specified, precomputed descriptor data should be used + if self.descriptor_module: + self.sampler = Sampler(samp_type=self.samp_type, targeted_num_points=tgt_num_points) + self.feature_matching = Soft_NN(corr_type=self.corr_type, st=straight_through_gradient) + self.precomputed_desc = False + else: + self.precomputed_desc = True + + self.filtering_module = filtering_module + + def forward(self, data): + + filtering_input, f_0, f_1 = self.compute_descriptors(input_dict=data) + + registration_outputs = self.filter_correspondences(filtering_input) + + return filtering_input, f_0, f_1, registration_outputs + + + + + def compute_descriptors(self, input_dict): + ''' + If not precomputed it infers the feature descriptors and returns the established correspondences + together with the ground truth transformation parameters and inlier labels. + + Args: + input_dict (dict): input data + + ''' + + if not self.precomputed_desc: + + xyz_down = input_dict['sinput0_C'] + + sinput0 = ME.SparseTensor( + input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device) + + F0 = self.descriptor_module(sinput0).F + + # If the FCGF descriptor should be trained with the FCGF loss (need also corresponding desc.) + if self.train_descriptor: + sinput1 = ME.SparseTensor( + input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device) + + F1 = self.descriptor_module(sinput1).F + else: + F1 = torch.empty(F0.shape[0], 0).to(self.device) + + + # Sample the points + xyz_batch, f_batch = self.sampler(xyz_down, F0, input_dict['pts_list']) + + # Build point cloud pairs for the inference + xyz_s, xyz_t, f_s, f_t = extract_overlaping_pairs(xyz_batch, f_batch, self.connectivity_info) + + # Compute nearest neighbors in feature space + nn_C_s_t = self.feature_matching(f_s, f_t, xyz_t) # NNs of the source points in the target point cloud + nn_C_t_s = self.feature_matching(f_t, f_s, xyz_s) # NNs of the target points in the source point cloud + + + if self.mutuals_flag: + mutuals = extract_mutuals(xyz_s, xyz_t, nn_C_s_t, nn_C_t_s) + else: + mutuals = None + + # Prepare the input for the filtering block + filtering_input = construct_filtering_input_data(xyz_s, xyz_t, input_dict, self.mutuals) + + else: + filtering_input = input_dict + F0 = None + F1 = None + + return filtering_input, F0, F1 + + + + def filter_correspondences(self, input_dict): + ''' + Return the infered weights together with the pairwise rotation matrices nad translation vectors. + + Args: + input_dict (dict): input data + + ''' + + registration_outputs = self.filtering_module(input_dict) + + return registration_outputs \ No newline at end of file diff --git a/lib/pairwise/__pycache__/__init__.cpython-36.pyc b/lib/pairwise/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..e682d2c Binary files /dev/null and b/lib/pairwise/__pycache__/__init__.cpython-36.pyc differ diff --git a/lib/pairwise/__pycache__/config.cpython-36.pyc b/lib/pairwise/__pycache__/config.cpython-36.pyc new file mode 100644 index 0000000..4c1d349 Binary files /dev/null and b/lib/pairwise/__pycache__/config.cpython-36.pyc differ diff --git a/lib/pairwise/__pycache__/training.cpython-36.pyc b/lib/pairwise/__pycache__/training.cpython-36.pyc new file mode 100644 index 0000000..6013895 Binary files /dev/null and b/lib/pairwise/__pycache__/training.cpython-36.pyc differ diff --git a/lib/pairwise/config.py b/lib/pairwise/config.py new file mode 100644 index 0000000..ae1a206 --- /dev/null +++ b/lib/pairwise/config.py @@ -0,0 +1,80 @@ +import torch +from lib.descriptor import descriptor_dict +from lib.filtering import filtering_dict +from lib.pairwise import training +from lib import pairwise + +def get_model(cfg, device): + ''' + Returns the model instance. + + Args: + cfg (dict): config dictionary + device (device): pytorch device + + Returns: + model (nn.Module): instance of the selected model class initialized based on the paramaters + ''' + + # Shortcuts + sampling_type = cfg['train']['samp_type'] + correspondence_type = cfg['train']['corr_type'] + connectivity_info = None + tgt_num_points = cfg['data']['max_num_points'] + st_grad = cfg['train']['st_grad_flag'] + + # Get individual components + filtering_module = get_filter(cfg, device) + descriptor_module = get_descriptor(cfg, device) + + + model = pairwise.PairwiseReg(descriptor_module=descriptor_module, filtering_module=filtering_module, + device=device, samp_type=sampling_type, corr_type = correspondence_type, + connectivity_info=connectivity_info,tgt_num_points=tgt_num_points, + straight_through_gradient=st_grad) + + return model + + +def get_descriptor(cfg, device): + descriptor_module = cfg['method']['descriptor_module'] + + if descriptor_module: + # We always keep the default parameters of FCGF + descriptor_module = descriptor_dict[descriptor_module]().to(device) + else: + descriptor_module = None + + return descriptor_module + + +def get_filter(cfg, device): + filter_module = cfg['method']['filter_module'] + + if filter_module: + filter_module = filtering_dict[filter_module](cfg).to(device) + else: + filter_module = None + + return filter_module + + +def get_trainer(cfg, model, optimizer, logger, device): + ''' + Returns a pairwise registration trainer instance. + + Args: + cfg (dict): configuration paramters + model (nn.Module): PairwiseReg model + optimizer (optimizer): PyTorch optimizer + logger (logger instance): logger used to output info to the consol + device (device): PyTorch device + + Return + trainer (trainer instace): Trainer used to train the pairwise registration model + ''' + + trainer = training.Trainer(cfg, model, optimizer, logger, device) + + return trainer + diff --git a/lib/pairwise/training.py b/lib/pairwise/training.py new file mode 100644 index 0000000..b2a9a9c --- /dev/null +++ b/lib/pairwise/training.py @@ -0,0 +1,226 @@ +import torch +import os +import torch.utils.data as data +import numpy as np +from tqdm import tqdm +from collections import defaultdict + +from lib.loss import DescriptorLoss, TransformationLoss, ClassificationLoss + + +class Trainer(): + ''' + Trainer class of the pairwise registration network. + + Args: + cfg (dict): configuration parameters + model (nn.Module): PairwiseReg model + optimizer (optimizer): PyTorch optimizer + tboard_logger (tensorboardx instance): TensorboardX logger used to track train and val stats + device (pytorch device) + ''' + + def __init__(self, cfg, model, optimizer, tboard_logger, device): + + self.model = model + self.optimizer = optimizer + self.tboard_logger = tboard_logger + self.device = device + self.loss_desc = cfg['loss']['loss_desc'] + + # Initialize the loss classes based on the input paramaters + self.DescriptorLoss = DescriptorLoss(cfg) + self.ClassificationLoss = ClassificationLoss(cfg) + self.TransformationLoss = TransformationLoss(cfg) + + + + def train_step(self, data, global_step): + ''' + Performs a single training step. + + Args: + data (dict): data dictionary + global_step (int): current training iteration + + ''' + + self.model.train() + self.optimizer.zero_grad() + backprop_flag = False + + loss, gradient_flag = self.compute_loss(data, global_step) + loss.backward() + + # Only update the parameters if there were no problems in the forward pass (mostly SVD) + # Check if any of the gradients is NaN + for name, param in self.model.named_parameters(): + if param.grad is not None: + if torch.any(torch.isnan(param.grad)): + print('Gradients include NaN values. Parameters will not be updated.') + backprop_flag = True + break + + if not (backprop_flag or gradient_flag): + self.optimizer.step() + + return loss.item() + + + def eval_step(self, data, global_step): + ''' + Performs a single evaluation step. + + Args: + data (dict): data dictionary + global_step (int): current training iteration + + Return? + eval_dict (dict): evaluation data of the current val batch + + ''' + + self.model.eval() + eval_dict = {} + + with torch.no_grad(): + # Extract the feature descriptors and correspondences + filtering_input, F0, F1 = self.model.compute_descriptors(data) + # Filter the correspondences and estimate the pairwise transformation parameters + filtered_output = self.model.filter_correspondences(filtering_input) + + # Losses + # Descriptor loss + desc_loss = torch.tensor([0.]).to(self.device) + if self.loss_desc and F0 is not None: + desc_loss = self.DescriptorLoss.evaluate(F0, F1, data['correspondences']) + eval_dict['desc_loss'] = desc_loss + + # Classification and transformation loss + class_loss = [] + trans_loss = [] + trans_loss_raw = [] + + for i in range(len(filtered_output['rot_est'])): + # Classification loss + temp_class_loss, precision, recall = self.ClassificationLoss.evaluate(filtered_output['logits'][i], filtering_input['ys'], filtered_output['scores'][i]) + class_loss.append(temp_class_loss) + + # Transformation loss + temp_trans_loss, temp_trans_loss_raw = self.TransformationLoss.evaluate(global_step, filtering_input, filtered_output['rot_est'][i], filtered_output['trans_est'][i]) + trans_loss.append(temp_trans_loss) + trans_loss_raw.append(temp_trans_loss_raw) + + + trans_loss_raw = torch.mean(torch.stack(trans_loss_raw)) + class_loss = torch.mean(torch.stack(class_loss)) + trans_loss = torch.mean(torch.stack(trans_loss)) + + loss = desc_loss + class_loss + trans_loss + + + eval_dict['class_loss'] = class_loss.item() + eval_dict['trans_loss'] = trans_loss_raw.item() + eval_dict['loss'] = loss.item() + + # If precision and recall stats are computed add them to the stats + if precision: + eval_dict['precision'] = precision + eval_dict['recall'] = recall + + return eval_dict + + + + def compute_loss(self, data, global_step): + ''' + Computes the combined loss (descriptor, classification, and transformation). + + Args: + data (dict): data dictionary + global_step (int): current training iteration + + Return: + loss (torch tensor): combined loss values of the current batch + gradient_flag (bool): flag denoting if the SVD estimation had any problem + + ''' + + # Extract the feature descriptors and correspondences + filtering_input, F0, F1 = self.model.compute_descriptors(data) + + # Filter the correspondences and estimate the pairwise transformation parameters + filtered_output = self.model.filter_correspondences(filtering_input) + + # Losses + # Descriptor loss + desc_loss = torch.tensor([0.]).to(self.device) + if self.loss_desc and F0 is not None: + desc_loss = self.DescriptorLoss.evaluate(F0, F1, data['correspondences']) + self.tboard_logger.add_scalar('train/desc_loss', desc_loss, global_step) + + # Classification and transformation loss + class_loss_iter = [] + trans_loss_iter = [] + trans_loss_raw_iter = [] + precision_iter = [] + recall_iter = [] + + for i in range(len(filtered_output['rot_est'])): + # Classification loss + temp_class_loss, precision, recall = self.ClassificationLoss.evaluate(filtered_output['logits'][i], filtering_input['ys'], filtered_output['scores'][i]) + class_loss_iter.append(temp_class_loss) + precision_iter.append(precision) + recall_iter.append(recall) + + # Transformation loss + temp_trans_loss, temp_trans_loss_raw = self.TransformationLoss.evaluate(global_step, filtering_input, filtered_output['rot_est'][i], filtered_output['trans_est'][i]) + trans_loss_iter.append(temp_trans_loss) + trans_loss_raw_iter.append(temp_trans_loss_raw) + + trans_loss_raw = torch.mean(torch.stack(trans_loss_raw_iter)) + class_loss = torch.mean(torch.stack(class_loss_iter)) + trans_loss = torch.mean(torch.stack(trans_loss_iter)) + + loss = desc_loss + class_loss + trans_loss + + # Print out the stats + for i in range(len(class_loss_iter)): + self.tboard_logger.add_scalar('train/class_loss_{}'.format(i), class_loss_iter[i].item(), global_step) + self.tboard_logger.add_scalar('train/trans_loss_{}'.format(i), trans_loss_iter[i].item(), global_step) + self.tboard_logger.add_scalar('train/trans_loss_raw_{}'.format(i), trans_loss_raw_iter[i].item(), global_step) + + # If precision and recall are computed, log them + if precision: + self.tboard_logger.add_scalar('train/precision_{}'.format(i), precision_iter[i], global_step) + self.tboard_logger.add_scalar('train/recall_{}'.format(i), recall_iter[i], global_step) + + return loss, filtered_output['gradient_flag'] + + + + + def evaluate(self, val_loader, global_step): + ''' + Performs the evaluation over the whole evaluation dataset. + Args: + val_loader (Pytorch dataloader): dataloader of the validation dataset + global_step (int): current iteration + + Returns: + eval_dict (defaultdict): evaluation values for the current validation epoch + ''' + + + eval_list = defaultdict(list) + + for data in tqdm(val_loader): + eval_step_dict = self.eval_step(data, global_step) + + for k, v in eval_step_dict.items(): + eval_list[k].append(v) + + + eval_dict = {k: np.mean(v) for k, v in eval_list.items()} + + return eval_dict diff --git a/lib/utils.py b/lib/utils.py new file mode 100644 index 0000000..5d3e763 --- /dev/null +++ b/lib/utils.py @@ -0,0 +1,1041 @@ +import math +import os +import copy +import torch +import re +import numpy as np +import torch.nn.functional as F +import open3d as o3d +import nibabel.quaternions as nq +import logging +import yaml +import time + +from torch import optim +from itertools import combinations +from torch.distributions import normal +from sklearn.neighbors import NearestNeighbors + +def load_config(path): + """ + Loads config file: + + Args: + path (str): path to the config file + + Returns: + config (dict): dictionary of the configuration parameters + + """ + with open(path,'r') as f: + cfg = yaml.safe_load(f) + + return cfg + + +def load_point_cloud(file, data_type='numpy'): + """ + Loads the point cloud coordinates from the '*.ply' file. + + Args: + file (str): path to the '*.ply' file + data_type (str): data type to be reurned (default: numpy) + + Returns: + ae (torch tensor): Rotation error in angular degreees [b,1] + + """ + temp_pc = o3d.io.read_point_cloud(file) + + assert data_type in ['numpy', 'open3d'], 'Wrong data type selected when loading the ply file.' + + if data_type == 'numpy': + return np.asarray(temp_pc.points) + else: + return temp_pc + +def sorted_alphanum(file_list_ordered): + """ + Sorts the list alphanumerically + + Args: + file_list_ordered (list): list of files to be sorted + + Return: + sorted_list (list): input list sorted alphanumerically + """ + def convert(text): + return int(text) if text.isdigit() else text + + def alphanum_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + sorted_list = sorted(file_list_ordered, key=alphanum_key) + + return sorted_list + + + +def get_file_list(path, extension=None): + """ + Build a list of all the files in the provided path + + Args: + path (str): path to the directory + extension (str): only return files with this extension + + Return: + file_list (list): list of all the files (with the provided extension) sorted alphanumerically + """ + if extension is None: + file_list = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] + else: + file_list = [ + os.path.join(path, f) + for f in os.listdir(path) + if os.path.isfile(os.path.join(path, f)) and os.path.splitext(f)[1] == extension + ] + file_list = sorted_alphanum(file_list) + + return file_list + + +def get_folder_list(path): + """ + Build a list of all the files in the provided path + + Args: + path (str): path to the directory + extension (str): only return files with this extension + + Returns: + file_list (list): list of all the files (with the provided extension) sorted alphanumerically + """ + folder_list = [os.path.join(path, f) for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))] + folder_list = sorted_alphanum(folder_list) + + return folder_list + + + + +def rotation_error(R1, R2): + """ + Torch batch implementation of the rotation error between the estimated and the ground truth rotatiom matrix. + Rotation error is defined as r_e = \arccos(\frac{Trace(\mathbf{R}_{ij}^{T}\mathbf{R}_{ij}^{\mathrm{GT}) - 1}{2}) + + Args: + R1 (torch tensor): Estimated rotation matrices [b,3,3] + R2 (torch tensor): Ground truth rotation matrices [b,3,3] + + Returns: + ae (torch tensor): Rotation error in angular degreees [b,1] + + """ + R_ = torch.matmul(R1.transpose(1,2), R2) + e = torch.stack([(torch.trace(R_[_, :, :]) - 1) / 2 for _ in range(R_.shape[0])], dim=0).unsqueeze(1) + + # Clamp the errors to the valid range (otherwise torch.acos() is nan) + e = torch.clamp(e, -1, 1, out=None) + + ae = torch.acos(e) + pi = torch.Tensor([math.pi]) + ae = 180. * ae / pi.to(ae.device).type(ae.dtype) + + return ae + + +def translation_error(t1, t2): + """ + Torch batch implementation of the rotation error between the estimated and the ground truth rotatiom matrix. + Rotation error is defined as r_e = \arccos(\frac{Trace(\mathbf{R}_{ij}^{T}\mathbf{R}_{ij}^{\mathrm{GT}) - 1}{2}) + + Args: + t1 (torch tensor): Estimated translation vectors [b,3,1] + t2 (torch tensor): Ground truth translation vectors [b,3,1] + + Returns: + te (torch tensor): translation error in meters [b,1] + + """ + return torch.norm(t1-t2, dim=(1, 2)) + + +def kabsch_transformation_estimation(x1, x2, weights=None, normalize_w = True, eps = 1e-7, best_k = 0, w_threshold = 0): + """ + Torch differentiable implementation of the weighted Kabsch algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm). Based on the correspondences and weights calculates + the optimal rotation matrix in the sense of the Frobenius norm (RMSD), based on the estimate rotation matrix is then estimates the translation vector hence solving + the Procrustes problem. This implementation supports batch inputs. + + Args: + x1 (torch array): points of the first point cloud [b,n,3] + x2 (torch array): correspondences for the PC1 established in the feature space [b,n,3] + weights (torch array): weights denoting if the coorespondence is an inlier (~1) or an outlier (~0) [b,n] + normalize_w (bool) : flag for normalizing the weights to sum to 1 + best_k (int) : number of correspondences with highest weights to be used (if 0 all are used) + w_threshold (float) : only use weights higher than this w_threshold (if 0 all are used) + Returns: + rot_matrices (torch array): estimated rotation matrices [b,3,3] + trans_vectors (torch array): estimated translation vectors [b,3,1] + res (torch array): pointwise residuals (Eucledean distance) [b,n] + valid_gradient (bool): Flag denoting if the SVD computation converged (gradient is valid) + + """ + if weights is None: + weights = torch.ones(x1.shape[0],x1.shape[1]).type_as(x1).to(x1.device) + + if normalize_w: + sum_weights = torch.sum(weights,dim=1,keepdim=True) + eps + weights = (weights/sum_weights) + + weights = weights.unsqueeze(2) + + if best_k > 0: + indices = np.argpartition(weights.cpu().numpy(), -best_k, axis=1)[0,-best_k:,0] + weights = weights[:,indices,:] + x1 = x1[:,indices,:] + x2 = x2[:,indices,:] + + if w_threshold > 0: + weights[weights < w_threshold] = 0 + + + x1_mean = torch.matmul(weights.transpose(1,2), x1) / (torch.sum(weights, dim=1).unsqueeze(1) + eps) + x2_mean = torch.matmul(weights.transpose(1,2), x2) / (torch.sum(weights, dim=1).unsqueeze(1) + eps) + + x1_centered = x1 - x1_mean + x2_centered = x2 - x2_mean + + weight_matrix = torch.diag_embed(weights.squeeze(2)) + + cov_mat = torch.matmul(x1_centered.transpose(1, 2), + torch.matmul(weight_matrix, x2_centered)) + + try: + u, s, v = torch.svd(cov_mat) + except Exception as e: + r = torch.eye(3,device=x1.device) + r = r.repeat(x1_mean.shape[0],1,1) + t = torch.zeros((x1_mean.shape[0],3,1), device=x1.device) + + res = transformation_residuals(x1, x2, r, t) + + return r, t, res, True + + tm_determinant = torch.det(torch.matmul(v.transpose(1, 2), u.transpose(1, 2))) + + determinant_matrix = torch.diag_embed(torch.cat((torch.ones((tm_determinant.shape[0],2),device=x1.device), tm_determinant.unsqueeze(1)), 1)) + + rotation_matrix = torch.matmul(v,torch.matmul(determinant_matrix,u.transpose(1,2))) + + # translation vector + translation_matrix = x2_mean.transpose(1,2) - torch.matmul(rotation_matrix,x1_mean.transpose(1,2)) + + # Residuals + res = transformation_residuals(x1, x2, rotation_matrix, translation_matrix) + + return rotation_matrix, translation_matrix, res, False + + +def transformation_residuals(x1, x2, R, t): + """ + Computer the pointwise residuals based on the estimated transformation paramaters + + Args: + x1 (torch array): points of the first point cloud [b,n,3] + x2 (torch array): points of the second point cloud [b,n,3] + R (torch array): estimated rotation matrice [b,3,3] + t (torch array): estimated translation vectors [b,3,1] + Returns: + res (torch array): pointwise residuals (Eucledean distance) [b,n,1] + """ + x2_reconstruct = torch.matmul(R, x1.transpose(1, 2)) + t + + res = torch.norm(x2_reconstruct.transpose(1, 2) - x2, dim=2) + + return res + +def transform_point_cloud(x1, R, t): + """ + Transforms the point cloud using the giver transformation paramaters + + Args: + x1 (np array): points of the point cloud [b,n,3] + R (np array): estimated rotation matrice [b,3,3] + t (np array): estimated translation vectors [b,3,1] + Returns: + x1_t (np array): points of the transformed point clouds [b,n,3] + """ + x1_t = (torch.matmul(R, x1.transpose(0,2,1)) + t).transpose(0,2,1) + + return x1_t + + +def knn_point(k, pos1, pos2): + ''' + Performs the k nearest neighbors search with CUDA support + + Args: + k (int): number of k in k-nn search + pos1: (torch tensor) input points [b,n,c] + pos2: (torch tensor) float32 array, query points [b,m,c] + + Returns: + val: (torch tensor) squared L2 distances [b,m,k] + idx: (torch tensor) indices of the k nearest points [b,m,1] + + ''' + + B, N, C = pos1.shape + M = pos2.shape[1] + + pos1 = pos1.view(B,1,N,-1).repeat(1,M,1,1) + pos2 = pos2.view(B,M,1,-1).repeat(1,1,N,1) + + dist = torch.sum(-(pos1 - pos2)**2, -1) + + val, idx = dist.topk(k=k,dim = -1) + + return -val, idx + +def axis_angle_to_rot_mat(axes, thetas): + """ + Computer a rotation matrix from the axis-angle representation using the Rodrigues formula. + \mathbf{R} = \mathbf{I} + (sin(\theta)\mathbf{K} + (1 - cos(\theta)\mathbf{K}^2), where K = \mathbf{I} \cross \frac{\mathbf{K}}{||\mathbf{K}||} + + Args: + axes (numpy array): array of axes used to compute the rotation matrices [b,3] + thetas (numpy array): array of angles used to compute the rotation matrices [b,1] + + Returns: + rot_matrices (numpy array): array of the rotation matrices computed from the angle, axis representation [b,3,3] + + """ + + R = [] + for k in range(axes.shape[0]): + K = np.cross(np.eye(3), axes[k,:]/np.linalg.norm(axes[k,:])) + R.append( np.eye(3) + np.sin(thetas[k])*K + (1 - np.cos(thetas[k])) * np.matmul(K,K)) + + rot_matrices = np.stack(R) + return rot_matrices + + +def sample_random_trans(pcd, randg=None, rotation_range=360): + """ + Samples random transformation paramaters with the rotaitons limited to the rotation range + + Args: + pcd (numpy array): numpy array of coordinates for which the transformation paramaters are sampled [n,3] + randg (numpy random generator): numpy random generator + + Returns: + T (numpy array): sampled transformation paramaters [4,4] + + """ + if randg == None: + randg = np.random.default_rng(41) + + # Create 3D identity matrix + T = np.zeros((4,4)) + idx = np.arange(4) + T[idx,idx] = 1 + + axes = np.random.rand(1,3) - 0.5 + + angles = rotation_range * np.pi / 180.0 * (np.random.rand(1,1) - 0.5) + + R = axis_angle_to_rot_mat(axes, angles) + + T[:3, :3] = R + T[:3, 3] = np.matmul(R,-np.mean(pcd, axis=0)) + + return T + + +def augment_precomputed_data(x,R,t, max_angle=360.0): + """ + Function used for data augmention (random transformation) in the training process. It transforms the point from PC1 with a randomly sampled + transformation matrix and updates the ground truth rotation and translation, respectively. + + Args: + x (np.array): coordinates of the correspondences [n,6] + R (np.array): gt rotation matrix [3,3] + t (np.array): gt translation vector [3,1] + max_angle (float): maximum angle that should be used to sample the rotation matrix + + Returns: + t_data (numpy array): augmented coordinates of the correspondences [n, 6] + t_rs (numpy array): augmented rotation matrix [3,3] + t_ts (numpy array): augmented translation vector [3,1] + """ + + # Sample random transformation matrix for each example in the batch + T_rand = sample_random_trans(x[:, 0:3], np.random.RandomState(), max_angle) + + # Compute the updated ground truth transformation paramaters R_n = R_gt*R_s^-1, t_n = t_gt - R_gt*R_s^-1*t_s + rotation_matrix_inv = T_rand[:3,:3].transpose() + t_rs = np.matmul(R,rotation_matrix_inv) + t_ts = t - np.matmul(R, np.matmul(rotation_matrix_inv, T_rand[:3,2:3].reshape(-1,1))) + + # Transform the coordinates of the first point cloud with the sampled transformation parmaters + t_xs = (np.matmul(T_rand[:3,:3], x[:, 0:3].transpose()) + T_rand[:3,2:3].reshape(-1,1)).transpose() + + t_data = np.concatenate((t_xs,x[:, 3:6]),axis=-1) + + + return t_data, t_rs, t_ts + + +def add_jitter(x, R, t, std=0.01, clip=0.025): + """ + Function used to add jitter to the coordinates of the correspondences in the training process. + + Args: + x (np.array): coordinates of the correspondences [n,6] + R (np.array): gt rotation matrix [3,3] + t (np.array): gt translation vector [3,1] + std (float): standard deviation of the normal distribution used to sample the jitter + clip (float): cut-off value for the jitter + + Returns: + x (np.array): coordinates of the correspondences with added jitter [n,6] + y (np.array): gt residuals of the correspondences aftter jitter [n] + """ + + jitter = np.clip(np.random.normal(0.0, scale=std, size=(x.shape[0], x.shape[1])), + a_min=-clip, a_max=clip) + + x += jitter # Add noise to xyz + + # Compute new ys + temp_x = (np.matmul(R,x[:,0:3].transpose()) + t).transpose() + y = np.sqrt(np.sum((x[:,3:6]-temp_x)**2,1)) + + return x, y + + +class ClippedStepLR(optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, step_size, min_lr, gamma=0.1, last_epoch=-1): + self.step_size = step_size + self.min_lr = min_lr + self.gamma = gamma + super(ClippedStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + return [max(base_lr * self.gamma ** (self.last_epoch // self.step_size), self.min_lr) + for base_lr in self.base_lrs] + + + +def ensure_dir(path): + """ + Creates the directory specigied by the input if it does not yet exist. + """ + if not os.path.exists(path): + os.makedirs(path, mode=0o755) + +def read_trajectory(filename, dim=4): + """ + Function that reads a trajectory saved in the 3DMatch/Redwood format to a numpy array. + Format specification can be found at http://redwood-data.org/indoor/fileformat.html + + Args: + filename (str): path to the '.txt' file containing the trajectory data + dim (int): dimension of the transformation matrix (4x4 for 3D data) + + Returns: + final_keys (dict): indices of pairs with more than 30% overlap (only this ones are included in the gt file) + traj (numpy array): gt pairwise transformation matrices for n pairs[n,dim, dim] + """ + + with open(filename) as f: + lines = f.readlines() + + # Extract the point cloud pairs + keys = lines[0::(dim+1)] + temp_keys = [] + for i in range(len(keys)): + temp_keys.append(keys[i].split('\t')[0:3]) + + final_keys = [] + for i in range(len(temp_keys)): + final_keys.append([temp_keys[i][0].strip(), temp_keys[i][1].strip(), temp_keys[i][2].strip()]) + + + traj = [] + for i in range(len(lines)): + if i % 5 != 0: + curr_line = '\t'.join(lines[i].split()) + traj.append(curr_line.split('\t')[0:dim]) + + traj = np.asarray(traj, dtype=np.float).reshape(-1,dim,dim) + + final_keys = np.asarray(final_keys) + + return final_keys, traj + + + +def write_trajectory(traj,metadata, filename, dim=4): + """ + Writes the trajectory into a '.txt' file in 3DMatch/Redwood format. + Format specification can be found at http://redwood-data.org/indoor/fileformat.html + + Args: + traj (numpy array): trajectory for n pairs[n,dim, dim] + metadata (numpy array): file containing metadata about fragment numbers [n,3] + filename (str): path where to save the '.txt' file containing trajectory data + dim (int): dimension of the transformation matrix (4x4 for 3D data) + """ + + with open(filename, 'w') as f: + for idx in range(traj.shape[0]): + # Only save the transfromation parameters for which the overlap threshold was satisfied + if metadata[idx][2]: + p = traj[idx,:,:].tolist() + f.write('\t'.join(map(str, metadata[idx])) + '\n') + f.write('\n'.join('\t'.join(map('{0:.12f}'.format, p[i])) for i in range(dim))) + f.write('\n') + + +def read_trajectory_info(filename, dim=6): + """ + Function that reads the trajectory information saved in the 3DMatch/Redwood format to a numpy array. + Information file contains the variance-covariance matrix of the transformation paramaters. + Format specification can be found at http://redwood-data.org/indoor/fileformat.html + + Args: + filename (str): path to the '.txt' file containing the trajectory information data + dim (int): dimension of the transformation matrix (4x4 for 3D data) + + Returns: + n_frame (int): number of fragments in the scene + cov_matrix (numpy array): covariance matrix of the transformation matrices for n pairs[n,dim, dim] + """ + + with open(filename) as fid: + contents = fid.readlines() + n_pairs = len(contents) // 7 + assert (len(contents) == 7 * n_pairs) + info_list = [] + n_frame = 0 + + for i in range(n_pairs): + frame_idx0, frame_idx1, n_frame = [int(item) for item in contents[i * 7].strip().split()] + info_matrix = np.concatenate( + [np.fromstring(item, sep='\t').reshape(1, -1) for item in contents[i * 7 + 1:i * 7 + 7]], axis=0) + info_list.append(info_matrix) + + cov_matrix = np.asarray(info_list, dtype=np.float).reshape(-1,dim,dim) + + return n_frame, cov_matrix + +def extract_corresponding_trajectors(est_pairs,gt_pairs, est_traj, gt_traj): + """ + Extract only those transformation matrices from the estimated trajectory that are also in the GT trajectory. + + Args: + est_pairs (numpy array): indices of point cloud pairs with enough estimated overlap [m, 3] + gt_pairs (numpy array): indices of gt overlaping point cloud pairs [n,3] + est_traj (numpy array): 3d array of the estimated transformation parameters [m,4,4] + gt_traj (numpy array): 3d array of the gt transformation parameters [n,4,4] + + Returns: + ext_traj_est (numpy array): extracted est transformation parameters for the point cloud pairs from [k,4,4] + ext_traj_gt (numpy array): extracted gt transformation parameters for the point cloud pairs from est_pairs [k,4,4] + """ + ext_traj_est = [] + ext_traj_gt = [] + + est_pairs = est_pairs[:,0:2] + gt_pairs = gt_pairs[:,0:2] + + for gt_idx, pair in enumerate(gt_pairs): + est_idx = np.where((est_pairs == pair).all(axis=1))[0] + if est_idx.size: + ext_traj_gt.append(gt_traj[gt_idx,:,:]) + ext_traj_est.append(est_traj[est_idx[0],:,:]) + + return np.stack(ext_traj_est, axis=0), np.stack(ext_traj_gt, axis=0) + +def computeTransformationErr(trans, info): + """ + Computer the transformation error as an approximation of the RMSE of corresponding points. + More informaiton at http://redwood-data.org/indoor/registration.html + + Args: + trans (numpy array): transformation matrices [n,4,4] + info (numpy array): covariance matrices of the gt transformation paramaters [n,4,4] + + Returns: + p (float): transformation error + """ + + t = trans[:3, 3] + r = trans[:3, :3] + q = nq.mat2quat(r) + er = np.concatenate([t, q[1:]], axis=0) + p = er.reshape(1, 6) @ info @ er.reshape(6, 1) / info[0, 0] + + return p.item() + + +def evaluate_registration(num_fragment, result, result_pairs, gt_pairs, gt, gt_info, err2=0.2): + """ + Evaluates the performance of the registration algorithm according to the evaluation protocol defined + by the 3DMatch/Redwood datasets. The evaluation protocol can be found at http://redwood-data.org/indoor/registration.html + + Args: + num_fragment (int): path to the '.txt' file containing the trajectory information data + result (numpy array): estimated transformation matrices [n,4,4] + result_pairs (numpy array): indices of the point cloud for which the transformation matrix was estimated (m,3) + gt_pairs (numpy array): indices of the ground truth overlapping point cloud pairs (n,3) + gt (numpy array): ground truth transformation matrices [n,4,4] + gt_cov (numpy array): covariance matrix of the ground truth transfromation parameters [n,6,6] + err2 (float): threshold for the RMSE of the gt correspondences (default: 0.2m) + + Returns: + precision (float): mean registration precision over the scene (not so important because it can be increased see papers) + recall (float): mean registration recall over the scene (deciding parameter for the performance of the algorithm) + """ + + err2 = err2 ** 2 + gt_mask = np.zeros((num_fragment, num_fragment), dtype=np.int) + + + + for idx in range(gt_pairs.shape[0]): + i = int(gt_pairs[idx,0]) + j = int(gt_pairs[idx,1]) + + # Only non consecutive pairs are tested + if j - i > 1: + gt_mask[i, j] = idx + + n_gt = np.sum(gt_mask > 0) + + good = 0 + n_res = 0 + for idx in range(result_pairs.shape[0]): + i = int(result_pairs[idx,0]) + j = int(result_pairs[idx,1]) + pose = result[idx,:,:] + + if j - i > 1: + n_res += 1 + if gt_mask[i, j] > 0: + gt_idx = gt_mask[i, j] + p = computeTransformationErr(np.linalg.inv(gt[gt_idx,:,:]) @ pose, gt_info[gt_idx,:,:]) + if p <= err2: + good += 1 + if n_res == 0: + n_res += 1e6 + precision = good * 1.0 / n_res + recall = good * 1.0 / n_gt + + return precision, recall + + + + +def do_single_pair_RANSAC_reg(xyz_i, xyz_j, pc_i, pc_j, voxel_size=0.025,method='3DMatch'): + """ + Runs a RANSAC registration pipeline for a single pair of point clouds. + + Args: + xyz_i (numpy array): coordinates of the correspondences from the first point cloud [n,3] + xyz_j (numpy array): coordinates of the correspondences from the second point cloud [n,3] + pc_i (numpy array): coordinates of all the points from the first point cloud [N,3] + pc_j (numpy array): coordinates of all the points from the second point cloud [N,3] + method (str): name of the method used for the overlap computation [3DMatch, FCGF] + + Returns: + overlap_flag (bool): flag denoting if overlap of the point cloud after aplying the estimated trans paramaters if more than a threshold + trans (numpy array): transformation parameters that trnasform point of point cloud 2 to the coordinate system of point cloud 1 + """ + + + trans = run_ransac(xyz_j, xyz_i) + + + ratio = compute_overlap_ratio(pc_i, pc_j, trans, method, voxel_size) + + overlap_flag = True if ratio > 0.3 else False + + + return [overlap_flag, trans] + + + +def run_ransac(xyz_i, xyz_j): + """ + Ransac based estimation of the transformation paramaters of the congurency transformation. Estimates the + transformation parameters thtat map xyz0 to xyz1. Implementation is based on the open3d library + (http://www.open3d.org/docs/release/python_api/open3d.registration.registration_ransac_based_on_correspondence.html) + + Args: + xyz_i (numpy array): coordinates of the correspondences from the first point cloud [n,3] + xyz_j (numpy array): coordinates of the correspondences from the second point cloud [n,3] + + Returns: + trans_param (float): mean registration precision over the scene (not so important because it can be increased see papers) + recall (float): mean registration recall over the scene (deciding parameter for the performance of the algorithm) + """ + + # Distance threshold as specificed by 3DMatch dataset + distance_threshold = 0.05 + + # Convert the point to an open3d PointCloud object + xyz0 = o3d.geometry.PointCloud() + xyz1 = o3d.geometry.PointCloud() + + xyz0.points = o3d.utility.Vector3dVector(xyz_i) + xyz1.points = o3d.utility.Vector3dVector(xyz_j) + + # Correspondences are already sorted + corr_idx = np.tile(np.expand_dims(np.arange(len(xyz0.points)),1),(1,2)) + corrs = o3d.utility.Vector2iVector(corr_idx) + + result_ransac = o3d.registration.registration_ransac_based_on_correspondence( + source=xyz0, target=xyz1,corres=corrs, + max_correspondence_distance=distance_threshold, + estimation_method=o3d.registration.TransformationEstimationPointToPoint(False), + ransac_n=4, + criteria=o3d.registration.RANSACConvergenceCriteria(50000, 2500)) + + trans_param = result_ransac.transformation + + return trans_param + + + +def compute_overlap_ratio(pc_i, pc_j, trans, method = '3DMatch', voxel_size=0.025): + """ + Computes the overlap percentage of the two point clouds using the estimateted transformation paramaters and based on the selected method. + Available methods are 3DMatch/Redwood as defined in the oficial dataset and the faster FCGF method that first downsamples the point clouds. + Method 3DMatch slightly deviates from the original implementation such that we take the max of the overlaps to check if it is above the threshold + where as in the original implementation only the overlap relative to PC1 is used. + + Args: + pc_i (numpy array): coordinates of all the points from the first point cloud [N,3] + pc_j (numpy array): coordinates of all the points from the second point cloud [N,3] + trans (numpy array): estimated transformation paramaters [4,4] + method (str): name of the method for overlap computation to be used ['3DMatch', 'FCGF'] + voxel size (float): voxel size used to downsample the point clouds when 'FCGF' method is selected + + Returns: + overlap (float): max of the computed overlap ratios relative to the PC1 and PC2 + + """ + neigh = NearestNeighbors(n_neighbors=1,algorithm='kd_tree') + trans_inv = np.linalg.inv(trans) + + if method == '3DMatch': + pc_i_t = (np.matmul(trans_inv[0:3, 0:3], pc_i.transpose()) + trans_inv[0:3, 3].reshape(-1, 1)).transpose() + pc_j_t = (np.matmul(trans[0:3, 0:3], pc_j.transpose()) + trans[0:3, 3].reshape(-1, 1)).transpose() + + neigh.fit(pc_j_t) + dist01, _ = neigh.kneighbors(pc_i, return_distance=True) + matching01 = np.where(dist01 < 0.05)[0].shape[0] + + neigh.fit(pc_i_t) + dist10, _ = neigh.kneighbors(pc_j, return_distance=True) + matching10 = np.where(dist10 < 0.05)[0].shape[0] + + overlap0 = matching01 / pc_i.shape[0] + overlap1 = matching10 / pc_j.shape[0] + + elif method == 'FCGF': + # Convert the point to an open3d PointCloud object + pcd0 = o3d.geometry.PointCloud() + pcd1 = o3d.geometry.PointCloud() + pcd0.points = o3d.utility.Vector3dVector(pc_i) + pcd1.points = o3d.utility.Vector3dVector(pc_j) + + pcd0_down = pcd0.voxel_down_sample(voxel_size) + pcd1_down = pcd1.voxel_down_sample(voxel_size) + + pc_i = np.array(pcd0_down.points) + pc_j = np.array(pcd1_down.points) + + pc_i_t = (np.matmul(trans_inv[0:3, 0:3], pc_i.transpose()) + trans_inv[0:3, 3].reshape(-1, 1)).transpose() + pc_j_t = (np.matmul(trans[0:3, 0:3], pc_j.transpose()) + trans[0:3, 3].reshape(-1, 1)).transpose() + + neigh.fit(pc_j_t) + dist01, _ = neigh.kneighbors(pc_i, return_distance=True) + matching01 = np.where(dist01 < 3*voxel_size)[0].shape[0] + + neigh.fit(pc_i_t) + dist10, _ = neigh.kneighbors(pc_j, return_distance=True) + matching10 = np.where(dist10 < 3*voxel_size)[0].shape[0] + + overlap0 = matching01 / pc_i.shape[0] + overlap1 = matching10 / pc_j.shape[0] + + # matching01 = get_matching_indices(pcd0_down, pcd1_down, np.linalg.inv(trans), search_voxel_size = 3*voxel_size, K=1) + # matching10 = get_matching_indices(pcd1_down, pcd0_down, trans, + # search_voxel_size = 3*voxel_size, K=1) + # overlap0 = len(matching01) / len(pcd0_down.points) + # overlap1 = len(matching10) / len(pcd1_down.points) + + else: + logging.error("Wrong overlap computation method was selected.") + + + return max(overlap0, overlap1) + + +def get_matching_indices(pc_i, pc_j, trans, search_voxel_size=0.025, K=None, method = 'FCGF'): + """ + Helper function for the point cloud overlap computation. Based on the estimated transformation parameters transforms the point cloud + and searches for the neares neighbor in the other point cloud. + + Args: + pc_i (numpy array): coordinates of all the points from the first point cloud [N,3] + pc_j (numpy array): coordinates of all the points from the second point cloud [N,3] + trans (numpy array): estimated transformation paramaters [4,4] + search_voxel_size (float): threshold used to determine if a point has a correspondence given the estimated trans parameters + K (int): number of nearest neighbors to be returned + + Returns: + match_inds (list): indices of points that have a correspondence withing the search_voxel_size + + """ + + pc_i_copy = copy.deepcopy(pc_i) + pc_j_copy = copy.deepcopy(pc_j) + pc_i_copy.transform(trans) + pcd_tree = o3d.geometry.KDTreeFlann(pc_j_copy) + + match_inds = [] + for i, point in enumerate(pc_i_copy.points): + [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) + if K is not None: + idx = idx[:K] + + for j in idx: + match_inds.append((i, j)) + + return match_inds + +def extract_mutuals(x1, x2, x1_soft_matches, x2_soft_matches, threshold=0.05): + ''' + Returns a flag if the two point are mutual nearest neighbors in the feature space. + In a softNN formulation a distance threshold has to be used. + + Args: + x1 (torch tensor): source point cloud [b,n,3] + x2 (torch tensor): target point cloud [b,n,3] + x1_soft_matches (torch tensor): coordinates of the (soft) correspondences for points x1 in x2 [b,n,3] + x2_soft_matches (torch tensor): coordinates of the (soft) correspondences for points x2 in x1 [b,n,3] + + Returns: + mutuals (torch tensor): mutual nearest neighbors flag (1 if mutual NN otherwise 0) [b,n] + + ''' + + B, N, C = x1.shape + + _, idx = knn_point(k=1, pos1=x2, pos2=x1_soft_matches) + + delta = x1 - torch.gather(x2_soft_matches,index=idx.expand(-1,-1,C),dim=1) + dist = torch.pow(delta,2).sum(dim=2) + + mutuals = torch.zeros((B,N)) + mutuals[dist < threshold**2] = 1 + + return mutuals + +def extract_overlaping_pairs(xyz, feat, conectivity_info=None): + """ + Build the point cloud pairs based either on the provided conectivity information or sample + all n choose 2 posibilities + + Args: + xyz (torch tensor): coordinates of the sampled points [b,n,3] + feat (torch tensor): features of the sampled points [b,n,c] + conectivity_info (torch tensor): conectivity information (indices of overlapping pairs) [m,2] + + Returns: + xyz_s (torch tensor): coordinates of the points in the source point clouds [B, n, 3] (B != b) + xyz_t (torch tensor): coordinates of the points in the target point clouds [B, n, 3] (B != b) + f_s (torch tensor): features of the points in the source point clouds [B, n, 3] (B != b) + f_t (torch tensor): features of the points in the target point clouds [B, n, 3] (B != b) + + """ + + if not conectivity_info: + + pairs = [] + + # If no conectivity information is provided sample n choose 2 pairs + for comb in list(combinations(range(xyz.shape[0]), 2)): + pairs.append(torch.tensor([int(comb[0]), int(comb[1])])) + + conectivity_info = torch.stack(pairs, dim=0).to(xyz) + + # Build the point cloud pairs based on the conectivity information + xyz_s = torch.index_select(xyz, dim=0, index=conectivity_info[:, 0]) + xyz_t = torch.index_select(xyz, dim=0, index=conectivity_info[:, 1]) + + f_s = torch.index_select(feat, dim=0, index=conectivity_info[:, 0]) + f_t = torch.index_select(feat, dim=0, index=conectivity_info[:, 1]) + + return xyz_s, xyz_t, f_s, f_t + + +def construct_filtering_input_data(xyz_s, xyz_t, data, overlapped_pair_tensors, dist_th=0.05, mutuals_flag=None): + """ + Prepares the input dictionary for the filtering network + + Args: + xyz_s (torch tensor): coordinates of the sampled points in the source point cloud [b,n,3] + xyz_t (torch tensor): coordinates of the correspondences from the traget point cloud [b,n,3] + data (dict): input data from the data loader + dist_th (float): distance threshold to determine if the correspondence is an inlier or an outlier + mutuals (torch tensor): torch tensor of the mutually nearest neighbors (can be used as side information to the filtering network) + + Returns: + filtering_data (dict): input data for the filtering network + + """ + + filtering_data = {} + Rs, ts = extract_transformation_matrices(data['T_global_0'], overlapped_pair_tensors) + + + ys = transformation_residuals(xyz_s, xyz_t, Rs, ts) + + xs = torch.cat((xyz_s,xyz_t),dim=-1) # [b, n, 6] + + if mutuals_flag is not None: + xs = torch.cat((xs,mutuals_flag.reshape(-1,1)), dim=-1) # [b, n, 7] + + + # Threshold ys based on the distance threshol + ys_binary = (ys < dist_th).type(xs.type()) + + + # Construct the data dictionary + filtering_data['xs'] = xs + filtering_data['ys'] = ys + filtering_data['ts'] = ts + filtering_data['Rs'] = Rs + + + return filtering_data + + +def extract_transformation_matrices(T0, indices): + """ + Compute the relative transformation matrices for the overlaping pairs from their global transformation matrices + + Args: + T0 (torch tensor): global transformation matrices [4*b,4] + indices (torch tensor): indices of the overlaping point couds [B,2] + + Returns: + rots (torch tensor): pairwise rotation matrices [B,3,3] (B!=b) + trans (torch tensor): pairwise translation parameters [B,3,1] (B!=b) + + """ + + indices = indices.detach().cpu().numpy() + T0 = T0.detach().cpu().numpy() + + rot_matrices = [] + trans_vectors = [] + + for row in indices: + + temp_trans_matrix = T0[4*row[0]:4*(row[0]+1), :] @ np.linalg.inv(T0[4*row[1]:4*(row[1]+1), :]) + + rot_matrices.append(temp_trans_matrix[0:3,0:3]) + trans_vectors.append(temp_trans_matrix[0:3,3]) + + rots = torch.from_numpy(np.asarray(rot_matrices)).to(T0) + trans = torch.from_numpy(np.asarray(trans_vectors)).unsqueeze(-1).to(T0) + + return rots, trans + + +def pairwise_distance(src, dst, normalized_feature=True): + """Calculate Euclidean distance between each two points. + + Args: + src (torch tensor): source data, [b, n, c] + dst (torch tensor): target data, [b, m, c] + normalized_feature (bool): distance computation can be more efficient + + Returns: + dist (torch tensor): per-point square distance, [b, n, m] + """ + + B, N, _ = src.shape + _, M, _ = dst.shape + + dist = torch.matmul(src, dst.permute(0, 2, 1)) + + # If features are normalized the distance is related to inner product + if not normalized_feature: + dist = -2 * dists + dist += torch.sum(src ** 2, dim=-1)[:, :, None] + dist += torch.sum(dst ** 2, dim=-1)[:, None, :] + + return torch.sqrt(dist) + + + + +if __name__ == '__main__': + + # Weighted Kabsch algorithm example + + pc_1 = torch.rand(1,5000,3) + + T = sample_random_trans(pc_1) + + R = T[:,:3,:3] + t = T[:,3:4,3] + + pc_2_t = torch.matmul(R,pc_1.transpose(1,2)).transpose(1,2) + t + + rotation_matrix, translation_vector, res, _ = kabsch_transformation_estimation(pc_1, pc_2_t) + + print('Input rotation matrix: {}'.format(R)) + print('Estimated rotation matrix: {}'.format(rotation_matrix)) + + print('Input translation vector: {}'.format(t)) + print('Estimated translation vector: {}'.format(translation_vector)) + + +class Timer(object): + """A simple timer.""" + + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.avg = 0. + + def reset(self): + self.total_time = 0 + self.calls = 0 + self.start_time = 0 + self.diff = 0 + self.avg = 0 + + def tic(self): + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.avg = self.total_time / self.calls + if average: + return self.avg + else: + return self.diff \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..52c7a4f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +coloredlogs +tqdm +nibabel +pyyaml +gitpython +tensorboardx +easydict +multiprocess diff --git a/scripts/benchmark_pairwise_registration.py b/scripts/benchmark_pairwise_registration.py new file mode 100644 index 0000000..fc106aa --- /dev/null +++ b/scripts/benchmark_pairwise_registration.py @@ -0,0 +1,430 @@ +""" +Script for benchmarking the pairwise registration algorithm on 3DMatch and Redwood datasets. +The script expects that the correspondences (any feature descriptor) are precomputed +and benchmarks the transformation estimation algorithms. Other datasets can easily be added +provided that they are expressed in the same data formats. + +NOTE: The results might deviate little from the official benchmarking code that is implemented +in matlab (https://github.com/andyzeng/3dmatch-toolbox). The reason being different RANSAC +implementation and overlap estimation (official one is also implemented here but is slower). + +Code is partially borrowed from the Chris Choy's FCGF repository (https://github.com/chrischoy/FCGF) + +Author: Zan Gojcic +""" + +import os +import glob +import sys +import numpy as np +import argparse +import logging +import open3d as o3d +from collections import defaultdict +import torch +import coloredlogs +from matplotlib import pyplot as plt +cwd = os.getcwd() +sys.path.append(cwd) + +from lib.utils import ensure_dir, read_trajectory, write_trajectory, read_trajectory_info, get_folder_list, \ + kabsch_transformation_estimation, Timer, run_ransac, rotation_error, load_config, \ + translation_error, evaluate_registration, compute_overlap_ratio, extract_corresponding_trajectors + +from scripts.utils import make_pairwise_eval_data_loader +from lib.checkpoints import CheckpointIO +import lib.config as config + +o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) + +SHORT_NAMES = {} +SHORT_NAMES['3d_match'] = {'kitchen':'Kitchen', + 'sun3d-home_at-home_at_scan1_2013_jan_1':'Home 1', + 'sun3d-home_md-home_md_scan9_2012_sep_30':'Home 2', + 'sun3d-hotel_uc-scan3':'Hotel 1', + 'sun3d-hotel_umd-maryland_hotel1':'Hotel 2', + 'sun3d-hotel_umd-maryland_hotel3':'Hotel 3', + 'sun3d-mit_76_studyroom-76-1studyroom2':'Study', + 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika':'MIT Lab'} + +SHORT_NAMES['redwood'] = {'iclnuim-livingroom1':'livingroom1', + 'iclnuim-livingroom2':'livingroom2', + 'iclnuim-office1':'office1', + 'iclnuim-office2':'office2'} + + +def estimate_trans_param_RANSAC(eval_data, source_path, dataset, scene_info, method, mutuals, save_data=False, overlap_method='FCGF'): + """ + Estimates the pairwise transformation parameters from the provided correspondences using RANSAC and + saves the results in the trajectory file that can be used to estimate the registration precision and recall. + + Args: + eval_data (torch dataloader): dataloader with the evaluation data + source_path (numpy array): coordinates of the correspondences from the second point cloud [n,3] + dataset (str): name of the dataset + scene_info (dict): metadata of the individual scenes from the dataset + method (str): method used for estimating the pairwise transformation paramaters + mutuals (bool): if True only mutually closest neighbors are used + save_data (bool): if True transformation parameters are saved in npz files + overlap_method (str): method for overlap computation [FCGF, 3DMatch] + + """ + + # Initialize the transformation matrix + num_pairs = scene_info['nr_examples'] + est_trans_param = np.tile(np.eye(4),reps=[num_pairs,1]) + + # Configure the base save path + save_path = os.path.join(source_path,'results', method) + save_path += '/mutuals/' if mutuals else '/all/' + ensure_dir(save_path) + + reg_metadata = [] + + logging.info('Starting RANSAC based registration estimation for {} pairs!'.format(num_pairs)) + avg_timer, full_timer, overlap_timer = Timer(), Timer(), Timer() + full_timer.tic() + + overlap_threshold = 0.3 if dataset =='3d_match' else 0.23 + + for batch in eval_data: + for idx in range(batch['xs'].shape[0]): + + data = batch['xs'][idx,0,:,:].numpy() + pc_1 = batch['xyz1'][idx][0] + pc_2 = batch['xyz2'][idx][0] + meta = batch['metadata'][idx] + pair_idx = int(batch['idx'][idx].numpy().item()) + + avg_timer.tic() + T_est = run_ransac(data[:,0:3], data[:,3:]) + avg_time = avg_timer.toc() + T_est = np.linalg.inv(T_est) + + overlap_timer.tic() + overlap_ratio = compute_overlap_ratio(pc_1, pc_2, T_est, method=overlap_method) + avg_time_overlap = overlap_timer.toc() + + overlap_flag = True if overlap_ratio >= overlap_threshold else False + est_trans_param[4*pair_idx: 4*pair_idx + 4, :] = T_est + + reg_metadata.append([str(int(meta[1])), str(int(meta[2])), overlap_flag]) + + if save_data: + np.savez_compressed( + os.path.join(save_path, meta[0], 'cloud_{}_{}.npz'.format(str(int(meta[1])), str(int(meta[2])))), + t_est=T_est[0:3,3], + R_est=T_est[0:3,0:3], + overlap=overlap_flag) + + + if len(eval_data) != 0: + logging.info('RANSAC based registration estimation is complete!') + logging.info('{} pairwise registration parameters estimated in {:.3f}s'.format(num_pairs,full_timer.toc(average=False))) + logging.info('Transformation estimation run time {:.4f}s per pair'.format(avg_time)) + logging.info('Overlap computation run time {:.4f}s per pair'.format(avg_time_overlap)) + + # Loop through the transformation matrix and save results to trajectory files + for key in scene_info: + if key != 'nr_examples': + scene_idx = scene_info[key] + ensure_dir(os.path.join(save_path, key)) + trans_par = est_trans_param[scene_idx[0]:scene_idx[1],:].reshape(-1,4,4) + write_trajectory(trans_par,reg_metadata[scene_idx[0]//4:scene_idx[1]//4], os.path.join(save_path, key, 'traj.txt')) + + + +def infer_transformation_parameters(eval_data, source_path, dataset, scene_info, method, model_path, mutuals, save_data=False, overlap_method='FCGF', refine=False): + """ + Estimates the pairwise transformation parameters from the provided correspondences using a deep learning model + and saves the results in the trajectory file that can be used to estimate the registration precision and recall. + + Args: + eval_data (torch dataloader): dataloader with the evaluation data + source_path (numpy array): coordinates of the correspondences from the second point cloud [n,3] + dataset (str): name of the dataset + scene_info (dict): metadata of the individual scenes from the dataset + method (str): method used for estimating the pairwise transformation paramaters + model_path (str): path to the model + mutuals (bool): if True only mutually closest neighbors are used + save_data (bool): if True transformation parameters are saved in npz files + overlap_method (str): method for overlap computation [FCGF, 3DMatch] + refine (bool): if the RANSAC should be applied after the network filtering on the inliers (similar to 2D filtering networks) + + """ + # Model initialization + logging.info("Using the method {}".format(method)) + + # Load config file + cfg = load_config(os.path.join('./configs/pairwise_registration/eval', method + '.yaml')) + model = config.get_model(cfg) + + # Load pre-trained model + model_name = model_path.split('/')[-1] + model_path = '/'.join(model_path.split('/')[0:-1]) + + kwargs = {'model': model} + + checkpoint_io = CheckpointIO(model_path, initialize_from=None, + initialization_file_name=None, **kwargs) + + load_dict = checkpoint_io.load(model_name) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + + # Initialize the transformation matrix + num_pairs = scene_info['nr_examples'] + est_trans_param = np.tile(np.eye(4),reps=[num_pairs,1]) + + save_path = os.path.join(source_path,'results', method) + save_path += '/mutuals/' if args.mutuals else '/all/' + + ensure_dir(save_path) + reg_metadata = [] + + logging.info('Starting {} based registration estimation for {} pairs!'.format(method, num_pairs)) + avg_timer, full_timer = Timer(), Timer() + full_timer.tic() + T_est = np.eye(4) + + overlap_threshold = 0.3 if dataset =='3d_match' else 0.23 + logging.info('Using overlap threshold {} for the dataset {}.'.format(overlap_threshold, dataset)) + for batch in eval_data: + + # Filter the correspondences and estimate the pairwise transformation parameters + avg_timer.tic() + filtered_output = model.filter_correspondences(batch) + + + avg_time = avg_timer.toc() + + # We still have to loop through the batch for the overlap estimation (point cloud not the same size) + for idx in range(batch['xs'].shape[0]): + + pc_1 = batch['xyz1'][idx][0] + pc_2 = batch['xyz2'][idx][0] + meta = batch['metadata'][idx] + pair_idx = int(batch['idx'][idx].numpy().item()) + if refine: + data = batch['xs'][idx,0,:,:].numpy() + inliers = (filtered_output['scores'][-1][idx].cpu().numpy() > 0.5) + T_est = run_ransac(data[inliers,0:3], data[inliers,3:]) + + else: + T_est[0:3,0:3] = filtered_output['rot_est'][-1][idx].cpu().numpy() + T_est[0:3,3] = filtered_output['trans_est'][-1][idx].cpu().numpy().reshape(-1) + T_est = np.linalg.inv(T_est) + + overlap_ratio = compute_overlap_ratio(pc_1, pc_2, T_est, method=overlap_method) + overlap_flag = True if overlap_ratio >= overlap_threshold else False + + est_trans_param[4*pair_idx: 4*pair_idx + 4, :] = T_est + + reg_metadata.append([str(int(meta[1])), str(int(meta[2])), overlap_flag]) + + if save_data: + np.savez_compressed( + os.path.join(save_path, meta[0], 'cloud_{}_{}.npz'.format(str(int(meta[1])), str(int(meta[2])))), + t_est=T_est[0:3,3], + R_est=T_est[0:3,0:3], + overlap=overlap_flag) + + + if len(eval_data) != 0: + logging.info('RANSAC based registration estimation is complete!') + logging.info('{} pairwise registration parameters estimated in {:.3f}s'.format(num_pairs,full_timer.toc(average=False))) + logging.info('Pure run time {:.4f}s per pair'.format(avg_time/eval_data.batch_size)) + + # Loop through the transformation matrix and save results to trajectory files + for key in scene_info: + if key != 'nr_examples': + scene_idx = scene_info[key] + ensure_dir(os.path.join(save_path, key)) + trans_par = est_trans_param[scene_idx[0]:scene_idx[1],:].reshape(-1,4,4) + write_trajectory(trans_par,reg_metadata[scene_idx[0]//4:scene_idx[1]//4], os.path.join(save_path, key, 'traj.txt')) + + + + +def evaluate_registration_performance(eval_data, source_path,dataset, scene_info, method, model, mutuals=False, save_data = False, overlap_method='FCGF', refine=False): + """ + Evaluates the pairwise registration performance of the selected method on the selected dataset. + + Args: + eval_data (torch dataloader): dataloader with the evaluation data + source_path (numpy array): coordinates of the correspondences from the second point cloud [n,3] + dataset (str): name of the dataset + scene_info (dict): metadata of the individual scenes from the dataset + method (str): method used for estimating the pairwise transformation paramaters + model (str): path to the model + mutuals (bool): if True only mutually closest neighbors are used + save_data (bool): if True transformation parameters are saved in npz files + overlap_method (str): method for overlap computation [FCGF, 3DMatch] + refine (bool): if the RANSAC should be applied after the network filtering on the inliers (similar to 2D filtering networks) + + """ + # Prepare the variables + + re_per_scene = defaultdict(list) + te_per_scene = defaultdict(list) + re_all, te_all, precision, recall = [], [], [], [] + re_medians, te_medians = [], [] + + # Estimate the transformation parameters (if the trajectory files are note existing yet) + if method == 'RANSAC': + estimate_trans_param_RANSAC(eval_data, source_path, dataset, scene_info, method, mutuals, save_data, overlap_method) + else: + infer_transformation_parameters(eval_data, source_path, dataset, scene_info, method, model, mutuals, save_data, overlap_method, refine) + + logging.info("Results of {} on {} dataset!".format(method, dataset)) + logging.info("--------------------------------------------") + logging.info("{:<12} ¦ prec. ¦ rec. ¦ re ¦ te ¦".format('Scene')) + logging.info("--------------------------------------------") + scenes = get_folder_list(os.path.join(source_path,'correspondences')) + scenes = [scene.split('/')[-1] for scene in scenes] + + for idx, scene in enumerate(scenes): + # Extract the values from the gt trajectory and trajectory information files + gt_pairs, gt_traj = read_trajectory(os.path.join(source_path,'raw_data', scene, "gt.log")) + n_fragments, gt_traj_cov = read_trajectory_info(os.path.join(source_path,'raw_data', scene, "gt.info")) + assert gt_traj.shape[0] > 0, "Empty trajectory file" + + # Extract the estimated transformation matrices + if mutuals: + if method == 'RANSAC': + est_pairs, est_traj = read_trajectory(os.path.join(source_path, 'results', + method, 'mutuals', scene, "traj.txt")) + else: + est_pairs, est_traj = read_trajectory(os.path.join(source_path, 'results', + method, 'mutuals', scene, "traj.txt")) + else: + if method == 'RANSAC': + est_pairs, est_traj = read_trajectory(os.path.join(source_path, 'results', + method, 'all', scene, "traj.txt")) + else: + est_pairs, est_traj = read_trajectory(os.path.join(source_path, 'results', + method, 'all', scene, "traj.txt")) + + + temp_precision, temp_recall = evaluate_registration(n_fragments, est_traj, est_pairs, gt_pairs, gt_traj, gt_traj_cov) + + # Filter out only the transformation matrices that are in the GT and EST + ext_traj_est, ext_traj_gt = extract_corresponding_trajectors(est_pairs,gt_pairs,est_traj, gt_traj) + + re = rotation_error(torch.from_numpy(ext_traj_gt[:,0:3,0:3]), torch.from_numpy(ext_traj_est[:,0:3,0:3])).cpu().numpy() + te = translation_error(torch.from_numpy(ext_traj_gt[:,0:3,3:4]), torch.from_numpy(ext_traj_est[:,0:3,3:4])).cpu().numpy() + + re_per_scene['mean'].append(np.mean(re)) + re_per_scene['median'].append(np.median(re)) + re_per_scene['min'].append(np.min(re)) + re_per_scene['max'].append(np.max(re)) + + + te_per_scene['mean'].append(np.mean(te)) + te_per_scene['median'].append(np.median(te)) + te_per_scene['min'].append(np.min(te)) + te_per_scene['max'].append(np.max(te)) + + + re_all.extend(re.reshape(-1).tolist()) + te_all.extend(te.reshape(-1).tolist()) + + precision.append(temp_precision) + recall.append(temp_recall) + re_medians.append(np.median(re)) + te_medians.append(np.median(te)) + + logging.info("{:<12} ¦ {:.3f} ¦ {:.3f} ¦ {:.3f} ¦ {:.3f} ¦".format(SHORT_NAMES[dataset][scene], temp_precision, temp_recall, np.median(re), np.median(te))) + + logging.info("--------------------------------------------") + logging.info("Mean precision: {:.3f} +- {:.3f}".format(np.mean(precision),np.std(precision))) + logging.info("Mean recall: {:.3f} +- {:.3f}".format(np.mean(recall),np.std(recall))) + logging.info("Mean ae: {:.3f} +- {:.3f} [deg]".format(np.mean(re_medians),np.std(re_medians))) + logging.info("Mean te: {:.3f} +- {:.3f} [m]".format(np.mean(te_medians),np.std(te_medians))) + logging.info("--------------------------------------------") + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--source_path', default='./data/eval_data/', type=str, help='path to dataset') + parser.add_argument( + '--dataset', default='3d_match', type=str, help='path to dataset') + parser.add_argument( + '--method', default='OANet', type=str, help='Which method should be used [RANSAC, RegBlock, Joint]') + parser.add_argument( + '--model', + default=None, + type=str, + help='path to latest checkpoint (default: None)') + + parser.add_argument( + '--batch_size', + type=int, + default=32, + help='Batch size (if mutuals are selected batch size will be 1).') + + parser.add_argument( + '--mutuals', + action='store_true', + help='If only mutually closest NN should be used (reciprocal matching).') + + parser.add_argument( + '--save_data', + action='store_true', + help='If the intermediate data should be saved to npz files.') + + parser.add_argument( + '--overwrite', + action='store_true', + help='If results for this method and dataset exist they will be overwritten, otherwised they will be loaded and used') + + parser.add_argument( + '--overlap_method', + type=str, + default='FCGF', + help='Method to compute the overlap ratio (FCGF or 3DMatch) FCGF is slightly faster than official 3DMatch') + + parser.add_argument( + '--only_gt_overlaping', + action='store_true', + help='Transformation matrices will be computed only for the GT overlaping pairs. Does not change \ + registration recall that is typically reported in the papers but it is almost 10x faster.') + + parser.add_argument( + '--refine', + action='store_true', + help='The results of the deep methods are refined by the subsequent RANSAC using only the inliers') + + + args = parser.parse_args() + + # Prepare the logger + logger = logging.getLogger() + coloredlogs.install(level='INFO', logger=logger) + log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s') + + + # Ensure that the source and the target folders were provided + assert args.source_path is not None + + # Adapt source path + args.source_path = os.path.join(args.source_path, args.dataset) + + # Prepare the data loader + eval_data, scene_info = make_pairwise_eval_data_loader(args) + + + with torch.no_grad(): + evaluate_registration_performance(eval_data, + source_path=args.source_path, + dataset=args.dataset, + scene_info=scene_info, + method=args.method, + model=args.model, + mutuals=args.mutuals, + save_data=args.save_data, + overlap_method=args.overlap_method, + refine=args.refine) \ No newline at end of file diff --git a/scripts/download_3DMatch_eval.sh b/scripts/download_3DMatch_eval.sh new file mode 100644 index 0000000..4167e73 --- /dev/null +++ b/scripts/download_3DMatch_eval.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +DATA_TYPE=$1 + +function download() { + if [ ! -d "data" ]; then + mkdir -p "data" + fi + cd data + + if [ ! -d "eval_data" ]; then + mkdir -p "eval_data" + fi + cd eval_data + + if [ ! -d "3d_match" ]; then + mkdir -p "3d_match" + fi + cd 3d_match + + url="https://share.phys.ethz.ch/~gsg/LMPR/data/" + + if [ "$DATA_TYPE" == "raw" ] + then + data_set="3d_match_eval_raw.zip" + echo $url$data_set + else + data_set="3d_match_eval_preprocessed.zip" + echo $url$data_set + fi + + + wget --no-check-certificate --show-progress "$url$data_set" + unzip $data_set + rm $data_set + cd /../../.. + + +} + +function main() { + if [ -z "$DATA_TYPE" ]; then + echo "Data type has to be selected! One of [raw, preprocessed]" + exit 1 + fi + + echo $DATA_TYPE + if [ "$DATA_TYPE" == "raw" ] || [ $DATA_TYPE == "preprocessed" ] + then + download + else + echo "Wrong data type selected must be on of [raw, preprocessed]." +fi +} + +main; diff --git a/scripts/download_3DMatch_train.sh b/scripts/download_3DMatch_train.sh new file mode 100644 index 0000000..ffa1caf --- /dev/null +++ b/scripts/download_3DMatch_train.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +DATA_TYPE=$1 + +function download() { + if [ ! -d "data" ]; then + mkdir -p "data" + fi + cd data + + if [ ! -d "training_data" ]; then + mkdir -p "training_data" + fi + cd training_data + + if [ ! -d "3d_match" ]; then + mkdir -p "3d_match" + fi + cd 3d_match + + url="https://share.phys.ethz.ch/~gsg/LMPR/data/" + + if [ "$DATA_TYPE" == "raw" ] + then + data_set="3d_match_raw.zip" + echo $url$data_set + else + data_set="3d_match_preprocessed.zip" + echo $url$data_set + fi + + + wget --no-check-certificate --show-progress "$url$data_set" + unzip $data_set + rm $data_set + cd /../.. + + +} + +function main() { + if [ -z "$DATA_TYPE" ]; then + echo "Data type has to be selected! One of [raw, preprocessed]" + exit 1 + fi + + echo $DATA_TYPE + if [ "$DATA_TYPE" == "raw" ] || [ $DATA_TYPE == "preprocessed" ] + then + download + else + echo "Wrong data type selected must be on of [raw, preprocessed]." +fi +} + +main; diff --git a/scripts/download_pretrained_models.sh b/scripts/download_pretrained_models.sh new file mode 100644 index 0000000..79242f7 --- /dev/null +++ b/scripts/download_pretrained_models.sh @@ -0,0 +1,6 @@ +mkdir pretrained +cd pretrained +wget --no-check-certificate --show-progress https://share.phys.ethz.ch/~gsg/LMPR/pretrained_models/pretrained_models.zip +unzip pretrained_models.zip +rm pretrained_models.zip +cd .. diff --git a/scripts/download_redwood_eval.sh b/scripts/download_redwood_eval.sh new file mode 100644 index 0000000..e426e5b --- /dev/null +++ b/scripts/download_redwood_eval.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +DATA_TYPE=$1 + +function download() { + if [ ! -d "data" ]; then + mkdir -p "data" + fi + cd data + + if [ ! -d "eval_data" ]; then + mkdir -p "eval_data" + fi + cd eval_data + + if [ ! -d "redwood" ]; then + mkdir -p "redwood" + fi + cd 3d_match + + url="https://share.phys.ethz.ch/~gsg/LMPR/data/" + + if [ "$DATA_TYPE" == "raw" ] + then + data_set="redwood_eval_raw.zip" + echo $url$data_set + else + data_set="redwood_eval_preprocessed.zip" + echo $url$data_set + fi + + + wget --no-check-certificate --show-progress "$url$data_set" + unzip $data_set + rm $data_set + cd /../../.. + + +} + +function main() { + if [ -z "$DATA_TYPE" ]; then + echo "Data type has to be selected! One of [raw, preprocessed]" + exit 1 + fi + + echo $DATA_TYPE + if [ "$DATA_TYPE" == "raw" ] || [ $DATA_TYPE == "preprocessed" ] + then + download + else + echo "Wrong data type selected must be on of [raw, preprocessed]." +fi +} + +main; diff --git a/scripts/extract_data.py b/scripts/extract_data.py new file mode 100644 index 0000000..fe805ab --- /dev/null +++ b/scripts/extract_data.py @@ -0,0 +1,376 @@ +""" +Source code used to extract FCGF features, pairwise correspondences and training that can be used to train our network +without computing the FCGF descriptors on the fly. This way of training greatly eases the sampling of the point cloud pairs and +can at least be used to pretrain the filtering network, confidence estimation block and transformation synchronization. + +""" + +import argparse +import os +import open3d as o3d +import torch +import logging +import numpy as np +import sys +import coloredlogs +import torch +import easydict +import multiprocessing as mp +import sys + +cwd = os.getcwd() +sys.path.append(cwd) + +from sklearn.neighbors import NearestNeighbors +from functools import partial + +from lib.descriptor.fcgf import FCGFNet +from scripts.utils import extract_features, transform_point_cloud +from lib.utils import compute_overlap_ratio, load_point_cloud, ensure_dir, get_file_list, get_folder_list + +o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) + + + +def extract_features_batch(model, source_path, target_path, dataset, voxel_size, device): + """ + Extracts the per point features in the FCGF feature space and saves them to the predefined path + + Args: + model (FCGF model instance): model used to inferr the descriptors + source_path (str): path to the raw files + target path (str): where to save the extracted data + dataset (float): name of the dataset + voxel_size (float): voxel sized used to create the sparse tensor + device (pytorch device): cuda or cpu + """ + + source_path = os.path.join(source_path,dataset,'raw_data') + target_path = os.path.join(target_path,dataset,'features') + + ensure_dir(target_path) + + folders = get_folder_list(source_path) + + assert len(folders) > 0, 'Could not find {} folders under {}'.format(dataset, source_path) + + logging.info(folders) + list_file = os.path.join(target_path, 'list.txt') + f = open(list_file, 'w') + model.eval() + + for fo in folders: + scene_name = fo.split() + files = get_file_list(fo, '.ply') + fo_base = os.path.basename(fo) + ensure_dir(os.path.join(target_path, fo_base)) + + f.write('%s %d\n' % (fo_base, len(files))) + for i, fi in enumerate(files): + save_fn = '%s_%03d' % (fo_base, i) + if os.path.exists(os.path.join(target_path, fo_base, save_fn + '.npz')): + print('Correspondence file already exits moving to the next example.') + else: + # Extract features from a file + pcd = o3d.io.read_point_cloud(fi) + + if i % 100 == 0: + logging.info(f'{i} / {len(files)}: {save_fn}') + + xyz_down, feature = extract_features( + model, + xyz=np.array(pcd.points), + rgb=None, + normal=None, + voxel_size=voxel_size, + device=device, + skip_check=True) + + np.savez_compressed( + os.path.join(target_path, fo_base, save_fn), + points=np.array(pcd.points), + xyz=xyz_down, + feature=feature.detach().cpu().numpy()) + + f.close() + +def extract_correspondences(dataset,source_path,target_path, n_correspondences): + """ + Prepares the arguments and runs the correspondence extration in parallel mode + + Args: + dataset (str): name of the dataset + source_path (str): path to the raw files + target_path (str): path to where the extracted data will be saved + n_correspondences (int): number of points to sample + + """ + source_path = os.path.join(source_path,dataset,'raw_data') + + + scene_paths = get_folder_list(source_path) + idx = list(range(len(scene_paths))) + + + pool = mp.Pool(processes=6) + func = partial(run_correspondence_extraction, dataset, source_path, target_path, n_correspondences) + pool.map(func, idx) + pool.close() + pool.join() + + +def run_correspondence_extraction(dataset,source_path, target_path, n_correspondences, idx): + """ + Computes the correspondences in the FCGF space together with the mutuals and ratios side information + + Args: + dataset (str): name of the dataset + source_path (str): path to the raw data + target_path (str): path to where the extracted data will be saved + n_correspondences (int): number of points to sample + idx (int): index of the scene, used for parallel processing + + """ + + # Initialize all the paths + features_path = os.path.join(target_path,dataset,'features') + target_path = os.path.join(target_path,dataset,'correspondences') + + fo = get_folder_list(source_path)[idx] + fo_base = os.path.basename(fo) + files = get_file_list(os.path.join(features_path, fo_base), '.npz') + + ensure_dir(os.path.join(target_path, fo_base)) + + # Loop over all fragment pairs and compute the training data + for idx_1 in range(len(files)): + for idx_2 in range(idx_1+1, len(files)): + if os.path.exists(os.path.join(target_path, fo_base,'{}_{}_{}.npz'.format(fo_base,str(idx_1).zfill(3), str(idx_2).zfill(3)))): + logging.info('Correspondence file already exits moving to the next example.') + + else: + pc_1_data = np.load(os.path.join(features_path, fo_base, fo_base + '_{}.npz'.format(str(idx_1).zfill(3)))) + pc_1_features = pc_1_data['feature'] + pc_1_keypoints = pc_1_data['xyz'] + + pc_2_data = np.load(os.path.join(features_path, fo_base, fo_base + '_{}.npz'.format(str(idx_2).zfill(3)))) + pc_2_features = pc_2_data['feature'] + pc_2_keypoints = pc_2_data['xyz'] + + # Sample with replacement if less then n_correspondences points are in the point cloud + if pc_1_features.shape[0] >= n_correspondences: + inds_1 = np.random.choice(pc_1_features.shape[0], n_correspondences, replace=False) + else: + inds_1 = np.random.choice(pc_1_features.shape[0], n_correspondences, replace=True) + + if pc_2_features.shape[0] >= n_correspondences: + inds_2 = np.random.choice(pc_2_features.shape[0], n_correspondences, replace=False) + else: + inds_2 = np.random.choice(pc_2_features.shape[0], n_correspondences, replace=True) + + + pc_1_features = pc_1_features[inds_1,:] + pc_2_features = pc_2_features[inds_2, :] + pc_1_key = pc_1_keypoints[inds_1,:] + pc_2_key = pc_2_keypoints[inds_2,:] + + # find the correspondence using nearest neighbor search in the feature space (two way) + nn_search = NearestNeighbors(n_neighbors=1, metric='minkowski', p=2) + nn_search.fit(pc_2_features) + nn_dists, nn_indices = nn_search.kneighbors(X=pc_1_features, n_neighbors=2, return_distance=True) + + + nn_search.fit(pc_1_features) + nn_dists_1, nn_indices_1 = nn_search.kneighbors(X=pc_2_features, n_neighbors=2, return_distance=True) + + ol_nn_ids = np.where((nn_indices[nn_indices_1[:, 0], 0] - np.arange(pc_1_features.shape[0])) == 0)[0] + + # Initialize mutuals and ratios + mutuals = np.zeros((n_correspondences, 1)) + mutuals[ol_nn_ids] = 1 + ratios = nn_dists[:, 0] / nn_dists[:, 1] + + # Concatenate the correspondence coordinates + xs = np.concatenate((pc_1_key[nn_indices_1[:, 0], :], pc_2_key), axis=1) + + np.savez_compressed( + os.path.join(target_path, fo_base, '{}_{}_{}.npz'.format(fo_base, str(idx_1).zfill(3), str(idx_2).zfill(3))), + x=xs, + mutuals=mutuals, + ratios=ratios) + + +def extract_precomputed_training_data(dataset, source_path, target_path, voxel_size, inlier_threshold): + """ + Prepares the data for training the filtering networks with precomputed correspondences (without FCGF descriptor) + + Args: + dataset (str): name of the dataset + source_path (str): path to the raw data + target_path (str): path to where the extracted data will be saved + voxel_size (float): voxel size that was used to compute the features + inlier_threshold (float): threshold to determine if a correspondence is an inlier or outlier + """ + source_path = os.path.join(source_path,dataset,'raw_data') + features_path = os.path.join(target_path,dataset,'features') + correspondence_path = os.path.join(target_path,dataset,'correspondences') + target_path = os.path.join(target_path,dataset,'training_data') + + + ensure_dir(target_path) + + # Check that the GT global transformation matrices are available and that the FCGF features are computed + folders = get_folder_list(source_path) + + assert len(folders) > 0, 'Could not find {} folders under {}'.format(dataset, source_path) + + logging.info('Found {} scenes from the {} dataset!'.format(len(folders),dataset)) + + for fo in folders: + + scene_name = fo.split() + fo_base = os.path.basename(fo) + ensure_dir(os.path.join(target_path, fo_base)) + + pc_files = get_file_list(fo, '.ply') + trans_files = get_file_list(fo, '.txt') + assert len(pc_files) <= len(trans_files), 'The number of point cloud files does not equal the number of GT trans parameters!' + + feat_files = get_file_list(os.path.join(features_path,fo_base), '.npz') + assert len(pc_files) == len(feat_files), 'Features for scene {} are either not computed or some are missing!'.format(fo_base) + + coor_files = get_file_list(os.path.join(correspondence_path,fo_base), '.npz') + + assert len(coor_files) == int((len(feat_files) * (len(feat_files)-1))/2), 'Correspondence files for the scene {} are missing. First run the correspondence extraction!'.format(fo_base) + + # Loop over all fragment pairs and compute the training data + for idx_1 in range(len(pc_files)): + for idx_2 in range(idx_1+1, len(pc_files)): + if os.path.exists(os.path.join(target_path, fo_base,'{}_{}_{}.npz'.format(fo_base,str(idx_1).zfill(3), str(idx_2).zfill(3)))): + logging.info('Training file already exits moving to the next example.') + + + data = np.load(os.path.join(correspondence_path, fo_base,'{}_{}_{}.npz'.format(fo_base,str(idx_1).zfill(3), str(idx_2).zfill(3)))) + xs = data['xs'] + mutuals = data['mutuals'] + ratios = data['ratios'] + + # Get the GT transformation parameters + t_1 = np.genfromtxt(os.path.join(source_path, fo_base, 'cloud_bin_{}.info.txt'.format(idx_1)), skip_header=1) + t_2 = np.genfromtxt(os.path.join(source_path, fo_base, 'cloud_bin_{}.info.txt'.format(idx_2)), skip_header=1) + + + # Get the GT transformation parameters + pc_1 = load_point_cloud(os.path.join(source_path,fo_base, 'cloud_bin_{}.ply'.format(idx_1)), data_type='numpy') + pc_2 = load_point_cloud(os.path.join(source_path,fo_base, 'cloud_bin_{}.ply'.format(idx_2)), data_type='numpy') + + pc_1_tr = transform_point_cloud(pc_1, t_1[0:3,0:3], t_1[0:3,3].reshape(-1,1)) + pc_2_tr = transform_point_cloud(pc_2, t_2[0:3,0:3], t_2[0:3,3].reshape(-1,1)) + + overlap_ratio = compute_overlap_ratio(pc_1_tr, pc_2_tr, np.eye(4), method = 'FCGF', voxel_size=voxel_size) + + # Estimate pairwise transformation parameters + t_3 = np.matmul(np.linalg.inv(t_2), t_1) + + r_matrix = t_3[0:3, 0:3] + t_vector = t_3[0:3, 3] + + # Transform the keypoints of the first point cloud + pc_1_key_tr = transform_point_cloud(xs[:,0:3], r_matrix, t_vector.reshape(-1,1)) + + # Compute the residuals after the transformation + y_s = np.sqrt(np.sum(np.square(pc_1_key_tr - xs[:,3:6]), axis=1)) + + # Inlier percentage + inlier_ratio = np.where(y_s < inlier_threshold)[0].shape[0] / y_s.shape[0] + inlier_ratio_mutuals = np.where(y_s[mutuals.astype(bool).reshape(-1)] < inlier_threshold)[0].shape[0] / np.sum(mutuals) + + np.savez_compressed(os.path.join(target_path,fo_base,'cloud_{}_{}.npz'.format(str(idx_1).zfill(3), str(idx_2).zfill(3))), + R=r_matrix, t=t_vector, x=xs, y=y_s, mutuals=mutuals, inlier_ratio=inlier_ratio, inlier_ratio_mutuals=inlier_ratio_mutuals, + ratios=ratios, overlap=overlap_ratio) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--source_path', default=None, type=str, help='path to the raw files') + parser.add_argument( + '--target_path', default=None, type=str, help='path to where the extracted data will be saved') + parser.add_argument( + '--dataset', default=None, type=str, help='name of the dataset') + parser.add_argument( + '-m', + '--model', + default=None, + type=str, + help='path to latest checkpoint (default: None)') + parser.add_argument( + '--voxel_size', + default=0.025, + type=float, + help='voxel size to preprocess point cloud') + parser.add_argument( + '--n_correspondences', + default=10000, + type=int, + help='number of points to be sampled in the correspondence estimation') + parser.add_argument( + '--inlier_threshold', + default=0.05, + type=float, + help='threshold to determine if the correspondence is an inlier or outlier') + parser.add_argument('--extract_features', action='store_true') + parser.add_argument('--extract_correspondences', action='store_true') + parser.add_argument('--extract_precomputed_training_data', action='store_true') + parser.add_argument('--refine', action='store_true') + parser.add_argument('--with_cuda', action='store_true') + + + args = parser.parse_args() + + # Prepare the logger + logger = logging.getLogger() + coloredlogs.install(level='INFO', logger=logger) + log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s') + + device = torch.device('cuda' if args.with_cuda and torch.cuda.is_available() else 'cpu') + + if args.extract_features: + assert args.model is not None + + checkpoint = torch.load(args.model) + config = checkpoint['config'] + + num_feats = 1 + model = FCGFNet( + num_feats, + config.model_n_out, + bn_momentum=0.05, + normalize_feature=config.normalize_feature, + conv1_kernel_size=config.conv1_kernel_size, + D=3) + + model.load_state_dict(checkpoint['state_dict']) + model.eval() + + model = model.to(device) + + with torch.no_grad(): + if args.extract_features: + logger.info('Starting feature extraction') + extract_features_batch(model, args.source_path, args.target_path, args.dataset, config.voxel_size, + device) + logger.info('Feature extraction completed') + + if args.extract_correspondences: + logger.info('Starting establishing pointwise correspondences in the feature space') + extract_correspondences(args.dataset,args.source_path, args.target_path, args.n_correspondences) + logger.info('Pointwise correspondences in the features space established') + + if args.extract_precomputed_training_data: + logger.info('Starting establishing pointwise correspondences in the feature space') + extract_precomputed_training_data(args.dataset,args.source_path, + args.target_path, args.inlier_threshold, args.voxel_size) + logger.info('Pointwise correspondences in the features space established') + + \ No newline at end of file diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000..b2211e1 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,199 @@ +""" +Contains the util function used in the provided scripts. +Some of the functions are borrowed from FCGF repository (https://github.com/chrischoy/FCGF/tree/440034846e9c27e4faba44346885e4cca51e9753) +""" + +import os +import re +from os import listdir +from os.path import isfile, join, isdir, splitext +import numpy as np +import torch +import MinkowskiEngine as ME +from lib.data import PrecomputedPairwiseEvalDataset, collate_fn, get_folder_list,get_file_list +from lib.utils import read_trajectory + +def read_txt(path): + """ + Reads the text file into lines. + + Args: + path (str): path to the file + + Returns: + lines (list): list of the lines from the input text file + """ + with open(path) as f: + lines = f.readlines() + lines = [x.strip() for x in lines] + + return lines + + +def ensure_dir(path): + """ + Creates dir if it does not exist. + + Args: + path (str): path to the folder + """ + if not os.path.exists(path): + os.makedirs(path, mode=0o755) + + +def extract_features(model, + xyz, + rgb=None, + normal=None, + voxel_size=0.05, + device=None, + skip_check=False, + is_eval=True): + """ + Extracts FCGF features. + + Args: + model (FCGF model instance): model used to inferr the features + xyz (torch tensor): coordinates of the point clouds [N,3] + rgb (torch tensor): colors, must be in range (0,1) [N,3] + normal (torch tensor): normal vectors, must be in range (-1,1) [N,3] + voxel_size (float): voxel size for the generation of the saprase tensor + device (torch device): which device to use, cuda or cpu + skip_check (bool): if true skip rigorous check (to speed up) + is_eval (bool): flag for evaluation mode + + Returns: + return_coords (torch tensor): return coordinates of the points after the voxelization [m,3] (m<=n) + features (torch tensor): per point FCGF features [m,c] + """ + + + if is_eval: + model.eval() + + if not skip_check: + assert xyz.shape[1] == 3 + + N = xyz.shape[0] + if rgb is not None: + assert N == len(rgb) + assert rgb.shape[1] == 3 + if np.any(rgb > 1): + raise ValueError('Invalid color. Color must range from [0, 1]') + + if normal is not None: + assert N == len(normal) + assert normal.shape[1] == 3 + if np.any(normal > 1): + raise ValueError('Invalid normal. Normal must range from [-1, 1]') + + if device is None: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + feats = [] + if rgb is not None: + # [0, 1] + feats.append(rgb - 0.5) + + if normal is not None: + # [-1, 1] + feats.append(normal / 2) + + if rgb is None and normal is None: + feats.append(np.ones((len(xyz), 1))) + + feats = np.hstack(feats) + + # Voxelize xyz and feats + coords = np.floor(xyz / voxel_size) + inds = ME.utils.sparse_quantize(coords, return_index=True) + coords = coords[inds] + # Convert to batched coords compatible with ME + coords = ME.utils.batched_coordinates([coords]) + return_coords = xyz[inds] + + feats = feats[inds] + + feats = torch.tensor(feats, dtype=torch.float32) + coords = torch.tensor(coords, dtype=torch.int32) + + stensor = ME.SparseTensor(coords=coords, feats=feats).to(device) + + return return_coords, model(stensor).F + + +def transform_point_cloud(x1, R, t, data_type='numpy'): + """ + Transforms the point cloud using the giver transformation paramaters + + Args: + x1 (np array): points of the point cloud [n,3] + R (np array): estimated rotation matrice [3,3] + t (np array): estimated translation vectors [3,1] + Returns: + x1_t (np array): points of the transformed point clouds [n,3] + """ + assert data_type in ['numpy', 'torch'] + + if data_type == 'numpy': + x1_t = (np.matmul(R, x1.transpose()) + t).transpose() + + elif data_type =='torch': + x1_t = (torch.matmul(R, x1.transpose(1,0)) + t).transpose(1,0) + + + return x1_t + + +def make_pairwise_eval_data_loader(args): + """ + Prepares the data loader for the pairwise evaluation + + Args: + args (dict): configuration parameters + + Returns: + loader (torch data loader): data loader for the evaluation data + scene_info (dict): metadate of the scenes + """ + + dset = PrecomputedPairwiseEvalDataset(args) + + batch_size = 1 if args.mutuals else args.batch_size + + # Extract the number of examples per scene + scene_names = get_folder_list(os.path.join(args.source_path,'correspondences')) + scene_info = {} + nr_examples = 0 + save_path = os.path.join(args.source_path, 'results', args.method) + save_path += '/mutuals/' if args.mutuals else '/all/' + + for folder in scene_names: + curr_scene_name = folder.split('/')[-1] + if os.path.exists(os.path.join(save_path, curr_scene_name, 'traj.txt')) and not args.overwrite: + pass + else: + if args.only_gt_overlaping: + gt_pairs, gt_traj = read_trajectory(os.path.join(args.source_path,'raw_data', curr_scene_name, "gt.log")) + examples_per_scene = len(gt_pairs) + scene_info[curr_scene_name] = [nr_examples * 4, (nr_examples + examples_per_scene) * 4 ] + nr_examples += examples_per_scene + + else: + examples_per_scene = len(get_file_list(folder)) + scene_info[curr_scene_name] = [nr_examples * 4, (nr_examples + examples_per_scene) * 4 ] + nr_examples += examples_per_scene + + # Save the total number of examples + scene_info['nr_examples'] = nr_examples + + loader = torch.utils.data.DataLoader( + dset, + batch_size=batch_size, + shuffle=False, + num_workers=4, + collate_fn=collate_fn, + pin_memory=False, + drop_last=False) + + return loader, scene_info \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..85749b6 --- /dev/null +++ b/train.py @@ -0,0 +1,158 @@ +import sys +import os +import logging +import torch +import time +import argparse +import numpy as np +import torch.optim as optim +from tensorboardX import SummaryWriter + +import lib.config as config +from lib.utils import load_config +from lib.data import make_data_loader +from lib.checkpoints import CheckpointIO +from lib.logger import prepare_logger + + +# Set the random seeds for reproducibility +np.random.seed(41) +torch.manual_seed(41) +if torch.cuda.is_available(): + torch.cuda.manual_seed(41) + + +def main(cfg, logger): + """ + Main function of this software. After preparing the data loaders, model, optimizer, and trainer, + start with the training and evaluation process. + + Args: + cfg (dict): current configuration paramaters + """ + + # Initialize parameters + model_selection_metric = cfg['train']['model_selection_metric'] + + if cfg['train']['model_selection_mode'] == 'maximize': + model_selection_sign = 1 + elif cfg['train']['model_selection_mode'] == 'minimize': + model_selection_sign = -1 + else: + raise ValueError('model_selection_mode must be either maximize or minimize.') + + # Get data loader + train_loader = make_data_loader(cfg, phase='train') + val_loader = make_data_loader(cfg, phase='val') + + # Set up tensorboard logger + tboard_logger = SummaryWriter(os.path.join(cfg['misc']['log_dir'], 'logs')) + + # Get model + model = config.get_model(cfg) + + # Get optimizer and trainer + optimizer = getattr(optim, cfg['optimizer']['alg'])(model.parameters(), lr=cfg['optimizer']['learning_rate'], + weight_decay=cfg['optimizer']['weight_decay']) + + trainer = config.get_trainer(cfg, model, optimizer, tboard_logger) + + # Load pre-trained model if existing + kwargs = { + 'model': model, + 'optimizer': optimizer, + } + + checkpoint_io = CheckpointIO(cfg['misc']['log_dir'], initialize_from=cfg['model']['init_from'], + initialization_file_name=cfg['model']['init_file_name'], **kwargs) + + try: + load_dict = checkpoint_io.load('model.pt') + except FileExistsError: + load_dict = dict() + + epoch_it = load_dict.get('epoch_it', -1) + it = load_dict.get('it', -1) + + metric_val_best = load_dict.get('loss_val_best', -model_selection_sign * np.inf) + + if metric_val_best == np.inf or metric_val_best == -np.inf: + metric_val_best = -model_selection_sign * np.inf + + logger.info('Current best validation metric ({}): {:.5f}'.format(model_selection_metric, metric_val_best)) + + # Training parameters + stat_interval = cfg['train']['stat_interval'] + stat_interval = stat_interval if stat_interval > 0 else abs(stat_interval* len(train_loader)) + + chkpt_interval = cfg['train']['chkpt_interval'] + chkpt_interval = chkpt_interval if chkpt_interval > 0 else abs(chkpt_interval* len(train_loader)) + + val_interval = cfg['train']['val_interval'] + val_interval = val_interval if val_interval > 0 else abs(val_interval* len(train_loader)) + + # Print model parameters and model graph + nparameters = sum(p.numel() for p in model.parameters()) + #print(model) + logger.info('Total number of parameters: {}'.format(nparameters)) + + # Training loop + while epoch_it < cfg['train']['max_epoch']: + epoch_it += 1 + + for batch in train_loader: + it += 1 + loss = trainer.train_step(batch, it) + tboard_logger.add_scalar('train/loss', loss, it) + + # Print output + if stat_interval != 0 and (it % stat_interval) == 0 and it != 0: + logger.info('[Epoch {}] it={}, loss={:.4f}'.format(epoch_it, it, loss)) + + # Save checkpoint + if (chkpt_interval != 0 and (it % chkpt_interval) == 0) and it != 0: + logger.info('Saving checkpoint') + checkpoint_io.save('model.pt', epoch_it=epoch_it, it=it, + loss_val_best=metric_val_best) + + # Run validation + if val_interval != 0 and (it % val_interval) == 0 and it != 0: + eval_dict = trainer.evaluate(val_loader,it) + + metric_val = eval_dict[model_selection_metric] + logger.info('Validation metric ({}): {:.4f}'.format(model_selection_metric, metric_val)) + + for k, v in eval_dict.items(): + tboard_logger.add_scalar('val/{}'.format(k), v, it) + + if model_selection_sign * (metric_val - metric_val_best) > 0: + metric_val_best = metric_val + logger.info('New best model (loss {:.4f})'.format(metric_val_best)) + checkpoint_io.save('model_best.pt', epoch_it=epoch_it, it=it, + loss_val_best=metric_val_best) + + + # Quit after the maximum number of epochs is reached + logger.info('Training completed after {} Epochs ({} it) with best val metric ({})={}'.format(epoch_it, it, model_selection_metric, metric_val_best)) + +if __name__ == "__main__": + logger = logging.getLogger + + + parser = argparse.ArgumentParser() + parser.add_argument('config', type=str, help= 'Path to the config file.') + args = parser.parse_args() + + cfg = load_config(args.config) + + # Create the output dir if it does not exist + if not os.path.exists(cfg['misc']['log_dir']): + os.makedirs(cfg['misc']['log_dir']) + + logger, checkpoint_dir = prepare_logger(cfg) + + cfg['misc']['log_dir'] = checkpoint_dir + # Argument: path to the config file + logger.info('Torch version: {}'.format(torch.__version__)) + + main(cfg, logger)