in GraphAutoEncoder/graphVAESSW.py [0:0]
def forward(self, in_pc, raw_w_weights, is_final_layer=False, b_max_pool = False):
batch = in_pc.shape[0]
device = in_pc.device #in_pc.device
in_channel = self.in_channel
out_channel = self.out_channel
in_pn = self.in_point_num
out_pn = self.out_point_num
weight_num = self.weight_num #M
max_neighbor_num = self.max_neighbor_num #N
neighbor_num_lst = self.neighbor_num_lst
neighbor_id_lstlst = self.neighbor_id_lstlst
pc_mask = torch.ones(in_pn+1).float().to(in_pc.device)
pc_mask[in_pn]=0
neighbor_mask_lst = index_selection_nd(pc_mask,neighbor_id_lstlst,0).contiguous()#out_pn*max_neighbor_num neighbor is 1 otherwise 0
raw_weights = self.weights
bias = self.bias
w_weights = raw_w_weights*(neighbor_mask_lst.view(out_pn, max_neighbor_num, 1)) #out_pn*max_neighbor_num*weight_num
in_pc_pad = torch.cat((in_pc, torch.zeros(batch, 1, in_channel).float().to(in_pc.device)), 1) #batch (in_pn+1) in_channel
in_neighbors = index_selection_nd(in_pc_pad,neighbor_id_lstlst, 1)
fuse_neighbors = torch.einsum('pnm,bpni->bpmi',[w_weights, in_neighbors]) #batch*out_pn*max_neighbor_num*out_channel
normalized_weights = raw_weights.view(weight_num,out_channel,in_channel)
out_neighbors = torch.einsum('moi,bpmi->bpmo',[normalized_weights, fuse_neighbors]) #out_pn*max_neighbor_num*(out_channel*in_channel)
out_pc = "" #batch*out_pn*out_channel
if b_max_pool:
out_pc = out_neighbors.max(2)
else:
out_pc = out_neighbors.sum(2)
out_pc = out_pc + bias
if is_final_layer==False:
out_pc = self.relu(out_pc) ##self.relu is defined in the init function
if self.residual_rate==0:
return out_pc
if(in_channel != out_channel):
in_pc_pad = torch.einsum('oi,bpi->bpo',[self.weight_res.view(out_channel,in_channel), in_pc_pad])
out_pc_res = []
if(in_pn == out_pn):
out_pc_res = in_pc_pad[:,0:in_pn].clone()
else:
p_neighbors_raw = self.p_neighbors
in_neighbors = index_selection_nd(in_pc_pad,neighbor_id_lstlst, 1)
#p_neighbors = torch.sigmoid(p_neighbors_raw) * neighbor_mask_lst
p_neighbors = torch.abs(p_neighbors_raw) * neighbor_mask_lst
p_neighbors_sum = p_neighbors.sum(1) + 1e-8 #out_pn
p_neighbors = p_neighbors/p_neighbors_sum.view(out_pn,1).repeat(1,max_neighbor_num)
out_pc_res = torch.einsum('pn,bpno->bpo', [p_neighbors, in_neighbors])
out_pc = out_pc*np.sqrt(1-self.residual_rate) + out_pc_res*np.sqrt(self.residual_rate)
return out_pc