diff --git a/README.md b/README.md new file mode 100644 index 0000000..5b43259 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# Adaptive Graph Convolution for Point Cloud Analysis + +This repository contains the implementation of **AdaptConv** for point cloud analysis. + +Adaptive Graph Convolution (AdaptConv) is a point cloud convolution operator presented in our ICCV2021 paper. If you find our work useful in your research, please cite our paper. + +## Installation + +* The code has been tested on one configuration: + - PyTorch 1.1.0, CUDA 10.1 + +* Install required packages: + - numpy + - h5py + - scikit-learn + - matplotlib + +## Classification + +[classification.md](./cls/classification.md) + +## Part Segmentation + +[part_segmentation.md](./part_seg/part_segmentation.md) + +## Indoor Segmentation + +coming soon + diff --git a/cls/classification.md b/cls/classification.md new file mode 100644 index 0000000..b21cb85 --- /dev/null +++ b/cls/classification.md @@ -0,0 +1,20 @@ +## Point Cloud Classification on ModelNet40 + +### Data + +First, you may download the ModelNet40 dataset from [here](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip), and place it to `cls/data/modelnet40_ply_hdf5_2048`. We use the prepared data in HDF5 files for principle evaluation, where each object is already sampled to 2048 points. The experiments presented in the paper uses 1024 points for training and testing. + +### Usage + +To train a model for classification: + + python train.py + +Model and log files will be saved to `cls/models/train/` in default. After the training stage, you can test the model by: + + python train.py --eval 1 + +If you'd like to use your own data, you can modify `data.py` to change the data-loading path. + + + diff --git a/cls/data.py b/cls/data.py new file mode 100644 index 0000000..dfac61e --- /dev/null +++ b/cls/data.py @@ -0,0 +1,85 @@ + +import os +import sys +import glob +import h5py +import numpy as np +from torch.utils.data import Dataset + + +def download(): + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + DATA_DIR = os.path.join(BASE_DIR, 'data') + if not os.path.exists(DATA_DIR): + os.mkdir(DATA_DIR) + if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): + www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' + zipfile = os.path.basename(www) + os.system('wget %s; unzip %s' % (www, zipfile)) + os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) + os.system('rm %s' % (zipfile)) + + +def load_data(partition): + download() + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + DATA_DIR = os.path.join(BASE_DIR, 'data') + all_data = [] + all_label = [] + for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): + f = h5py.File(h5_name) + data = f['data'][:].astype('float32') + label = f['label'][:].astype('int64') + f.close() + all_data.append(data) + all_label.append(label) + all_data = np.concatenate(all_data, axis=0) + all_label = np.concatenate(all_label, axis=0) + return all_data, all_label + + +def translate_pointcloud(pointcloud): + xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) + xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) + + translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') + return translated_pointcloud + +def normalize_pointcloud(pointcloud): + center = pointcloud.mean(axis=0) + pointcloud -= center + distance = np.linalg.norm(pointcloud, axis=1) + pointcloud /= distance.max() + return pointcloud + +def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): + N, C = pointcloud.shape + pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) + return pointcloud + +# **********Dataset ModelNet40********** + +class ModelNet40(Dataset): + def __init__(self, num_points, partition='train'): + self.data, self.label = load_data(partition) + self.num_points = num_points + self.partition = partition + + def __getitem__(self, item): + pointcloud = self.data[item][:self.num_points] + label = self.label[item] + if self.partition == 'train': + pointcloud = translate_pointcloud(pointcloud) + np.random.shuffle(pointcloud) + return pointcloud, label + + def __len__(self): + return self.data.shape[0] + + +if __name__ == '__main__': + train = ModelNet40(1024) + test = ModelNet40(1024, 'test') + for data, label in train: + print(data.shape) + print(label.shape) diff --git a/cls/model_cls.py b/cls/model_cls.py new file mode 100644 index 0000000..1c0c1a5 --- /dev/null +++ b/cls/model_cls.py @@ -0,0 +1,143 @@ +import os +import sys +import copy +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def knn(x, k): + inner = -2*torch.matmul(x.transpose(2, 1), x) + xx = torch.sum(x**2, dim=1, keepdim=True) + pairwise_distance = -xx - inner - xx.transpose(2, 1) + + idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) + return idx + + +def get_graph_feature(x, k=20, idx=None): + batch_size = x.size(0) + num_points = x.size(2) + x = x.view(batch_size, -1, num_points) + if idx is None: + idx = knn(x, k=k) # (batch_size, num_points, k) + device = x.device + + idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points + + idx = idx + idx_base + + idx = idx.view(-1) + + _, num_dims, _ = x.size() + + x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) + feature = x.view(batch_size*num_points, -1)[idx, :] + feature = feature.view(batch_size, num_points, k, num_dims) + x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) + + feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() + + return feature, idx + + +class AdaptiveConv(nn.Module): + def __init__(self, in_channels, out_channels, feat_channels): + super(AdaptiveConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.feat_channels = feat_channels + + self.conv0 = nn.Conv2d(feat_channels, out_channels, kernel_size=1, bias=False) + self.conv1 = nn.Conv2d(out_channels, out_channels*in_channels, kernel_size=1, bias=False) + self.bn0 = nn.BatchNorm2d(out_channels) + self.bn1 = nn.BatchNorm2d(out_channels) + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x, y): + # x: (bs, in_channels, num_points, k), y: (bs, feat_channels, num_points, k) + batch_size, n_dims, num_points, k = x.size() + + y = self.conv0(y) # (bs, out, num_points, k) + y = self.leaky_relu(self.bn0(y)) + y = self.conv1(y) # (bs, in*out, num_points, k) + y = y.permute(0, 2, 3, 1).view(batch_size, num_points, k, self.out_channels, self.in_channels) # (bs, num_points, k, out, in) + + x = x.permute(0, 2, 3, 1).unsqueeze(4) # (bs, num_points, k, in_channels, 1) + x = torch.matmul(y, x).squeeze(4) # (bs, num_points, k, out_channels) + x = x.permute(0, 3, 1, 2).contiguous() # (bs, out_channels, num_points, k) + + x = self.bn1(x) + x = self.leaky_relu(x) + + return x + +class Net(nn.Module): + def __init__(self, args, output_channels=40): + super(Net, self).__init__() + self.args = args + self.k = args.k + + self.bn1 = nn.BatchNorm2d(64) + self.bn2 = nn.BatchNorm2d(64) + self.bn3 = nn.BatchNorm2d(128) + self.bn4 = nn.BatchNorm2d(256) + self.bn5 = nn.BatchNorm1d(args.emb_dims) + + self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), + self.bn3, + nn.LeakyReLU(negative_slope=0.2)) + self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), + self.bn4, + nn.LeakyReLU(negative_slope=0.2)) + self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False), + self.bn5, + nn.LeakyReLU(negative_slope=0.2)) + self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False) + self.bn6 = nn.BatchNorm1d(512) + self.dp1 = nn.Dropout(p=args.dropout) + self.linear2 = nn.Linear(512, 256) + self.bn7 = nn.BatchNorm1d(256) + self.dp2 = nn.Dropout(p=args.dropout) + self.linear3 = nn.Linear(256, output_channels) + + self.adapt_conv1 = AdaptiveConv(6, 64, 6) + self.adapt_conv2 = AdaptiveConv(6, 64, 64*2) + + def forward(self, x): + batch_size = x.size(0) + points = x + + x, idx = get_graph_feature(x, k=self.k) + p, _ = get_graph_feature(points, k=self.k, idx=idx) + x = self.adapt_conv1(p, x) + x1 = x.max(dim=-1, keepdim=False)[0] + + x, idx = get_graph_feature(x1, k=self.k) + p, _ = get_graph_feature(points, k=self.k, idx=idx) + x = self.adapt_conv2(p, x) + x2 = x.max(dim=-1, keepdim=False)[0] + + x, _ = get_graph_feature(x2, k=self.k) + x = self.conv3(x) + x3 = x.max(dim=-1, keepdim=False)[0] + + x, _ = get_graph_feature(x3, k=self.k) + x = self.conv4(x) + x4 = x.max(dim=-1, keepdim=False)[0] + + x = torch.cat((x1, x2, x3, x4), dim=1) + + x = self.conv5(x) + x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) + x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) + x = torch.cat((x1, x2), 1) + + x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) + x = self.dp1(x) + x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) + x = self.dp2(x) + x = self.linear3(x) + return x diff --git a/cls/train.py b/cls/train.py new file mode 100644 index 0000000..5b744c5 --- /dev/null +++ b/cls/train.py @@ -0,0 +1,241 @@ +from __future__ import print_function +import os +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import CosineAnnealingLR +from data import ModelNet40 +import numpy as np +from torch.utils.data import DataLoader +from util import cal_loss, IOStream +import sklearn.metrics as metrics +from importlib import import_module + +TRAIN_NAME = __file__.split('.')[0] + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, default='', metavar='N', + help='Name of the experiment') + parser.add_argument('--model', type=str, default='model_cls', metavar='N', + choices=['pointnet', 'dgcnn'], + help='Model to use, [pointnet, dgcnn]') + parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N', + choices=['modelnet40']) + parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', + help='Size of batch)') + parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', + help='Size of batch)') + parser.add_argument('--epochs', type=int, default=500, metavar='N', + help='number of episode to train ') + parser.add_argument('--Tmax', type=int, default=250, metavar='N', + help='Max iteration number of scheduler. ') + parser.add_argument('--use_sgd', type=int, default=True, + help='Use SGD') + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate (default: 0.001, 0.1 if using sgd)') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--eval', type=int, default=False, + help='evaluate the model') + parser.add_argument('--num_points', type=int, default=1024, + help='num of points to use') + parser.add_argument('--dropout', type=float, default=0.5, + help='dropout rate') + parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', + help='Dimension of embeddings') + parser.add_argument('--k', type=int, default=20, metavar='N', + help='Num of nearest neighbors to use') + parser.add_argument('--model_path', type=str, default='', metavar='N', + help='Pretrained model path') + + parser.add_argument('--gpu_idx', type=int, default=0, help='set < 0 to use CPU') + + args = parser.parse_args() + + return args + +def _init_(args): + if args.name == '': + args.name = TRAIN_NAME + if not os.path.exists('models'): + os.makedirs('models') + if not os.path.exists('models/'+args.name): + os.makedirs('models/'+args.name) + if not os.path.exists('models/'+args.name+'/'+'models'): + os.makedirs('models/'+args.name+'/'+'models') + os.system('cp {}.py models/{}/{}.py.backup'.format(TRAIN_NAME, args.name, TRAIN_NAME)) + os.system('cp {}.py models/{}/{}.py.backup'.format(args.model, args.name, args.model)) + os.system('cp util.py models' + '/' + args.name + '/' + 'util.py.backup') + os.system('cp data.py models' + '/' + args.name + '/' + 'data.py.backup') + +def train(args, io): + + device = torch.device('cpu' if args.gpu_idx < 0 else 'cuda:{}'.format(args.gpu_idx)) + MODEL = import_module(args.model) + + # colored console output + green = lambda x: '\033[92m' + x + '\033[0m' + blue = lambda x: '\033[94m' + x + '\033[0m' + + torch.manual_seed(args.seed) + if args.gpu_idx < 0: + io.cprint('Using CPU') + else: + io.cprint('Using GPU: {}'.format(args.gpu_idx)) + torch.cuda.manual_seed(args.seed) + + # Load data + train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8, + batch_size=args.batch_size, shuffle=True, drop_last=True) + test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8, + batch_size=args.test_batch_size, shuffle=True, drop_last=False) + + #Try to load models + io.cprint('Using model: {}'.format(args.model)) + model = MODEL.Net(args).to(device) + print(str(model)) + + #model = nn.DataParallel(model) + #print("Let's use", torch.cuda.device_count(), "GPUs!") + + if args.use_sgd: + print("Use SGD") + opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) + else: + print("Use Adam") + opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) + + scheduler = CosineAnnealingLR(opt, args.Tmax, eta_min=args.lr) + + criterion = cal_loss + + best_test_acc = 0 + for epoch in range(args.epochs): + if epoch < args.Tmax: + scheduler.step() + elif epoch == args.Tmax: + for group in opt.param_groups: + group['lr'] = 0.0001 + + learning_rate = opt.param_groups[0]['lr'] + #################### + # Train + #################### + train_loss = 0.0 + count = 0.0 + model.train() + train_pred = [] + train_true = [] + for data, label in train_loader: + data, label = data.to(device), label.to(device).squeeze() + data = data.permute(0, 2, 1) + batch_size = data.size()[0] + opt.zero_grad() + logits = model(data) + loss = criterion(logits, label) + loss.backward() + opt.step() + preds = logits.max(dim=1)[1] + count += batch_size + train_loss += loss.item() * batch_size + train_true.append(label.cpu().numpy()) + train_pred.append(preds.detach().cpu().numpy()) + train_true = np.concatenate(train_true) + train_pred = np.concatenate(train_pred) + outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (epoch, + train_loss*1.0/count, + metrics.accuracy_score( + train_true, train_pred), + metrics.balanced_accuracy_score( + train_true, train_pred)) + io.cprint('EPOCH #{} lr = {}'.format(epoch, learning_rate)) + io.cprint(outstr) + + #################### + # Test + #################### + test_loss = 0.0 + count = 0.0 + model.eval() + test_pred = [] + test_true = [] + for data, label in test_loader: + data, label = data.to(device), label.to(device).squeeze() + data = data.permute(0, 2, 1) + batch_size = data.size()[0] + with torch.no_grad(): + logits = model(data) + loss = criterion(logits, label) + preds = logits.max(dim=1)[1] + count += batch_size + test_loss += loss.item() * batch_size + test_true.append(label.cpu().numpy()) + test_pred.append(preds.detach().cpu().numpy()) + test_true = np.concatenate(test_true) + test_pred = np.concatenate(test_pred) + test_acc = metrics.accuracy_score(test_true, test_pred) + avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) + outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch, + test_loss*1.0/count, + test_acc, + avg_per_class_acc) + io.cprint(outstr) + if test_acc > best_test_acc: + best_test_acc = test_acc + torch.save(model.state_dict(), 'models/%s/models/model.t7' % args.name) + io.cprint('Current best saved in: {}'.format('********** models/%s/models/model.t7 **********' % args.name)) + + +def test(args, io): + MODEL = import_module(args.model) + device = torch.device('cpu' if args.gpu_idx < 0 else 'cuda:{}'.format(args.gpu_idx)) + + test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), + batch_size=args.test_batch_size, shuffle=True, drop_last=False) + + io.cprint('********** TEST STAGE **********') + io.cprint('Reload best epoch:') + + #Try to load models + model = MODEL.Net(args).to(device) + model.load_state_dict(torch.load('models/%s/models/model.t7' % args.name)) + model = model.eval() + test_acc = 0.0 + count = 0.0 + test_true = [] + test_pred = [] + for data, label in test_loader: + + data, label = data.to(device), label.to(device).squeeze() + data = data.permute(0, 2, 1) + batch_size = data.size()[0] + with torch.no_grad(): + logits = model(data) + preds = logits.max(dim=1)[1] + test_true.append(label.cpu().numpy()) + test_pred.append(preds.detach().cpu().numpy()) + test_true = np.concatenate(test_true) + test_pred = np.concatenate(test_pred) + test_acc = metrics.accuracy_score(test_true, test_pred) + avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) + outstr = 'Test :: test acc: %.6f, test avg acc: %.6f'%(test_acc, avg_per_class_acc) + io.cprint(outstr) + + +if __name__ == "__main__": + args = parse_arguments() + + _init_(args) + + io = IOStream('models/' + args.name + '/train.log') + io.cprint(str(args)) + + if not args.eval: + train(args, io) + else: + test(args, io) diff --git a/cls/util.py b/cls/util.py new file mode 100644 index 0000000..29e16de --- /dev/null +++ b/cls/util.py @@ -0,0 +1,37 @@ + +import numpy as np +import torch +import torch.nn.functional as F + + +def cal_loss(pred, gold, smoothing=True): + ''' Calculate cross entropy loss, apply label smoothing if needed. ''' + + gold = gold.contiguous().view(-1) + + if smoothing: + eps = 0.2 + n_class = pred.size(1) + + one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, dim=1) + + loss = -(one_hot * log_prb).sum(dim=1).mean() + else: + loss = F.cross_entropy(pred, gold, reduction='mean') + + return loss + + +class IOStream(): + def __init__(self, path): + self.f = open(path, 'a') + + def cprint(self, text): + print(text) + self.f.write(text+'\n') + self.f.flush() + + def close(self): + self.f.close() diff --git a/part_seg/ShapeNetPart.py b/part_seg/ShapeNetPart.py new file mode 100644 index 0000000..93600d0 --- /dev/null +++ b/part_seg/ShapeNetPart.py @@ -0,0 +1,251 @@ +import sys +from util import parameter_number +import os +import json +import glob +import h5py +import numpy as np +import torch +import torch.nn as nn +import time +from torch.utils.data import Dataset, DataLoader, Sampler +from util import augmentation_transform + +# Object number in dataset: +# catgory | train | valid | test +# ---------------------------------- +# Airplane | 1958 | 391 | 341 +# Bag | 54 | 8 | 14 +# Cap | 39 | 5 | 11 +# Car | 659 | 81 | 158 +# Chair | 2658 | 396 | 704 +# Earphone | 49 | 6 | 14 +# Guitar | 550 | 78 | 159 +# Knife | 277 | 35 | 80 +# Lamp | 1118 | 143 | 286 +# Laptop | 324 | 44 | 83 +# Motorbike | 125 | 26 | 51 +# Mug | 130 | 16 | 38 +# Pistol | 209 | 30 | 44 +# Rocket | 46 | 8 | 12 +# Skateboard | 106 | 15 | 31 +# Table | 3835 | 588 | 848 + +PART_NUM = { + "Airplane": 4, + "Bag": 2, + "Cap": 2, + "Car": 4, + "Chair": 4, + "Earphone": 3, + "Guitar": 3, + "Knife": 2, + "Lamp": 4, + "Laptop": 2, + "Motorbike": 6, + "Mug": 2, + "Pistol": 3, + "Rocket": 3, + "Skateboard": 3, + "Table": 3, +} + +CLASS_NUM = { + "Airplane": 0, + "Bag": 1, + "Cap": 2, + "Car": 3, + "Chair": 4, + "Earphone": 5, + "Guitar": 6, + "Knife": 7, + "Lamp": 8, + "Laptop": 9, + "Motorbike": 10, + "Mug": 11, + "Pistol": 12, + "Rocket": 13, + "Skateboard": 14, + "Table": 15, +} + +TOTAL_PARTS_NUM = sum(PART_NUM.values()) +TOTAL_CLASS_NUM = len(PART_NUM) + +# For calculating mIoU +def get_valid_labels(category: str): + assert category in PART_NUM + base = 0 + for cat, num in PART_NUM.items(): + if category == cat: + valid_labels = [base + i for i in range(num)] + return valid_labels + else: + base += num + +def translate_pointcloud(pointcloud): + xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) + xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) + + translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') + return translated_pointcloud + +def normalize_pointcloud(pointcloud): + center = pointcloud.mean(axis=0) + pointcloud -= center + distance = np.linalg.norm(pointcloud, axis=1) + pointcloud /= distance.max() + return pointcloud + +def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): + N, C = pointcloud.shape + pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) + return pointcloud + + +class ShapeNetDataset(): + def __init__(self, root, config, num_points=1024, split='train', normalize=True): + self.num_points = num_points + self.config = config + self.split = split + self.root = root + self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') + self.cat = {} + self.normalize = normalize + + with open(self.catfile, 'r') as f: + for line in f: + ls = line.strip().split() + self.cat[ls[0]] = ls[1] + self.cat = {k:v for k,v in self.cat.items()} + #print(self.cat) + + self.meta = {} + with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: + train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: + val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: + test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + for item in self.cat: + #print('category', item) + self.meta[item] = [] + dir_point = os.path.join(self.root, self.cat[item]) + fns = sorted(os.listdir(dir_point)) + #print(fns[0][0:-4]) + if split=='trainval': + fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] + elif split=='train': + fns = [fn for fn in fns if fn[0:-4] in train_ids] + elif split=='val': + fns = [fn for fn in fns if fn[0:-4] in val_ids] + elif split=='test': + fns = [fn for fn in fns if fn[0:-4] in test_ids] + else: + print('Unknown split: %s. Exiting..'%(split)) + exit(-1) + + #print(os.path.basename(fns)) + for fn in fns: + token = (os.path.splitext(os.path.basename(fn))[0]) + self.meta[item].append(os.path.join(dir_point, token + '.txt')) + + self.datapath = [] + for item in self.cat: + for fn in self.meta[item]: + self.datapath.append((item, fn)) + + + self.classes = dict(zip(self.cat, range(len(self.cat)))) + # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels + self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} + + for cat in sorted(self.seg_classes.keys()): + print(cat, self.seg_classes[cat]) + + self.cache = {} # from index to (point_set, cls, seg) tuple + self.cache_size = 20000 + + def __getitem__(self, index): + if index in self.cache: + point_set, normal, seg, cls = self.cache[index] + cat = self.datapath[index][0] + else: + fn = self.datapath[index] + cat = self.datapath[index][0] # cat name + cls = self.classes[cat] # cat index + cls = np.array([cls]).astype(np.int64) + data = np.loadtxt(fn[1]).astype(np.float32) + point_set = data[:,0:3] + if self.normalize: + point_set = self.pc_normalize(point_set) + normal = data[:,3:6] + seg = data[:,-1].astype(np.int64) + if len(self.cache) < self.cache_size: + self.cache[index] = (point_set, normal, seg, cls) + + if self.split == 'train' or self.split == 'trainval': + point_set, normal = augmentation_transform(point_set, self.config, normals=normal) + + # sample to num_points + if self.num_points <= point_set.shape[0]: + sample_ids = np.random.permutation(point_set.shape[0])[:self.num_points] + else: + sample_ids = np.random.choice(point_set.shape[0], self.num_points, replace=True) + #resample + point_set = point_set[sample_ids, :] + seg = seg[sample_ids] + normal = normal[sample_ids, :] + point_set = np.concatenate((point_set, normal), axis=1) + + mask = self.get_mask(cat) + onehot = self.get_catgory_onehot(cls) + obj_id = 0 + + return cat, obj_id, point_set, seg, mask, onehot + + def pc_normalize(self, pc): + l = pc.shape[0] + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + pc = pc / m + return pc + + def get_mask(self, category): + mask = torch.zeros(TOTAL_PARTS_NUM) + mask[self.seg_classes[category]] = 1 + mask = mask.unsqueeze(0).repeat(self.num_points, 1) + return mask + + def get_catgory_onehot(self, cat_id): + onehot = torch.zeros(len(self.cat)) + onehot[cat_id] = 1 + return onehot + + def __len__(self): + return len(self.datapath) + + +class PartSegConfig(): + # Augmentations + augment_scale_anisotropic = True + augment_symmetries = [False, False, False] + normal_scale = True + augment_shift = None + augment_rotation = 'none' + augment_scale_min = 0.8 + augment_scale_max = 1.25 + augment_noise = 0.002 + augment_noise_clip = 0.05 + augment_occlusion = 'none' + +if __name__ == '__main__': + data = ShapeNetDataset(root='./data/shapenetcore_partanno_segmentation_benchmark_v0_normal', config=PartSegConfig(), split='trainval', num_points=1024) + print('datapath', len(data.datapath), data.datapath[0]) + print('classes', data.classes) + print('seg_classes', data.seg_classes) + + cat, obj_id, point_set, seg, mask, onehot = data[0] + print(cat) + print(point_set.shape, seg.shape, mask, onehot) diff --git a/part_seg/graph_pooling.py b/part_seg/graph_pooling.py new file mode 100644 index 0000000..4b6f8ee --- /dev/null +++ b/part_seg/graph_pooling.py @@ -0,0 +1,236 @@ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +def get_nearest_index(target: "(bs, 3, v1)", source: "(bs, 3, v2)"): + """ + Return: (bs, v1, 1) + """ + inner = torch.bmm(target.transpose(1, 2), source) #(bs, v1, v2) + s_norm_2 = torch.sum(source ** 2, dim=1) #(bs, v2) + t_norm_2 = torch.sum(target ** 2, dim=1) #(bs, v1) + d_norm_2 = s_norm_2.unsqueeze(1) + t_norm_2.unsqueeze(2) - 2 * inner + nearest_index = torch.topk(d_norm_2, k= 1, dim= -1, largest= False)[1] + return nearest_index + +def indexing_neighbor(x: "(bs, dim, num_points0)", index: "(bs, num_points, k)" ): + """ + Return: (bs, dim, num_points, neighbor_num) + """ + batch_size, num_points, k = index.size() + + id_0 = torch.arange(batch_size).view(-1, 1, 1) + + x = x.transpose(2, 1).contiguous() # (bs, num_points, num_dims) + feature = x[id_0, index] # (bs, num_points, k, num_dims) + feature = feature.permute(0, 3, 1, 2).contiguous() # (bs, num_dims, num_points, k) + ''' + idx_base = torch.arange(0, batch_size, device=index.device).view(-1, 1, 1)*num_points + index = index + idx_base + index = index.view(-1) + + x = x.transpose(2, 1).contiguous() # (bs, num_points, num_dims) + feature = x.view(batch_size*num_points, -1)[index, :] + feature = feature.view(batch_size, num_points, k, num_dims) + + feature = feature.permute(0, 3, 1, 2).contiguous() # (bs, num_dims, num_points, k)''' + + return feature + + +def knn(x, k): + inner = -2*torch.matmul(x.transpose(2, 1), x) + xx = torch.sum(x**2, dim=1, keepdim=True) + pairwise_distance = -xx - inner - xx.transpose(2, 1) + + idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) + return idx + + +def get_graph_feature(x, k=20, idx=None): + batch_size = x.size(0) + num_points = x.size(2) + x = x.view(batch_size, -1, num_points) + if idx is None: + idx = knn(x, k=k) # (batch_size, num_points, k) + device = x.device + + idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points + + idx = idx + idx_base + + idx = idx.view(-1) + + _, num_dims, _ = x.size() + + x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) + feature = x.view(batch_size*num_points, -1)[idx, :] + feature = feature.view(batch_size, num_points, k, num_dims) + #x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) + + feature = feature.permute(0, 3, 1, 2).contiguous() + + return feature, idx + +def index_points(points, idx): + """ + + Input: + points: input points data, [B, N, C] + idx: sample index data, [B, S] + Return: + new_points:, indexed points data, [B, S, C] + """ + points = points.transpose(2,1).contiguous() + device = points.device + B = points.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) + new_points = points[batch_indices, idx, :] + return new_points.transpose(2,1).contiguous() + +def index_feature(x, idx): + """ + Input: + x: input data, [bs, num_dims, num_points, k] + idx: sample index data, [bs, new_npoints] + Return: + x:, indexed points data, [bs, num_dims, new_npoints, k] + """ + x = x.permute(0, 2, 1, 3).contiguous() # (bs, num_points, num_dims, k) + device = x.device + B = x.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) + x = x[batch_indices, idx, :] + return x.permute(0, 2, 1, 3).contiguous() + + +def farthest_point_sample(xyz, npoint): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + xyz = xyz.transpose(2,1).contiguous() + device = xyz.device + B, N, C = xyz.shape + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + centroids[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + return centroids + +# ***********************Pooling Layer*********************** + +class Pooling_fps(nn.Module): + def __init__(self, pooling_rate, neighbor_num): + super().__init__() + self.pooling_rate = pooling_rate + self.neighbor_num = neighbor_num + + def forward(self, + vertices: "(bs, 3, vertice_num)", + feature_map: "(bs, channel_num, vertice_num)", + idx): + """ + Return: + vertices_pool: (bs, 3, pool_vertice_num), + feature_map_pool: (bs, channel_num, pool_vertice_num) + """ + + bs, _, vertice_num = vertices.size() + neighbor_feature, _ = get_graph_feature(feature_map, k=self.neighbor_num, idx=idx) # (bs, num_dims, num_points, k) + pooled_feature = torch.max(neighbor_feature, dim=-1)[0] #(bs, num_dims, num_points) + + new_npoints = int(vertice_num / self.pooling_rate) + new_points_idx = farthest_point_sample(vertices, new_npoints) #(bs, new_npoints) + vertices_pool = index_points(vertices, new_points_idx) # (bs, 3, new_npoints) + feature_map_pool = index_points(pooled_feature, new_points_idx) #(bs, num_dims, new_npoints) + + return vertices_pool, feature_map_pool + +class Pooling_strided(nn.Module): + def __init__(self, pooling_rate, neighbor_num, in_channels): + super().__init__() + self.pooling_rate = pooling_rate + self.neighbor_num = neighbor_num + self.conv_layer = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(in_channels), + nn.LeakyReLU(negative_slope=0.2)) + + + def forward(self, + vertices: "(bs, 3, vertice_num)", + feature_map: "(bs, channel_num, vertice_num)", + idx): + """ + Return: + vertices_pool: (bs, 3, pool_vertice_num), + x: (bs, new_npoints, num_dims) + """ + + bs, _, vertice_num = vertices.size() + neighbor_feature, _ = get_graph_feature(feature_map, k=self.neighbor_num, idx=idx) # (bs, num_dims, num_points, k) + #neighbor_feature = neighbor_feature.permute(0,2,1,3) # (bs, num_points, num_dims, k) + #pooled_feature = torch.max(neighbor_feature, dim=-1)[0] #(bs, num_dims, num_points) + + # downsample + new_npoints = int(vertice_num / self.pooling_rate) + new_points_idx = farthest_point_sample(vertices, new_npoints) #(bs, new_npoints) + x = index_feature(neighbor_feature, new_points_idx) # (bs, num_dims, new_npoints, k) + vertices_pool = index_points(vertices, new_points_idx) # (bs, 3, new_npoints) + + x = self.conv_layer(x) + x = x.max(dim=-1, keepdim=False)[0] # (bs, num_dims, new_npoints) + + return vertices_pool, x + +def test(): + import time + bs = 8 + v = 1024 + dim = 6 + n = 20 + device='cuda:0' + + pool = Pool_layer(pooling_rate= 4, neighbor_num= 20).to(device) + + points = torch.randn(bs, 3, v).to(device) + x = torch.randn(bs, dim, v).to(device) + _, neighbor_idx = get_graph_feature(points, k=20) + print('points: {}, x: {}, neighbor_idx: {}'.format(points.size(), x.size(), neighbor_idx.size())) + + points1, x1 = pool(points, x, neighbor_idx) + print('points1: {}, x1: {}'.format(points1.size(), x1.size())) + + nearest_pool_1 = get_nearest_index(points, points1) + print('nearest_pool_1: {}'.format(nearest_pool_1.size())) + print(nearest_pool_1.device) + + x2 = indexing_neighbor(x1, nearest_pool_1).squeeze(3) + print('x2: {}'.format(x2.size())) + print(x2.device) + + + +if __name__ == "__main__": + test() diff --git a/part_seg/manager.py b/part_seg/manager.py new file mode 100644 index 0000000..fc87850 --- /dev/null +++ b/part_seg/manager.py @@ -0,0 +1,77 @@ +import os +import sys +sys.path.append('../') +from util import * +import numpy as np +import torch +import torch.optim as optim +import torch.nn as nn +from ShapeNetPart import get_valid_labels +from visualize import visualize + +def get_miou(pred: "tensor (point_num, )", target: "tensor (point_num, )", valid_labels: list): + pred, target = pred.cpu().numpy(), target.cpu().numpy() + part_ious = [] + for part_id in valid_labels: + pred_part = (pred == part_id) + target_part = (target == part_id) + I = np.sum(np.logical_and(pred_part, target_part)) + U = np.sum(np.logical_or( pred_part, target_part)) + if U == 0: + part_ious.append(1) + else: + part_ious.append(I/U) + miou = np.mean(part_ious) + return miou + + +class IouTable(): + def __init__(self): + self.obj_miou = {} + + def add_obj_miou(self, category: str, miou: float): + if category not in self.obj_miou: + self.obj_miou[category] = [miou] + else: + self.obj_miou[category].append(miou) + + def get_category_miou(self): + """ + Return: moiu table of each category + """ + category_miou = {} + for c, mious in self.obj_miou.items(): + category_miou[c] = np.mean(mious) + return category_miou + + def get_mean_category_miou(self): + category_miou = [] + for c, mious in self.obj_miou.items(): + c_miou = np.mean(mious) + category_miou.append(c_miou) + return np.mean(category_miou) + + def get_mean_instance_miou(self): + object_miou = [] + for c, mious in self.obj_miou.items(): + object_miou += mious + return np.mean(object_miou) + + def get_string(self): + mean_c_miou = self.get_mean_category_miou() + mean_i_miou = self.get_mean_instance_miou() + first_row = "| {:5} | {:5} ||".format("Avg_c", "Avg_i") + second_row = "| {:.3f} | {:.3f} ||".format(mean_c_miou, mean_i_miou) + + categories = list(self.obj_miou.keys()) + categories.sort() + cate_miou = self.get_category_miou() + + for c in categories: + miou = cate_miou[c] + first_row += " {:5} |".format(c[:3]) + second_row += " {:.3f} |".format(miou) + + string = first_row + "\n" + second_row + return string + diff --git a/part_seg/model_seg.py b/part_seg/model_seg.py new file mode 100644 index 0000000..96d4d75 --- /dev/null +++ b/part_seg/model_seg.py @@ -0,0 +1,310 @@ +import os +import sys +import copy +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +import graph_pooling as gp + + +def knn(x, k): + inner = -2*torch.matmul(x.transpose(2, 1), x) + xx = torch.sum(x**2, dim=1, keepdim=True) + pairwise_distance = -xx - inner - xx.transpose(2, 1) + + idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) + return idx + + +def get_graph_feature(x, k=20, idx=None, dim6=False): + batch_size = x.size(0) + num_points = x.size(2) + x = x.view(batch_size, -1, num_points) + if idx is None: + if dim6 == False: + idx = knn(x, k=k) # (batch_size, num_points, k) + else: + idx = knn(x[:, 0:3], k=k) + device = x.device + + idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points + + idx = idx + idx_base + + idx = idx.view(-1) + + _, num_dims, _ = x.size() + + x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) + feature = x.view(batch_size*num_points, -1)[idx, :] + feature = feature.view(batch_size, num_points, k, num_dims) + x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) + + feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() + + return feature, idx + + +class AdaptiveConv(nn.Module): + def __init__(self, k, in_channels, feat_channels, nhiddens, out_channels): + super(AdaptiveConv, self).__init__() + self.in_channels = in_channels + self.nhiddens = nhiddens + self.out_channels = out_channels + self.feat_channels = feat_channels + self.k = k + + self.conv0 = nn.Conv2d(feat_channels, nhiddens, kernel_size=1, bias=False) + self.conv1 = nn.Conv2d(nhiddens, nhiddens*in_channels, kernel_size=1, bias=False) + self.bn0 = nn.BatchNorm2d(nhiddens) + self.bn1 = nn.BatchNorm2d(nhiddens) + self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) + self.residual_layer = nn.Sequential(nn.Conv2d(feat_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + ) + self.linear = nn.Sequential(nn.Conv2d(nhiddens, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels)) + + def forward(self, points, feat, idx): + # points: (bs, in_channels, num_points), feat: (bs, feat_channels/2, num_points) + batch_size, _, num_points = points.size() + + x, _ = get_graph_feature(points, k=self.k, idx=idx) # (bs, in_channels, num_points, k) + y, _ = get_graph_feature(feat, k=self.k, idx=idx) # (bs, feat_channels, num_points, k) + + kernel = self.conv0(y) # (bs, nhiddens, num_points, k) + kernel = self.leaky_relu(self.bn0(kernel)) + kernel = self.conv1(kernel) # (bs, in*nhiddens, num_points, k) + kernel = kernel.permute(0, 2, 3, 1).view(batch_size, num_points, self.k, self.nhiddens, self.in_channels) # (bs, num_points, k, nhiddens, in) + + x = x.permute(0, 2, 3, 1).unsqueeze(4) # (bs, num_points, k, in_channels, 1) + x = torch.matmul(kernel, x).squeeze(4) # (bs, num_points, k, nhiddens) + x = x.permute(0, 3, 1, 2).contiguous() # (bs, nhiddens, num_points, k) + + # nhiddens -> out_channels + x = self.leaky_relu(self.bn1(x)) + x = self.linear(x) # (bs, out_channels, num_points, k) + # residual: feat_channels -> out_channels + residual = self.residual_layer(y) + x += residual + x = self.leaky_relu(x) + + x = x.max(dim=-1, keepdim=False)[0] # (bs, out_channels, num_points) + + return x + +class GraphConv(nn.Module): + def __init__(self, in_channels, out_channels, k): + super(GraphConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.k = k + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.LeakyReLU(negative_slope=0.2)) + + def forward(self, x, idx): + # x: (bs, in_channels, num_points) + x, _ = get_graph_feature(x, k=self.k, idx=idx) # (bs, in_channels, num_points, k) + x = self.conv(x) # (bs, out_channels, num_points, k) + x = x.max(dim=-1, keepdim=False)[0] # (bs, out_channels, num_points) + + return x + +class ConvLayer(nn.Module): + def __init__(self, para, k, in_channels, feat_channels): + super(ConvLayer, self).__init__() + self.type = para[0] + self.out_channels = para[1] + self.k = k + if self.type == 'adapt': + self.layer = AdaptiveConv(k, in_channels, feat_channels, nhiddens=para[2], out_channels=para[1]) + elif self.type == 'graph': + self.layer = GraphConv(feat_channels, self.out_channels, k) + elif self.type == 'conv1d': + self.layer = nn.Sequential(nn.Conv1d(int(feat_channels/2), self.out_channels, kernel_size=1, bias=False), + nn.BatchNorm1d(self.out_channels), + nn.LeakyReLU(negative_slope=0.2)) + else: + raise ValueError('Unknown convolution layer: {}'.format(self.type)) + + def forward(self, points, x, idx): + # points: (bs, 3, num_points), x: (bs, feat_channels/2, num_points) + if self.type == 'conv1d': + x = self.layer(x) + x = x.max(dim=-1, keepdim=False)[0] # (bs, num_dims) + elif self.type == 'adapt': + x = self.layer(points, x, idx) + elif self.type == 'graph': + x = self.layer(x, idx) + + return x + + +class Transform_Net(nn.Module): + def __init__(self, in_channels=6, out_channels=3): + super(Transform_Net, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.bn1 = nn.BatchNorm2d(64) + self.bn2 = nn.BatchNorm2d(128) + self.bn3 = nn.BatchNorm1d(1024) + + self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=1, bias=False), + self.bn1, + nn.LeakyReLU(negative_slope=0.2)) + self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), + self.bn2, + nn.LeakyReLU(negative_slope=0.2)) + self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False), + self.bn3, + nn.LeakyReLU(negative_slope=0.2)) + + self.linear1 = nn.Linear(1024, 512, bias=False) + self.bn3 = nn.BatchNorm1d(512) + self.linear2 = nn.Linear(512, 256, bias=False) + self.bn4 = nn.BatchNorm1d(256) + + self.transform = nn.Linear(256, out_channels*out_channels) + init.constant_(self.transform.weight, 0) + init.eye_(self.transform.bias.view(out_channels, out_channels)) + + def forward(self, x): + batch_size = x.size(0) + + x = self.conv1(x) # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k) + x = self.conv2(x) # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k) + x = x.max(dim=-1, keepdim=False)[0] # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points) + + x = self.conv3(x) # (batch_size, 128, num_points) -> (batch_size, 1024, num_points) + x = x.max(dim=-1, keepdim=False)[0] # (batch_size, 1024, num_points) -> (batch_size, 1024) + + x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2) # (batch_size, 1024) -> (batch_size, 512) + x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256) + + x = self.transform(x) # (batch_size, 256) -> (batch_size, 3*3) + x = x.view(batch_size, self.out_channels, self.out_channels) # (batch_size, 3*3) -> (batch_size, 3, 3) + + return x + + +class Net(nn.Module): + def __init__(self, args, class_num, cat_num, use_stn=True): + super(Net, self).__init__() + self.args = args + self.k = args.k + self.class_num = class_num + self.cat_num = cat_num + self.use_stn = use_stn + + # architecture + self.in_channels = 6 + self.forward_para = [['adapt', 64, 64], + ['adapt', 64, 64], + ['pool', 4], + ['adapt', 128, 64], + ['pool', 4], + ['adapt', 256, 64], + ['pool', 2], + ['graph', 512], + ['conv1d', 1024]] + self.agg_channels = 0 + + # layers + self.forward_layers = nn.ModuleList() + feat_channels = 12 + for i, para in enumerate(self.forward_para): + if para[0] == 'pool': + self.forward_layers.append(gp.Pooling_fps(pooling_rate=para[1], neighbor_num=self.k)) + else: + self.forward_layers.append(ConvLayer(para, self.k, self.in_channels, feat_channels)) + self.agg_channels += para[1] + feat_channels = para[1]*2 + + self.agg_channels += 64 + + self.conv_onehot = nn.Sequential(nn.Conv1d(cat_num, 64, kernel_size=1, bias=False), + nn.BatchNorm1d(64), + nn.LeakyReLU(negative_slope=0.2)) + + self.conv1d = nn.Sequential( + nn.Conv1d(self.agg_channels, 512, kernel_size=1), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Dropout(p=args.dropout), + nn.Conv1d(512, 256, kernel_size=1), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Dropout(p=args.dropout), + nn.Conv1d(256, class_num, kernel_size=1), + ) + + if self.use_stn: + self.stn = Transform_Net(in_channels=12, out_channels=3) + + + def forward(self, x, onehot): + # x: (bs, num_points, 6), onehot: (bs, cat_num) + x = x.permute(0, 2, 1).contiguous() # (bs, 6, num_points) + batch_size = x.size(0) + num_points = x.size(2) + + if self.use_stn: + x0, _ = get_graph_feature(x, k=self.k) + t = self.stn(x0) + p1 = torch.bmm(x[:,0:3,:].transpose(2, 1), t) # (bs, num_points, 3) + p2 = torch.bmm(x[:,3:6,:].transpose(2, 1), t) + x = torch.cat((p1, p2), dim=2).transpose(2, 1).contiguous() # (bs, 6, num_points) + points = x[:,0:3,:] # (bs, 3, num_points) + + # forward + feat_forward = [] + points_forward = [points] + _, idx = get_graph_feature(points, k=self.k) + for i, block in enumerate(self.forward_layers): + if self.forward_para[i][0] == 'pool': + points, x = block(points, x, idx) + points_forward.append(points) + _, idx = get_graph_feature(points, k=self.k) + elif self.forward_para[i][0] == 'conv1d': + x = block(points, x, idx) + x = x.unsqueeze(2).repeat(1, 1, num_points) + feat_forward.append(x) + else: + x = block(points, x, idx) + feat_forward.append(x) + + # onehot + onehot = onehot.unsqueeze(2) + onehot_expand = self.conv_onehot(onehot) + onehot_expand = onehot_expand.repeat(1, 1, num_points) + + # aggregating features from all layers + x_agg = [] + points0 = points_forward.pop(0) + points = None + for i, para in enumerate(self.forward_para): + if para[0] == 'pool': + points = points_forward.pop(0) + else: + x = feat_forward.pop(0) + if x.size(2) == points0.size(2): + x_agg.append(x) + continue + idx = gp.get_nearest_index(points0, points) + x_upsample = gp.indexing_neighbor(x, idx).squeeze(3) + x_agg.append(x_upsample) + x = torch.cat(x_agg, dim=1) + x = torch.cat((x, onehot_expand), dim=1) + x = self.conv1d(x) + x = x.permute(0, 2, 1).contiguous() # (bs, num_points, class_num) + + return x + + diff --git a/part_seg/part_segmentation.md b/part_seg/part_segmentation.md new file mode 100644 index 0000000..3399768 --- /dev/null +++ b/part_seg/part_segmentation.md @@ -0,0 +1,15 @@ +## Part Segmentation on ShapeNet + +### Data + +We use the ShapeNetPart dataset (xyz, normals and labels) from [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip). Download the dataset and place it to `data/shapenetcore_partanno_segmentation_benchmark_v0_normal/`. + +### Usage + +The settings are similar as in our classification experiment. To train a model for part segmentation (require 2 gpus for 2048 points input): + + python train.py --gpu_idx 0 1 + + + + diff --git a/part_seg/train.py b/part_seg/train.py new file mode 100644 index 0000000..5854cd2 --- /dev/null +++ b/part_seg/train.py @@ -0,0 +1,260 @@ +import os +import sys +from util import * +import argparse +import torch +import numpy as np +import torch +import torch.optim as optim +import torch.nn as nn +from torch.utils.data import DataLoader +from manager import IouTable, get_miou +from ShapeNetPart import ShapeNetDataset, get_valid_labels +from importlib import import_module +from visualize import visualize + + +TRAIN_NAME = __file__.split('.')[0] + +class PartSegConfig(): + + #################### + # Dataset parameters + #################### + + # Augmentations + augment_scale_anisotropic = True + augment_symmetries = [False, False, False] + normal_scale = True + augment_shift = None + augment_rotation = 'none' + augment_scale_min = 0.8 + augment_scale_max = 1.25 + augment_noise = 0.002 + augment_noise_clip = 0.05 + augment_occlusion = 'none' + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, default='', metavar='N', + help='Name of the experiment') + parser.add_argument('--model', type=str, default='model_seg', metavar='N', + help='Model to use, [pointnet, dgcnn]') + parser.add_argument('--gpu_idx', type=int, default=[0,1], nargs='+', + help='set < 0 to use CPU') + parser.add_argument('--k', type=int, default=20, metavar='N', + help='Num of nearest neighbors to use') + parser.add_argument('--dropout', type=float, default=0.5, + help='dropout rate') + parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', + help='Dimension of embeddings') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') + parser.add_argument('--Tmax', type=int, default=100, metavar='N', + help='Max iteration number of scheduler. ') + parser.add_argument('--mode', default= 'train', help= '[train/test]') + parser.add_argument('--epoch', type= int, default= 200, help= 'Epoch number') + parser.add_argument('--lr', type= float, default= 0.001, help= 'Learning rate') + parser.add_argument('--bs', type= int, default= 32, help= 'Batch size') + parser.add_argument('--dataset', type=str, default='data/shapenetcore_partanno_segmentation_benchmark_v0_normal', help= "Path to ShapeNetPart") + parser.add_argument('--load', help= 'Path to load model') + parser.add_argument('--save', type=str, default='model.pkl', help= 'Path to save model') + parser.add_argument('--record', type=str, default='record.log', help= 'Record file name (e.g. record.log)') + parser.add_argument('--interval', type= int, default=100, help= 'Record interval within an epoch') + parser.add_argument('--point', type= int, default= 2048, help= 'Point number per object') + parser.add_argument('--output', help= 'Folder for visualization images') + # Transform + parser.add_argument('--normal', dest= 'normal', action= 'store_true', help= 'Normalize objects (zero--mean, unit size)') + parser.set_defaults(normal= False) + parser.add_argument('--shift', type= float, help= 'Shift objects (original: 0.0)') + parser.add_argument('--scale', type= float, help= 'Enlarge/shrink objects (original: 1.0)') + parser.add_argument('--rotate', type= float, help= 'Rotate objects in degree (original: 0.0)') + parser.add_argument('--axis', type= int, default= 1, help= 'Rotation axis [0, 1, 2] (upward = 1)') # upward axis = 1 + parser.add_argument('--random', dest= 'random', action= 'store_true', help= 'Randomly transform in a given range') + parser.set_defaults(random= False) + args = parser.parse_args() + + if args.name == '': + args.name = TRAIN_NAME + + config = PartSegConfig() + + MODEL = import_module(args.model) + model = MODEL.Net(args=args, class_num=50, cat_num=16) + manager = Manager(model, args) + + if args.mode == "train": + print("Training ...") + train_data = ShapeNetDataset(root=args.dataset, config=config, num_points=args.point, split='trainval') + train_loader = DataLoader(train_data, shuffle=True, batch_size=args.bs, drop_last=True) + test_data = ShapeNetDataset(root=args.dataset, config=config, num_points=args.point, split='test') + test_loader = DataLoader(test_data, shuffle=False, batch_size=args.bs, drop_last=False) + manager.train(train_loader, test_loader) + + elif args.mode == "test": + print("Testing ...") + test_data = ShapeNetDataset(root=args.dataset, config=config, num_points=args.point, split='test') + test_loader = DataLoader(test_data, shuffle=False, batch_size=args.bs, drop_last=False) + + test_loss, test_table_str = manager.test(test_loader, args.output) + print(test_table_str) + + +class Manager(): + def __init__(self, model, args): + self.args_info = args.__str__() + self.device = torch.device('cpu' if len(args.gpu_idx) == 0 else 'cuda:{}'.format(args.gpu_idx[0])) + if args.load: + model.load_state_dict(torch.load(args.load)) + self.model = model.to(self.device) + self.model = nn.DataParallel(self.model, device_ids=args.gpu_idx) + print('Now use {} GPUs: {}'.format(len(args.gpu_idx), args.gpu_idx)) + + self.epoch = args.epoch + self.Tmax = args.Tmax + self.optimizer = optim.SGD(self.model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) + self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=args.Tmax, eta_min=args.lr) + self.loss_function = nn.CrossEntropyLoss() + + self.save = os.path.join('models', args.name, args.save) + if not os.path.exists(os.path.join('models', args.name)): + os.makedirs(os.path.join('models', args.name)) + self.record_interval = args.interval + self.record_file = None + if args.record: + self.record_file = open(os.path.join('models', args.name, args.record), 'w') + + self.out_dir = args.output + self.best = {"c_miou": 0, "i_miou": 0} + + def update_best(self, c_miou, i_miou): + self.best["c_miou"] = max(self.best["c_miou"], c_miou) + self.best["i_miou"] = max(self.best["i_miou"], i_miou) + + def record(self, info): + print(info) + if self.record_file: + self.record_file.write(info + '\n') + self.record_file.flush() + + def calculate_save_mious(self, iou_table, category_names, labels, predictions): + for i in range(len(category_names)): + category = category_names[i] + pred = predictions[i] + label = labels[i] + valid_labels = get_valid_labels(category) + miou = get_miou(pred, label, valid_labels) + iou_table.add_obj_miou(category, miou) + + def save_visualizations(self, dir, category_names, object_ids, points, labels, predictions): + for i in range(len(category_names)): + cat = category_names[i] + valid_labels = get_valid_labels(cat) + shift = min(valid_labels) * (-1) + obj_id = object_ids[i] + point = points[i].to("cpu") + label = labels[i].to("cpu") + shift + pred = predictions[i].to("cpu") + shift + + cat_dir = os.path.join(dir, cat) + if not os.path.isdir(cat_dir): + os.mkdir(cat_dir) + gt_fig_name = os.path.join(cat_dir, "{}_gt.png".format(obj_id)) + pred_fig_name = os.path.join(cat_dir, "{}_pred.png".format(obj_id)) + visualize(point, label, gt_fig_name) + visualize(point, pred, pred_fig_name) + + def train(self, train_data, test_data): + self.record("*****************************************") + self.record("Hyper-parameters: {}".format(self.args_info)) + self.record("Model parameter number: {}".format(parameter_number(self.model))) + self.record("Model structure: \n{}".format(self.model.__str__())) + self.record("*****************************************") + + for epoch in range(self.epoch): + self.model.train() + train_loss = 0 + train_iou_table = IouTable() + learning_rate = self.optimizer.param_groups[0]['lr'] + for i, (cat_name, obj_ids, points, labels, mask, onehot) in enumerate(train_data): + points = points.to(self.device) + labels = labels.to(self.device) + onehot = onehot.to(self.device) + out = self.model(points, onehot) + + self.optimizer.zero_grad() + loss = self.loss_function(out.reshape(-1, out.size(-1)), labels.view(-1,)) + loss.backward() + self.optimizer.step() + train_loss += loss.item() + + out[mask == 0] = out.min() + pred = torch.max(out, 2)[1] + self.calculate_save_mious(train_iou_table, cat_name, labels, pred) + + if self.record_interval and ((i + 1) % self.record_interval == 0): + c_miou = train_iou_table.get_mean_category_miou() + i_miou = train_iou_table.get_mean_instance_miou() + self.record(' epoch {:3} step {:5} | avg loss: {:.3f} | miou(c): {:.3f} | miou(i): {:.3f}'.format(epoch+1, i+1, train_loss/(i + 1), c_miou, i_miou)) + + train_loss /= (i+1) + train_table_str = train_iou_table.get_string() + test_loss, test_table_str = self.test(test_data, self.out_dir) + if epoch < self.Tmax: + self.lr_scheduler.step() + elif epoch == self.Tmax: + for group in self.optimizer.param_groups: + group['lr'] = 0.0001 + + if self.save: + torch.save(self.model.state_dict(), self.save) + + self.record("==== Epoch {:3} ====".format(epoch + 1)) + self.record("lr = {}".format(learning_rate)) + self.record("Training mIoU:") + self.record(train_table_str) + self.record("Testing mIoU:") + self.record(test_table_str) + self.record("* Best mIoU(c): {:.3f}, Best mIoU (i): {:.3f} \n".format(self.best["c_miou"], self.best["i_miou"])) + + def test(self, test_data, out_dir= None): + if out_dir: + if not os.path.isdir(out_dir): + os.mkdir(out_dir) + + self.model.eval() + test_loss = 0 + test_iou_table = IouTable() + + for i, (cat_name, obj_ids, points, labels, mask, onehot) in enumerate(test_data): + points = points.to(self.device) + labels = labels.to(self.device) + onehot = onehot.to(self.device) + with torch.no_grad(): + out = self.model(points, onehot) + loss = self.loss_function(out.reshape(-1, out.size(-1)), labels.view(-1,)) + test_loss += loss.item() + + out[mask == 0] = out.min() + pred = torch.max(out, 2)[1] + self.calculate_save_mious(test_iou_table, cat_name, labels, pred) + if out_dir: + self.save_visualizations(out_dir, cat_name, obj_ids, points, labels, pred) + + test_loss /= (i+1) + c_miou = test_iou_table.get_mean_category_miou() + i_miou = test_iou_table.get_mean_instance_miou() + self.update_best(c_miou, i_miou) + test_table_str = test_iou_table.get_string() + + if out_dir: + miou_file = open(os.path.join(out_dir, "miou.txt"), "w") + miou_file.write(test_table_str) + + return test_loss, test_table_str + + +if __name__ == '__main__': + main() diff --git a/part_seg/util.py b/part_seg/util.py new file mode 100644 index 0000000..1c00d3a --- /dev/null +++ b/part_seg/util.py @@ -0,0 +1,166 @@ +import numpy as np +import torch + +def parameter_number(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def normal2unit(vertices: "(vertice_num, 3)"): + """ + Return: (vertice_num, 3) => normalized into unit sphere + """ + center = vertices.mean(dim= 0) + vertices -= center + distance = vertices.norm(dim= 1) + vertices /= distance.max() + return vertices + +def rotate(points, degree: float, axis: int): + """Rotate along upward direction""" + rotate_matrix = torch.eye(3) + theta = (degree/360)*2*np.pi + cos = np.cos(theta) + sin = np.sin(theta) + + axises = [0, 1, 2] + assert axis in axises + axises.remove(axis) + + rotate_matrix[axises[0], axises[0]] = cos + rotate_matrix[axises[0], axises[1]] = -sin + rotate_matrix[axises[1], axises[0]] = sin + rotate_matrix[axises[1], axises[1]] = cos + points = points @ rotate_matrix + return points + + +def augmentation_transform(points, config, normals=None, verbose=False): + """Implementation of an augmentation transform for point clouds.""" + + ########## + # Rotation + ########## + + # Initialize rotation matrix + R = np.eye(points.shape[1]) + + if points.shape[1] == 3: + if config.augment_rotation == 'vertical': + + # Create random rotations + theta = np.random.rand() * 2 * np.pi + c, s = np.cos(theta), np.sin(theta) + R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32) + + elif config.augment_rotation == 'all': + + # Choose two random angles for the first vector in polar coordinates + theta = np.random.rand() * 2 * np.pi + phi = (np.random.rand() - 0.5) * np.pi + + # Create the first vector in carthesian coordinates + u = np.array([np.cos(theta) * np.cos(phi), np.sin(theta) * np.cos(phi), np.sin(phi)]) + + # Choose a random rotation angle + alpha = np.random.rand() * 2 * np.pi + + # Create the rotation matrix with this vector and angle + R = create_3D_rotations(np.reshape(u, (1, -1)), np.reshape(alpha, (1, -1)))[0] + + R = R.astype(np.float32) + + ####### + # Scale + ####### + + # Choose random scales for each example + min_s = config.augment_scale_min + max_s = config.augment_scale_max + if config.augment_scale_anisotropic: + scale = np.random.uniform(min_s, max_s, points.shape[1]) + else: + scale = np.random.uniform(min_s, max_s) + + # Add random symmetries to the scale factor + symmetries = np.array(config.augment_symmetries).astype(np.int32) + symmetries *= np.random.randint(2, size=points.shape[1]) + scale = (scale * (1 - symmetries * 2)).astype(np.float32) + + ####### + # Noise + ####### + + noise = (np.random.randn(points.shape[0], points.shape[1]) * config.augment_noise).astype(np.float32) + noise = np.clip(noise, -1*config.augment_noise_clip, config.augment_noise_clip) + + ####### + # Shift + ####### + + if config.augment_shift: + shift = np.random.uniform(low=-config.augment_shift, high=config.augment_shift, size=[3]).astype(np.float32) + + ################## + # Apply transforms + ################## + + # Do not use np.dot because it is multi-threaded + #augmented_points = np.dot(points, R) * scale + noise + augmented_points = np.sum(np.expand_dims(points, 2) * R, axis=1) * scale + noise + if config.augment_shift: + augmented_points = np.add(augmented_points, shift) + + + if normals is None: + return augmented_points + else: + # Anisotropic scale of the normals thanks to cross product formula + if config.normal_scale: + normal_scale = scale[[1, 2, 0]] * scale[[2, 0, 1]] + else: + normal_scale = np.ones(points.shape[1]) + augmented_normals = np.dot(normals, R) * normal_scale + # Renormalise + augmented_normals *= 1 / (np.linalg.norm(augmented_normals, axis=1, keepdims=True) + 1e-6) + + if verbose: + test_p = [np.vstack([points, augmented_points])] + test_n = [np.vstack([normals, augmented_normals])] + test_l = [np.hstack([points[:, 2]*0, augmented_points[:, 2]*0+1])] + show_ModelNet_examples(test_p, test_n, test_l) + + return augmented_points, augmented_normals + +class PartSegConfig(): + + #################### + # Dataset parameters + #################### + + + # Augmentations (S3DIS) + '''augment_scale_anisotropic = True + augment_symmetries = [True, False, False] + augment_rotation = 'vertical' + augment_scale_min = 0.8 + augment_scale_max = 1.2 + augment_noise = 0.001 + augment_color = 0.8''' + + # Augmentations (PartSeg) + augment_scale_anisotropic = True + augment_symmetries = [False, False, False] + augment_rotation = 'none' + augment_scale_min = 0.9 + augment_scale_max = 1.1 + augment_noise = 0.001 + augment_occlusion = 'none' + augment_shift = 0.2 + +def test(): + points = np.random.rand(1024,3) + config = PartSegConfig() + points = augmentation_transform(points, config) + print(points.shape) + +if __name__ == '__main__': + test() diff --git a/part_seg/visualize.py b/part_seg/visualize.py new file mode 100644 index 0000000..d5fefcc --- /dev/null +++ b/part_seg/visualize.py @@ -0,0 +1,38 @@ +from mpl_toolkits.mplot3d import Axes3D +from matplotlib import pyplot as plt +import numpy as np + +COLORS = ["tomato", "forestgreen", "royalblue", "gold", "cyan", "gray"] + +def normalize(points: "numpy array (vertice_num, 3)"): + center = np.mean(points, axis= 0) + points = points - center + max_d = np.sqrt(np.max(points @ (points.T))) + points = points / max_d + return points + +def visualize(points: "(vertice_num, 3)", labels: "(vertice_num, )", fig_name: str): + points = np.array(points) + labels = np.array(labels) + + points = normalize(points) + eye = np.eye(3) + bound_points = np.vstack((eye , eye * (-1))) + + x ,y ,z = points[:, 0], points[:, 1], points[:, 2] + fig = plt.figure() + ax = fig.add_subplot(projection= "3d") + ax.axis("off") + + colors = [COLORS[i % len(COLORS)] for i in labels] + ax.scatter(x ,z, y, s= 3, c= colors, marker= "o") + ax.scatter(bound_points[:, 0], bound_points[:, 1], bound_points[:, 2], s=0.01, c= "white") + plt.savefig(fig_name) + plt.close() + + +def test(): + pass + +if __name__ == "__main__": + test()