-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,908 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Adaptive Graph Convolution for Point Cloud Analysis | ||
|
||
This repository contains the implementation of **AdaptConv** for point cloud analysis. | ||
|
||
Adaptive Graph Convolution (AdaptConv) is a point cloud convolution operator presented in our ICCV2021 paper. If you find our work useful in your research, please cite our paper. | ||
|
||
## Installation | ||
|
||
* The code has been tested on one configuration: | ||
- PyTorch 1.1.0, CUDA 10.1 | ||
|
||
* Install required packages: | ||
- numpy | ||
- h5py | ||
- scikit-learn | ||
- matplotlib | ||
|
||
## Classification | ||
|
||
[classification.md](./cls/classification.md) | ||
|
||
## Part Segmentation | ||
|
||
[part_segmentation.md](./part_seg/part_segmentation.md) | ||
|
||
## Indoor Segmentation | ||
|
||
coming soon | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
## Point Cloud Classification on ModelNet40 | ||
|
||
### Data | ||
|
||
First, you may download the ModelNet40 dataset from [here](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip), and place it to `cls/data/modelnet40_ply_hdf5_2048`. We use the prepared data in HDF5 files for principle evaluation, where each object is already sampled to 2048 points. The experiments presented in the paper uses 1024 points for training and testing. | ||
|
||
### Usage | ||
|
||
To train a model for classification: | ||
|
||
python train.py | ||
|
||
Model and log files will be saved to `cls/models/train/` in default. After the training stage, you can test the model by: | ||
|
||
python train.py --eval 1 | ||
|
||
If you'd like to use your own data, you can modify `data.py` to change the data-loading path. | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
|
||
import os | ||
import sys | ||
import glob | ||
import h5py | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def download(): | ||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
DATA_DIR = os.path.join(BASE_DIR, 'data') | ||
if not os.path.exists(DATA_DIR): | ||
os.mkdir(DATA_DIR) | ||
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): | ||
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' | ||
zipfile = os.path.basename(www) | ||
os.system('wget %s; unzip %s' % (www, zipfile)) | ||
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) | ||
os.system('rm %s' % (zipfile)) | ||
|
||
|
||
def load_data(partition): | ||
download() | ||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
DATA_DIR = os.path.join(BASE_DIR, 'data') | ||
all_data = [] | ||
all_label = [] | ||
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): | ||
f = h5py.File(h5_name) | ||
data = f['data'][:].astype('float32') | ||
label = f['label'][:].astype('int64') | ||
f.close() | ||
all_data.append(data) | ||
all_label.append(label) | ||
all_data = np.concatenate(all_data, axis=0) | ||
all_label = np.concatenate(all_label, axis=0) | ||
return all_data, all_label | ||
|
||
|
||
def translate_pointcloud(pointcloud): | ||
xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) | ||
xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) | ||
|
||
translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') | ||
return translated_pointcloud | ||
|
||
def normalize_pointcloud(pointcloud): | ||
center = pointcloud.mean(axis=0) | ||
pointcloud -= center | ||
distance = np.linalg.norm(pointcloud, axis=1) | ||
pointcloud /= distance.max() | ||
return pointcloud | ||
|
||
def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): | ||
N, C = pointcloud.shape | ||
pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) | ||
return pointcloud | ||
|
||
# **********Dataset ModelNet40********** | ||
|
||
class ModelNet40(Dataset): | ||
def __init__(self, num_points, partition='train'): | ||
self.data, self.label = load_data(partition) | ||
self.num_points = num_points | ||
self.partition = partition | ||
|
||
def __getitem__(self, item): | ||
pointcloud = self.data[item][:self.num_points] | ||
label = self.label[item] | ||
if self.partition == 'train': | ||
pointcloud = translate_pointcloud(pointcloud) | ||
np.random.shuffle(pointcloud) | ||
return pointcloud, label | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
|
||
if __name__ == '__main__': | ||
train = ModelNet40(1024) | ||
test = ModelNet40(1024, 'test') | ||
for data, label in train: | ||
print(data.shape) | ||
print(label.shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import os | ||
import sys | ||
import copy | ||
import math | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def knn(x, k): | ||
inner = -2*torch.matmul(x.transpose(2, 1), x) | ||
xx = torch.sum(x**2, dim=1, keepdim=True) | ||
pairwise_distance = -xx - inner - xx.transpose(2, 1) | ||
|
||
idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) | ||
return idx | ||
|
||
|
||
def get_graph_feature(x, k=20, idx=None): | ||
batch_size = x.size(0) | ||
num_points = x.size(2) | ||
x = x.view(batch_size, -1, num_points) | ||
if idx is None: | ||
idx = knn(x, k=k) # (batch_size, num_points, k) | ||
device = x.device | ||
|
||
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points | ||
|
||
idx = idx + idx_base | ||
|
||
idx = idx.view(-1) | ||
|
||
_, num_dims, _ = x.size() | ||
|
||
x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) | ||
feature = x.view(batch_size*num_points, -1)[idx, :] | ||
feature = feature.view(batch_size, num_points, k, num_dims) | ||
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) | ||
|
||
feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() | ||
|
||
return feature, idx | ||
|
||
|
||
class AdaptiveConv(nn.Module): | ||
def __init__(self, in_channels, out_channels, feat_channels): | ||
super(AdaptiveConv, self).__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.feat_channels = feat_channels | ||
|
||
self.conv0 = nn.Conv2d(feat_channels, out_channels, kernel_size=1, bias=False) | ||
self.conv1 = nn.Conv2d(out_channels, out_channels*in_channels, kernel_size=1, bias=False) | ||
self.bn0 = nn.BatchNorm2d(out_channels) | ||
self.bn1 = nn.BatchNorm2d(out_channels) | ||
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2) | ||
|
||
def forward(self, x, y): | ||
# x: (bs, in_channels, num_points, k), y: (bs, feat_channels, num_points, k) | ||
batch_size, n_dims, num_points, k = x.size() | ||
|
||
y = self.conv0(y) # (bs, out, num_points, k) | ||
y = self.leaky_relu(self.bn0(y)) | ||
y = self.conv1(y) # (bs, in*out, num_points, k) | ||
y = y.permute(0, 2, 3, 1).view(batch_size, num_points, k, self.out_channels, self.in_channels) # (bs, num_points, k, out, in) | ||
|
||
x = x.permute(0, 2, 3, 1).unsqueeze(4) # (bs, num_points, k, in_channels, 1) | ||
x = torch.matmul(y, x).squeeze(4) # (bs, num_points, k, out_channels) | ||
x = x.permute(0, 3, 1, 2).contiguous() # (bs, out_channels, num_points, k) | ||
|
||
x = self.bn1(x) | ||
x = self.leaky_relu(x) | ||
|
||
return x | ||
|
||
class Net(nn.Module): | ||
def __init__(self, args, output_channels=40): | ||
super(Net, self).__init__() | ||
self.args = args | ||
self.k = args.k | ||
|
||
self.bn1 = nn.BatchNorm2d(64) | ||
self.bn2 = nn.BatchNorm2d(64) | ||
self.bn3 = nn.BatchNorm2d(128) | ||
self.bn4 = nn.BatchNorm2d(256) | ||
self.bn5 = nn.BatchNorm1d(args.emb_dims) | ||
|
||
self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), | ||
self.bn3, | ||
nn.LeakyReLU(negative_slope=0.2)) | ||
self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), | ||
self.bn4, | ||
nn.LeakyReLU(negative_slope=0.2)) | ||
self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False), | ||
self.bn5, | ||
nn.LeakyReLU(negative_slope=0.2)) | ||
self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False) | ||
self.bn6 = nn.BatchNorm1d(512) | ||
self.dp1 = nn.Dropout(p=args.dropout) | ||
self.linear2 = nn.Linear(512, 256) | ||
self.bn7 = nn.BatchNorm1d(256) | ||
self.dp2 = nn.Dropout(p=args.dropout) | ||
self.linear3 = nn.Linear(256, output_channels) | ||
|
||
self.adapt_conv1 = AdaptiveConv(6, 64, 6) | ||
self.adapt_conv2 = AdaptiveConv(6, 64, 64*2) | ||
|
||
def forward(self, x): | ||
batch_size = x.size(0) | ||
points = x | ||
|
||
x, idx = get_graph_feature(x, k=self.k) | ||
p, _ = get_graph_feature(points, k=self.k, idx=idx) | ||
x = self.adapt_conv1(p, x) | ||
x1 = x.max(dim=-1, keepdim=False)[0] | ||
|
||
x, idx = get_graph_feature(x1, k=self.k) | ||
p, _ = get_graph_feature(points, k=self.k, idx=idx) | ||
x = self.adapt_conv2(p, x) | ||
x2 = x.max(dim=-1, keepdim=False)[0] | ||
|
||
x, _ = get_graph_feature(x2, k=self.k) | ||
x = self.conv3(x) | ||
x3 = x.max(dim=-1, keepdim=False)[0] | ||
|
||
x, _ = get_graph_feature(x3, k=self.k) | ||
x = self.conv4(x) | ||
x4 = x.max(dim=-1, keepdim=False)[0] | ||
|
||
x = torch.cat((x1, x2, x3, x4), dim=1) | ||
|
||
x = self.conv5(x) | ||
x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) | ||
x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) | ||
x = torch.cat((x1, x2), 1) | ||
|
||
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) | ||
x = self.dp1(x) | ||
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) | ||
x = self.dp2(x) | ||
x = self.linear3(x) | ||
return x |
Oops, something went wrong.