in GraphAutoEncoder/graphVAESSW.py [0:0]
def __init__(self, in_channel, out_channel, weight_num,in_point_num, connection_info, b_Perpt_bias = True, residual_rate = 0.0): #layer_info_lst= [(point_num, feature_dim)]
super(LASMConvssw, self).__init__()
self.relu = nn.ELU()
self.in_channel = in_channel
self.out_channel = out_channel
self.weight_num = weight_num
self.in_point_num = in_point_num
out_point_num = connection_info.shape[0]
self.out_point_num = out_point_num
neighbor_num_lst = torch.from_numpy(connection_info[:,0].astype(np.float32)).float() #out_point_num*1
self.register_buffer("neighbor_num_lst", neighbor_num_lst)
neighbor_id_dist_lstlst = connection_info[:, 1:] #out_point_num*(max_neighbor_num*2)
neighbor_id_lstlst = neighbor_id_dist_lstlst.reshape((out_point_num, -1,2))[:,:,0] #out_point_num*max_neighbor_num
neighbor_id_lstlst = torch.from_numpy(neighbor_id_lstlst).long()
self.register_buffer("neighbor_id_lstlst", neighbor_id_lstlst)
max_neighbor_num = neighbor_id_lstlst.shape[1]
self.max_neighbor_num = max_neighbor_num
avg_neighbor_num= round(neighbor_num_lst.mean().item())
self.avg_neighbor_num = avg_neighbor_num
####parameters for conv###############
weights = nn.Parameter(torch.randn(weight_num, out_channel*in_channel))
self.register_parameter("weights",weights)
bias = nn.Parameter(torch.zeros(out_channel))
if b_Perpt_bias:
bias= nn.Parameter(torch.zeros(out_point_num, out_channel))
self.register_parameter("bias",bias)
self.residual_rate = residual_rate
####parameters for residual###############
#residual_layer = ""
if self.residual_rate > 0:
if(out_point_num != in_point_num):
p_neighbors = nn.Parameter(torch.randn(out_point_num, max_neighbor_num)/(avg_neighbor_num))
self.register_parameter("p_neighbors",p_neighbors)
if(out_channel != in_channel):
weight_res = torch.randn(1, out_channel*in_channel)
weight_res = weight_res/out_channel
weight_res = nn.Parameter(weight_res)
self.register_parameter("weight_res",weight_res)
print ("in_channel", in_channel,\
"out_channel",out_channel, \
"in_point_num", in_point_num, \
"out_point_num", out_point_num, \
"weight_num", weight_num,\
"max_neighbor_num", max_neighbor_num)