def forward()

in GraphAutoEncoder/graphVAESSW.py [0:0]


    def forward(self, in_pc, raw_w_weights, is_final_layer=False, b_max_pool = False):
        batch = in_pc.shape[0]
        device = in_pc.device #in_pc.device
        
        in_channel = self.in_channel
        out_channel = self.out_channel 
        in_pn = self.in_point_num
        out_pn = self.out_point_num 
        weight_num = self.weight_num  #M
        max_neighbor_num = self.max_neighbor_num #N
        neighbor_num_lst = self.neighbor_num_lst
        neighbor_id_lstlst = self.neighbor_id_lstlst
        
        
        pc_mask  = torch.ones(in_pn+1).float().to(in_pc.device)
        pc_mask[in_pn]=0                
        neighbor_mask_lst = index_selection_nd(pc_mask,neighbor_id_lstlst,0).contiguous()#out_pn*max_neighbor_num neighbor is 1 otherwise 0
              
        raw_weights = self.weights
        bias = self.bias
        
        w_weights = raw_w_weights*(neighbor_mask_lst.view(out_pn, max_neighbor_num, 1)) #out_pn*max_neighbor_num*weight_num
        
        
        in_pc_pad = torch.cat((in_pc, torch.zeros(batch, 1, in_channel).float().to(in_pc.device)), 1) #batch (in_pn+1) in_channel
        in_neighbors = index_selection_nd(in_pc_pad,neighbor_id_lstlst, 1)                
        fuse_neighbors = torch.einsum('pnm,bpni->bpmi',[w_weights, in_neighbors]) #batch*out_pn*max_neighbor_num*out_channel             
        
        normalized_weights = raw_weights.view(weight_num,out_channel,in_channel)        
        out_neighbors = torch.einsum('moi,bpmi->bpmo',[normalized_weights, fuse_neighbors]) #out_pn*max_neighbor_num*(out_channel*in_channel)


        out_pc = "" #batch*out_pn*out_channel        
        if b_max_pool:
            out_pc = out_neighbors.max(2)
        else:
            out_pc = out_neighbors.sum(2)
        
        out_pc = out_pc + bias
        
        if is_final_layer==False:
            out_pc = self.relu(out_pc) ##self.relu is defined in the init function
        
        
        if self.residual_rate==0:
            return out_pc
        
        if(in_channel != out_channel):
            in_pc_pad = torch.einsum('oi,bpi->bpo',[self.weight_res.view(out_channel,in_channel), in_pc_pad])

        out_pc_res = []
        if(in_pn == out_pn):
            out_pc_res = in_pc_pad[:,0:in_pn].clone()
        else:
            p_neighbors_raw = self.p_neighbors            
            in_neighbors = index_selection_nd(in_pc_pad,neighbor_id_lstlst, 1)

            #p_neighbors = torch.sigmoid(p_neighbors_raw) * neighbor_mask_lst
            p_neighbors = torch.abs(p_neighbors_raw)  * neighbor_mask_lst
            p_neighbors_sum = p_neighbors.sum(1) + 1e-8 #out_pn
            p_neighbors = p_neighbors/p_neighbors_sum.view(out_pn,1).repeat(1,max_neighbor_num)        
            
            out_pc_res = torch.einsum('pn,bpno->bpo', [p_neighbors, in_neighbors])

        out_pc = out_pc*np.sqrt(1-self.residual_rate) + out_pc_res*np.sqrt(self.residual_rate)      

        return out_pc