-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtd3d_instance_segmentor.py
141 lines (121 loc) · 5.15 KB
/
td3d_instance_segmentor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
try:
import MinkowskiEngine as ME
except ImportError:
import warnings
warnings.warn(
'Please follow `getting_started.md` to install MinkowskiEngine.`')
from mmdet3d.models import DETECTORS, build_backbone, build_neck, build_head
from .base import Base3DDetector
import torch
@DETECTORS.register_module()
class TD3DInstanceSegmentor(Base3DDetector):
r"""Two-stage instance segmentor based on MinkowskiEngine.
The first stage is bbox detector. The second stage is two-class pointwise segmentor (foreground/background).
Args:
backbone (dict): Config of the backbone.
neck (dict): Config of the neck.
head (dict): Config of the head.
voxel_size (float): Voxel size in meters.
train_cfg (dict, optional): Config for train stage. Defaults to None.
test_cfg (dict, optional): Config for test stage. Defaults to None.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
pretrained (str, optional): Deprecated initialization parameter.
Defaults to None.
"""
def __init__(self,
backbone,
neck,
head,
voxel_size,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super(TD3DInstanceSegmentor, self).__init__(init_cfg)
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
head.update(train_cfg=train_cfg)
head.update(test_cfg=test_cfg)
self.head = build_head(head)
self.voxel_size = voxel_size
self.init_weights()
def extract_feat(self, points):
"""Extract features from points.
Args:
points (list[Tensor]): Raw point clouds.
Returns:
SparseTensor: Voxelized point clouds.
"""
x = self.backbone(points)
x = self.neck(x)
return x
def collate(self, points, quantization_mode):
coordinates, features = ME.utils.batch_sparse_collate(
[(p[:, :3] / self.voxel_size, p[:, 3:]) for p in points],
dtype=points[0].dtype,
device=points[0].device)
return ME.TensorField(
features=features,
coordinates=coordinates,
quantization_mode=quantization_mode,
minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
device=points[0].device,
)
def forward_train(self, points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas):
"""Forward of training.
Args:
points (list[Tensor]): Raw point clouds.
gt_bboxes_3d (list[BaseInstance3DBoxes]): Ground truth
bboxes of each sample.
gt_labels_3d (list[torch.Tensor]): Labels of each sample.
pts_semantic_mask (list[torch.Tensor]): Per point semantic labels
of each sample.
pts_instance_mask (list[torch.Tensor]): Per point instance labels
of each sample.
img_metas (list[dict]): Contains scene meta infos.
Returns:
dict: Loss values.
"""
# points = [torch.cat([p, torch.unsqueeze(m, 1)], dim=1) for p, m in zip(points, pts_instance_mask)]
points = [torch.cat([p, torch.unsqueeze(inst, 1), torch.unsqueeze(sem, 1)], dim=1) for p, inst, sem in zip(points, pts_instance_mask, pts_semantic_mask)]
field = self.collate(points, ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE)
x = field.sparse()
targets = x.features[:, 3:].round().long()
x = ME.SparseTensor(
x.features[:, :3],
coordinate_map_key=x.coordinate_map_key,
coordinate_manager=x.coordinate_manager,
)
x = self.extract_feat(x)
losses = self.head.forward_train(x, targets, field, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, img_metas)
return losses
def simple_test(self, points, img_metas, *args, **kwargs):
"""Test without augmentations.
Args:
points (list[torch.Tensor]): Points of each sample.
img_metas (list[dict]): Contains scene meta infos.
Returns:
list[dict]: Predicted 3d instances.
"""
field = self.collate(points, ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
x = self.extract_feat(field.sparse())
instances = self.head.forward_test(x, field, img_metas)
results = []
for mask, label, score in instances:
results.append(dict(
instance_mask=mask.cpu(),
instance_label=label.cpu(),
instance_score=score.cpu()))
return results
def aug_test(self, points, img_metas, **kwargs):
"""Test with augmentations.
Args:
points (list[list[torch.Tensor]]): Points of each sample.
img_metas (list[dict]): Contains scene meta infos.
Returns:
list[dict]: Predicted 3d boxes.
"""
raise NotImplementedError