def forward()

in c3dm/c3dpo.py [0:0]


    def forward( self, kp_loc=None, kp_vis=None, \
                 class_mask=None, K=None, dense_basis=None, \
                 phi_out = None, dense_basis_mask=None, 
                 shape_coeff_in = None, **kwargs ):

        # dictionary with outputs of the fw pass
        preds = {}

        # input sizes ...
        ba,kp_dim,n_kp = kp_loc.shape
        dtype = kp_loc.type()        

        assert kp_dim==2, 'bad input keypoint dim'
        assert n_kp==self.n_keypoints, 'bad # of keypoints!'

        if self.projection_type=='perspective':
            kp_loc_cal = self.calibrate_keypoints(kp_loc, K)
        else:
            kp_loc_cal = kp_loc

        # save for later visualisations ...
        kp_loc_norm, kp_mean, kp_scale = \
            self.normalize_keypoints( \
                    kp_loc_cal, kp_vis, rescale=self.keypoint_rescale )
        preds['kp_loc_norm'] = kp_loc_norm
        preds['kp_mean'], preds['kp_scale'] = kp_mean, kp_scale

        # run the shape predictor
        if phi_out is not None: # bypass the predictor and use input
            preds['phi'] = phi_out
        else:
            preds['phi'] = self.run_phi(kp_loc_norm, kp_vis, \
                                class_mask=class_mask, \
                                shape_coeff_in=shape_coeff_in)

        if self.canonicalization['use']:
            preds['l_canonicalization' ], preds['psi'] = \
                self.canonicalization_loss( preds['phi'], \
                class_mask=class_mask )

        # 3D->2D project shape to camera
        kp_reprojected, depth = self.camera_projection( \
            preds['phi']['shape_camera_coord'])
        preds['kp_reprojected'] = kp_reprojected

        if dense_basis is not None:
            preds['phi_dense'] = self.run_phi_dense(dense_basis, preds['phi'])
            kp_reprojected_dense, depth_dense = self.camera_projection( \
                                preds['phi_dense']['shape_camera_coord_dense'])
            preds['kp_reprojected_dense'] = kp_reprojected_dense
            preds['depth_dense'] = depth_dense

        # compute the repro loss for backpropagation
        if self.loss_normalization=='kp_count_per_image':
            preds['l_reprojection'] = avg_l2_dist( \
                        kp_reprojected,
                        kp_loc_norm,
                        mask=kp_vis,
                        squared=self.squared_reprojection_loss )
            # print(float(preds['l_reprojection']))
        elif self.loss_normalization=='kp_total_count':
            kp_reprojected_flatten = \
                kp_reprojected.permute(1,2,0).contiguous().view(1,2,self.n_keypoints*ba)
            kp_loc_norm_flatten = \
                kp_loc_norm.permute(1,2,0).contiguous().view(1,2,self.n_keypoints*ba)
            kp_vis_flatten = \
                kp_vis.permute(1,0).contiguous().view(1,self.n_keypoints*ba)

            if self.use_huber:
                preds['l_reprojection'] = avg_l2_huber( \
                    kp_reprojected_flatten,
                    kp_loc_norm_flatten,
                    mask=kp_vis_flatten,
                    scaling=self.huber_scaling )
            else:
                assert False
                preds['l_reprojection'] = avg_l2_dist( \
                            kp_reprojected_flatten,
                            kp_loc_norm_flatten,
                            mask=kp_vis_flatten,
                            squared=self.squared_reprojection_loss )

        else:
            raise ValueError('undefined loss normalization %s' % self.loss_normalization)

        if self.squared_reprojection_loss:
            assert False
            # compute the average reprojection distance
            #   = easier to interpret than the squared repro loss
            preds['dist_reprojection'] = avg_l2_dist( \
                                            kp_reprojected,
                                            kp_loc_norm,
                                            mask=kp_vis,
                                            squared=False )

        # unnormalize the shape projections
        kp_reprojected_image = \
            self.unnormalize_keypoints(kp_reprojected, kp_mean, \
                rescale=self.keypoint_rescale, kp_scale=kp_scale)

        if dense_basis is not None:
            kp_reprojected_image_dense = \
                self.unnormalize_keypoints( \
                    preds['kp_reprojected_dense'], kp_mean, \
                    rescale=self.keypoint_rescale, kp_scale=kp_scale)
            preds['kp_reprojected_image_dense'] = kp_reprojected_image_dense
            
        # projections in the image coordinate frame
        if self.replace_keypoints_with_input and not self.training:
            # use the input points
            kp_reprojected_image = (1-kp_vis[:,None,:]) * kp_reprojected_image + \
                                    kp_vis[:,None,:]    * kp_loc_cal
            
        preds['kp_reprojected_image'] = kp_reprojected_image

        # projected 3D shape in the image space 
        #   = unprojection of kp_reprojected_image
        shape_image_coord, depth_image_coord = \
            self.camera_unprojection( \
                                kp_reprojected_image, depth, \
                                rescale=self.keypoint_rescale, \
                                kp_scale=kp_scale )
        
        if dense_basis is not None:
            shape_image_coord_dense, depth_image_coord_dense = \
                self.camera_unprojection( \
                    kp_reprojected_image_dense, depth_dense, \
                    rescale=self.keypoint_rescale, \
                    kp_scale=kp_scale )
            
        if self.projection_type=='perspective':
            preds['kp_reprojected_image_cal'] = kp_reprojected_image
            preds['shape_image_coord_cal'] = shape_image_coord
            preds['shape_image_coord'] = \
                self.uncalibrate_keypoints(shape_image_coord, K)
            preds['kp_reprojected_image'], _ = \
                self.camera_projection(preds['shape_image_coord'])
            if dense_basis is not None:
                preds['shape_image_coord_cal_dense'] = shape_image_coord_dense
                preds['shape_image_coord_dense'] = \
                    self.uncalibrate_keypoints(shape_image_coord_dense, K)
                preds['kp_reprojected_image_dense'], _ = \
                    self.camera_projection(preds['shape_image_coord_dense'])

                # if True:
                #     preds['shape_image_coord_dense'].register_hook(\
                #         lambda grad: print(grad.abs().view(-1).topk(10)[0][-1]))
                #     preds['kp_reprojected_image_dense'].register_hook(\
                #         lambda grad: print(grad.abs().view(-1).topk(10)[0][-1]))

                preds['depth_image_coord_dense'] = depth_image_coord_dense

        elif self.projection_type=='orthographic':
            preds['shape_image_coord'] = shape_image_coord
            preds['depth_image_coord'] = depth_image_coord
            if dense_basis is not None:
                preds['shape_image_coord_dense'] = shape_image_coord_dense
                preds['depth_image_coord_dense'] = depth_image_coord_dense
        
        else:
            raise ValueError()

        
        # get the final loss
        preds['objective'] = self.get_objective(preds)
        assert np.isfinite(preds['objective'].sum().data.cpu().numpy()), "nans!"
        
        return preds