def forward()

in python/dgllife/model/model_zoo/jtvae.py [0:0]


    def forward(self, tree_graphs, tree_vec):
        device = tree_vec.device
        batch_size = tree_graphs.batch_size

        root_ids = get_root_ids(tree_graphs)

        if 'x' not in tree_graphs.ndata:
            tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
        if 'src_x' not in tree_graphs.edata:
            tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
        tree_graphs = tree_graphs.local_var()
        tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']})

        line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True)
        line_num_nodes = line_tree_graphs.num_nodes()
        line_tree_graphs.ndata.update({
            'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']),
            # Exploit the fact that the reduce function is a sum of incoming messages,
            # and uncomputed messages are zero vectors.
            'h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
            'vec': dgl.broadcast_edges(tree_graphs, tree_vec),
            'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
            'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device)
        })

        # input tensors for stop prediction (p) and label prediction (q)
        pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
        stop_hiddens, stop_targets = [], []

        # Predict root
        pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device))
        pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)])
        pred_mol_vecs.append(tree_vec)

        # Traverse the tree and predict on children
        for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
            eid = eid.to(device)
            p = p.to(device=device, dtype=tree_graphs.idtype)

            # Message passing excluding the target
            line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
                                  reduce_func=fn.sum('h_nei', 'sum_h'))
            line_tree_graphs.pull(v=eid, message_func=self.gru_message,
                                  reduce_func=fn.sum('m', 'sum_gated_h'))
            line_tree_graphs.apply_nodes(self.gru_update, v=eid)

            # Node aggregation including the target
            # By construction, the edges of the raw graph follow the order of
            # (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes
            # in the line graph corresponds to the order of the edges in the raw graph.
            eid = eid.long()
            reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device))
            cur_o = line_tree_graphs.ndata['sum_h'][eid] + \
                    line_tree_graphs.ndata['h'][reverse_eid]

            # Gather targets
            mask = (p == torch.tensor(0).to(device))
            pred_list = eid[mask]
            stop_target = torch.tensor(1).to(device) - p

            # Hidden states for stop prediction
            stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid],
                                     cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1)
            stop_hiddens.append(stop_hidden)
            stop_targets.extend(stop_target)

            #Hidden states for clique prediction
            if len(pred_list) > 0:
                pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list])
                pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list])
                pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list])

        #Last stop at root
        root_ids = root_ids.to(device)
        cur_x = tree_graphs.ndata['x'][root_ids]
        tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
        tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
                         message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o'))
        stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1)
        stop_hiddens.append(stop_hidden)
        stop_targets.extend(torch.zeros(batch_size).to(device))

        # Predict next clique
        pred_hiddens = torch.cat(pred_hiddens, dim=0)
        pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
        pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
        pred_vecs = F.relu(self.W(pred_vecs))
        pred_scores = self.W_o(pred_vecs)
        pred_targets = torch.cat(pred_targets, dim=0)

        pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size
        _, preds = torch.max(pred_scores, dim=1)
        pred_acc = torch.eq(preds, pred_targets).float()
        pred_acc = torch.sum(pred_acc) / pred_targets.nelement()

        # Predict stop
        stop_hiddens = torch.cat(stop_hiddens, dim=0)
        stop_vecs = F.relu(self.U(stop_hiddens))
        stop_scores = self.U_s(stop_vecs).squeeze()
        stop_targets = torch.Tensor(stop_targets).to(device)

        stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size
        stops = torch.ge(stop_scores, 0).float()
        stop_acc = torch.eq(stops, stop_targets).float()
        stop_acc = torch.sum(stop_acc) / stop_targets.nelement()

        return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()