in models/trunks/spconv_unet.py [0:0]
def forward(self, x, out_feat_keys=None):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
point_features: (N, C)
"""
### Pre processing
voxel_features, voxel_num_points = x['voxels'], x['voxel_num_points']
points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False)
normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features)
points_mean = points_mean / normalizer
voxel_features = points_mean.contiguous()
temp = x['voxel_coords'].detach().cpu().numpy()
batch_size = len(np.unique(temp[:,0]))
voxel_coords = x['voxel_coords']
input_sp_tensor = spconv.SparseConvTensor(
features=voxel_features.float(),
indices=voxel_coords.int(),
spatial_shape=self.sparse_shape,
batch_size=batch_size
)
x = self.conv_input(input_sp_tensor)
x_conv1 = self.conv1(x)
x_conv2 = self.conv2(x_conv1)
x_conv3 = self.conv3(x_conv2)
x_conv4 = self.conv4(x_conv3)
if self.conv_out is not None:
# for detection head
# [200, 176, 5] -> [200, 176, 2]
out = self.conv_out(x_conv4)
#batch_dict['encoded_spconv_tensor'] = out
#batch_dict['encoded_spconv_tensor_stride'] = 8
# for segmentation head
# [400, 352, 11] <- [200, 176, 5]
x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4, self.conv_up_m4, self.inv_conv4)
# [800, 704, 21] <- [400, 352, 11]
x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3, self.conv_up_m3, self.inv_conv3)
# [1600, 1408, 41] <- [800, 704, 21]
x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2, self.conv_up_m2, self.inv_conv2)
# [1600, 1408, 41] <- [1600, 1408, 41]
x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1, self.conv_up_m1, self.conv5)
end_points = {}
end_points['conv4_features'] = [x_up4.features, x_up3.features, x_up2.features, x_up1.features]#.view(batch_size, -1, 64).permute(0, 2, 1).contiguous()
end_points['indice'] = [x_up4.indices, x_up3.indices, x_up2.indices, x_up1.indices]
out_feats = [None] * len(out_feat_keys)
for key in out_feat_keys:
feat = end_points[key+"_features"]
featlist = []
for i in range(batch_size):
tempfeat = []
for idx in range(len(end_points['indice'])):
temp_idx = end_points['indice'][idx][:,0] == i
temp_f = end_points['conv4_features'][idx][temp_idx].unsqueeze(0).permute(0, 2, 1).contiguous()
tempfeat.append(F.max_pool1d(temp_f, temp_f.shape[-1]).squeeze(-1))
featlist.append(torch.cat(tempfeat, -1))
feat = torch.cat(featlist, 0)
if self.use_mlp:
feat = self.head(feat)
out_feats[out_feat_keys.index(key)] = feat ### Just use smlp
return out_feats