def forward()

in models/trunks/spconv_unet.py [0:0]


    def forward(self, x, out_feat_keys=None):
        """
        Args:
            batch_dict:
                batch_size: int
                vfe_features: (num_voxels, C)
                voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
        Returns:
            batch_dict:
                encoded_spconv_tensor: sparse tensor
                point_features: (N, C)
        """
        ### Pre processing
        voxel_features, voxel_num_points = x['voxels'], x['voxel_num_points']
        points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False)
        normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features)
        points_mean = points_mean / normalizer
        voxel_features = points_mean.contiguous()

        temp = x['voxel_coords'].detach().cpu().numpy()
        
        batch_size = len(np.unique(temp[:,0]))
        voxel_coords = x['voxel_coords']

        input_sp_tensor = spconv.SparseConvTensor(
            features=voxel_features.float(),
            indices=voxel_coords.int(),
            spatial_shape=self.sparse_shape,
            batch_size=batch_size
        )
        x = self.conv_input(input_sp_tensor)

        x_conv1 = self.conv1(x)
        x_conv2 = self.conv2(x_conv1)
        x_conv3 = self.conv3(x_conv2)
        x_conv4 = self.conv4(x_conv3)

        if self.conv_out is not None:
            # for detection head
            # [200, 176, 5] -> [200, 176, 2]
            out = self.conv_out(x_conv4)
            #batch_dict['encoded_spconv_tensor'] = out
            #batch_dict['encoded_spconv_tensor_stride'] = 8

        # for segmentation head
        # [400, 352, 11] <- [200, 176, 5]
        x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4, self.conv_up_m4, self.inv_conv4)
        # [800, 704, 21] <- [400, 352, 11]
        x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3, self.conv_up_m3, self.inv_conv3)
        # [1600, 1408, 41] <- [800, 704, 21]
        x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2, self.conv_up_m2, self.inv_conv2)
        # [1600, 1408, 41] <- [1600, 1408, 41]
        x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1, self.conv_up_m1, self.conv5)
        
        end_points = {}

        end_points['conv4_features'] = [x_up4.features, x_up3.features, x_up2.features, x_up1.features]#.view(batch_size, -1, 64).permute(0, 2, 1).contiguous()
        end_points['indice'] = [x_up4.indices, x_up3.indices, x_up2.indices, x_up1.indices]
        
        out_feats = [None] * len(out_feat_keys)

        for key in out_feat_keys:
            feat = end_points[key+"_features"]

            featlist = []
            for i in range(batch_size):
                tempfeat = []
                for idx in range(len(end_points['indice'])):
                    temp_idx = end_points['indice'][idx][:,0] == i
                    temp_f = end_points['conv4_features'][idx][temp_idx].unsqueeze(0).permute(0, 2, 1).contiguous()
                    tempfeat.append(F.max_pool1d(temp_f, temp_f.shape[-1]).squeeze(-1))
                featlist.append(torch.cat(tempfeat, -1))
            feat = torch.cat(featlist, 0)
            if self.use_mlp:
                feat = self.head(feat)
            out_feats[out_feat_keys.index(key)] = feat ### Just use smlp
        return out_feats