in GraphAutoEncoder/graphVAESSW.py [0:0]
def forward2(self, in_pc, raw_w_weights, is_final_layer=False, b_max_pool = False): #layer_info
batch = in_pc.shape[0]
device = in_pc.device
in_channel = self.in_channel
out_channel = self.out_channel
in_pn = self.in_point_num
out_pn = self.out_point_num
weight_num = self.weight_num
max_neighbor_num = self.max_neighbor_num
neighbor_num_lst = self.neighbor_num_lst
neighbor_id_lstlst = self.neighbor_id_lstlst
pc_mask = torch.ones(in_pn+1).float().to(device)
pc_mask[in_pn]=0
neighbor_mask_lst = pc_mask[neighbor_id_lstlst].contiguous() #out_pn*max_neighbor_num neighbor is 1 otherwise 0
raw_weights = self.weights
bias = self.bias
w_weights = raw_w_weights*neighbor_mask_lst.view(out_pn, max_neighbor_num, 1).repeat(1,1,weight_num) #out_pn*max_neighbor_num*weight_num
normalized_weights = raw_weights
weights = torch.einsum('pmw,wc->pmc',[w_weights, normalized_weights]) #out_pn*max_neighbor_num*(out_channel*in_channel)
weights = weights.view(out_pn, max_neighbor_num, out_channel,in_channel)
in_pc_pad = torch.cat((in_pc, torch.zeros(batch, 1, in_channel).float().to(device)), 1) #batch*(in_pn+1)*in_channel
in_neighbors = in_pc_pad[:, neighbor_id_lstlst] #batch*out_pn*max_neighbor_num*in_channel
out_neighbors = torch.einsum('pmoi,bpmi->bpmo',[weights, in_neighbors]) #batch*out_pn*max_neighbor_num*out_channel
out_pc = ""
if b_max_pool:
out_pc = out_neighbors.max(2)
else:
out_pc = out_neighbors.sum(2)
out_pc = out_pc + bias
if is_final_layer==False:
out_pc = self.relu(out_pc)
if self.residual_rate==0:
return out_pc
if(in_channel != out_channel):
in_pc_pad = torch.einsum('oi,bpi->bpo',[self.weight_res.view(out_channel,in_channel), in_pc_pad])
out_pc_res = []
if(in_pn == out_pn):
out_pc_res = in_pc_pad[:,0:in_pn].clone()
else:
p_neighbors_raw = self.p_neighbors
in_neighbors = in_pc_pad[:,neighbor_id_lstlst] #batch*out_pn*max_neighbor_num*out_channel
p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst
p_neighbors_sum = p_neighbors.sum(1) + 1e-8 #out_pn
p_neighbors = p_neighbors/p_neighbors_sum.view(out_pn,1).repeat(1,max_neighbor_num)
out_pc_res = torch.einsum('pm,bpmo->bpo', [p_neighbors, in_neighbors])
out_pc = out_pc*np.sqrt(1-self.residual_rate) + out_pc_res*np.sqrt(self.residual_rate)
return out_pc