def forward2()

in GraphAutoEncoder/graphVAESSW.py [0:0]


    def forward2(self, in_pc, raw_w_weights, is_final_layer=False, b_max_pool = False): #layer_info
        batch = in_pc.shape[0]
        device = in_pc.device
        
        in_channel = self.in_channel
        out_channel = self.out_channel 
        in_pn = self.in_point_num
        out_pn = self.out_point_num 
        weight_num = self.weight_num  
        max_neighbor_num = self.max_neighbor_num
        neighbor_num_lst = self.neighbor_num_lst
        neighbor_id_lstlst = self.neighbor_id_lstlst
                
        pc_mask  = torch.ones(in_pn+1).float().to(device)
        pc_mask[in_pn]=0
        neighbor_mask_lst = pc_mask[neighbor_id_lstlst].contiguous() #out_pn*max_neighbor_num neighbor is 1 otherwise 0
        
        
        raw_weights = self.weights
        bias = self.bias        
        
        w_weights = raw_w_weights*neighbor_mask_lst.view(out_pn, max_neighbor_num, 1).repeat(1,1,weight_num) #out_pn*max_neighbor_num*weight_num
                
        normalized_weights = raw_weights          

        weights = torch.einsum('pmw,wc->pmc',[w_weights, normalized_weights]) #out_pn*max_neighbor_num*(out_channel*in_channel)
        weights = weights.view(out_pn, max_neighbor_num, out_channel,in_channel)
                
        in_pc_pad = torch.cat((in_pc, torch.zeros(batch, 1, in_channel).float().to(device)), 1) #batch*(in_pn+1)*in_channel
        
        in_neighbors = in_pc_pad[:, neighbor_id_lstlst] #batch*out_pn*max_neighbor_num*in_channel
        
        out_neighbors = torch.einsum('pmoi,bpmi->bpmo',[weights, in_neighbors]) #batch*out_pn*max_neighbor_num*out_channel
        
        out_pc = ""
        if b_max_pool:
            out_pc = out_neighbors.max(2)
        else:
            out_pc = out_neighbors.sum(2)
        
        out_pc = out_pc + bias
        
        if is_final_layer==False:
            out_pc = self.relu(out_pc)
                
        if self.residual_rate==0:
            return out_pc        
                
        if(in_channel != out_channel):
            in_pc_pad = torch.einsum('oi,bpi->bpo',[self.weight_res.view(out_channel,in_channel), in_pc_pad])

        out_pc_res = []
        if(in_pn == out_pn):
            out_pc_res = in_pc_pad[:,0:in_pn].clone()
        else:
            p_neighbors_raw = self.p_neighbors
            in_neighbors = in_pc_pad[:,neighbor_id_lstlst] #batch*out_pn*max_neighbor_num*out_channel
            
            p_neighbors = torch.abs(p_neighbors_raw)  * neighbor_mask_lst
            p_neighbors_sum = p_neighbors.sum(1) + 1e-8 #out_pn
            p_neighbors = p_neighbors/p_neighbors_sum.view(out_pn,1).repeat(1,max_neighbor_num)
        
            out_pc_res = torch.einsum('pm,bpmo->bpo', [p_neighbors, in_neighbors])
        
        out_pc = out_pc*np.sqrt(1-self.residual_rate) + out_pc_res*np.sqrt(self.residual_rate)
        
        return out_pc