forked from qinzheng93/GeoTransformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
backbone.py
85 lines (67 loc) · 3.88 KB
/
backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from geotransformer.modules.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample
class KPConvFPN(nn.Module):
def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm):
super(KPConvFPN, self).__init__()
self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm)
self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm)
self.encoder2_1 = ResidualBlock(
init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True
)
self.encoder2_2 = ResidualBlock(
init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
)
self.encoder2_3 = ResidualBlock(
init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm
)
self.encoder3_1 = ResidualBlock(
init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True
)
self.encoder3_2 = ResidualBlock(
init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
)
self.encoder3_3 = ResidualBlock(
init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm
)
self.encoder4_1 = ResidualBlock(
init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm, strided=True
)
self.encoder4_2 = ResidualBlock(
init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm
)
self.encoder4_3 = ResidualBlock(
init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm
)
self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm)
self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim)
def forward(self, feats, data_dict):
feats_list = []
points_list = data_dict['points']
neighbors_list = data_dict['neighbors']
subsampling_list = data_dict['subsampling']
upsampling_list = data_dict['upsampling']
feats_s1 = feats
feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0])
feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0])
feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0])
feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1])
feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1])
feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1])
feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2])
feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2])
feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2])
feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3])
feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3])
latent_s4 = feats_s4
feats_list.append(feats_s4)
latent_s3 = nearest_upsample(latent_s4, upsampling_list[2])
latent_s3 = torch.cat([latent_s3, feats_s3], dim=1)
latent_s3 = self.decoder3(latent_s3)
feats_list.append(latent_s3)
latent_s2 = nearest_upsample(latent_s3, upsampling_list[1])
latent_s2 = torch.cat([latent_s2, feats_s2], dim=1)
latent_s2 = self.decoder2(latent_s2)
feats_list.append(latent_s2)
feats_list.reverse()
return feats_list