def forward()

in GraphAutoEncoder/graphVAE_train.py [0:0]


    def forward(self, in_pc_batch, iteration, frame, t_idx, t_nor, bDebug = False): # meshvertices: B N 3, meshnormals: B N 3    

        nbat = in_pc_batch.size(0)
        npt = in_pc_batch.size(1)
        nch = in_pc_batch.size(2)


        t_mu, t_logstd = self.net_geoenc(in_pc_batch, self.mcvcoeffsenc) # in in mm and out in dm
        t_std = t_logstd.exp()

        t_eps = torch.ones_like(t_std).normal_() #torch.FloatTensor(t_std.size()).normal_().to(device)
        t_z = t_mu + t_std * t_eps 

        
        klloss = torch.mean(-0.5 - t_logstd + 0.5 * t_mu ** 2 + 0.5 * torch.exp(2 * t_logstd))

        out_pc_batchfull = self.net_geodec(t_z, self.mcvcoeffsdec)
        out_pc_batch = out_pc_batchfull[:,:,0:3]


        dif_pos = out_pc_batch - in_pc_batch
        vet0 = index_selection_nd(dif_pos,self.t_facedata[:,0],1)
        vet1 = index_selection_nd(dif_pos,self.t_facedata[:,1],1)
        vet2 = index_selection_nd(dif_pos,self.t_facedata[:,2],1)

        loss_normal = ((vet1 + vet0 + vet2)/3.0 * t_nor).sum(2).pow(2).mean()
        

        loss_pose_l1 = self.net_loss.compute_geometric_loss_l1(in_pc_batch[:,:,0:3], out_pc_batch[:,:,0:3])
        loss_laplace_l1 = self.net_loss.compute_laplace_loss_l2(in_pc_batch[:,:,0:3], out_pc_batch[:,:,0:3])



        loss = loss_pose_l1*self.w_pose +  loss_laplace_l1 * self.w_laplace  + klloss * self.klweight + loss_normal * self.w_nor

             
        outvetgeo = (out_pc_batch/ SCALE)
        gtvetgeo = (in_pc_batch/ SCALE)
        
        
        if bDebug and in_pc_batch.get_device() == 0:  

            wtid = t_idx[0].cpu()
            wtid = (int)(frame[0])
            print(wtid)

            err_pose_l2 = self.net_loss.compute_geometric_mean_euclidean_dist_error(in_pc_batch[:1,:,0:3], out_pc_batch[:1,:,0:3])
            err_pose_l1 = self.net_loss.compute_laplace_loss_l1(in_pc_batch[:1,:,0:3], out_pc_batch[:1,:,0:3])

            with open(f'{self.write_tmp_folder}/err{iteration}.txt','w') as f:
                f.write('{0:d} {1:f} {2:f} {3:f}\n'.format(iteration,loss_pose_l1.mean(),err_pose_l2, err_pose_l1))
                f.closed

            ##just output some meshes


                           
        return loss[None],loss_pose_l1[None],loss_laplace_l1[None], klloss[None], loss_normal[None]