in python/dgllife/model/model_zoo/jtvae.py [0:0]
def forward(self, tree_graphs, tree_vec):
device = tree_vec.device
batch_size = tree_graphs.batch_size
root_ids = get_root_ids(tree_graphs)
if 'x' not in tree_graphs.ndata:
tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
if 'src_x' not in tree_graphs.edata:
tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
tree_graphs = tree_graphs.local_var()
tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']})
line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True)
line_num_nodes = line_tree_graphs.num_nodes()
line_tree_graphs.ndata.update({
'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']),
# Exploit the fact that the reduce function is a sum of incoming messages,
# and uncomputed messages are zero vectors.
'h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
'vec': dgl.broadcast_edges(tree_graphs, tree_vec),
'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device)
})
# input tensors for stop prediction (p) and label prediction (q)
pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
stop_hiddens, stop_targets = [], []
# Predict root
pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device))
pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)])
pred_mol_vecs.append(tree_vec)
# Traverse the tree and predict on children
for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
eid = eid.to(device)
p = p.to(device=device, dtype=tree_graphs.idtype)
# Message passing excluding the target
line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
reduce_func=fn.sum('h_nei', 'sum_h'))
line_tree_graphs.pull(v=eid, message_func=self.gru_message,
reduce_func=fn.sum('m', 'sum_gated_h'))
line_tree_graphs.apply_nodes(self.gru_update, v=eid)
# Node aggregation including the target
# By construction, the edges of the raw graph follow the order of
# (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes
# in the line graph corresponds to the order of the edges in the raw graph.
eid = eid.long()
reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device))
cur_o = line_tree_graphs.ndata['sum_h'][eid] + \
line_tree_graphs.ndata['h'][reverse_eid]
# Gather targets
mask = (p == torch.tensor(0).to(device))
pred_list = eid[mask]
stop_target = torch.tensor(1).to(device) - p
# Hidden states for stop prediction
stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid],
cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend(stop_target)
#Hidden states for clique prediction
if len(pred_list) > 0:
pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list])
pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list])
pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list])
#Last stop at root
root_ids = root_ids.to(device)
cur_x = tree_graphs.ndata['x'][root_ids]
tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o'))
stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend(torch.zeros(batch_size).to(device))
# Predict next clique
pred_hiddens = torch.cat(pred_hiddens, dim=0)
pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
pred_vecs = F.relu(self.W(pred_vecs))
pred_scores = self.W_o(pred_vecs)
pred_targets = torch.cat(pred_targets, dim=0)
pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size
_, preds = torch.max(pred_scores, dim=1)
pred_acc = torch.eq(preds, pred_targets).float()
pred_acc = torch.sum(pred_acc) / pred_targets.nelement()
# Predict stop
stop_hiddens = torch.cat(stop_hiddens, dim=0)
stop_vecs = F.relu(self.U(stop_hiddens))
stop_scores = self.U_s(stop_vecs).squeeze()
stop_targets = torch.Tensor(stop_targets).to(device)
stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size
stops = torch.ge(stop_scores, 0).float()
stop_acc = torch.eq(stops, stop_targets).float()
stop_acc = torch.sum(stop_acc) / stop_targets.nelement()
return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()