def decode()

in python/dgllife/model/model_zoo/jtvae.py [0:0]


    def decode(self, mol_vec, prob_decode):
        device = mol_vec.device
        stack = []
        init_hidden = torch.zeros(1, self.hidden_size).to(device)
        zero_pad = torch.zeros(1, 1, self.hidden_size).to(device)

        # Root Prediction
        root_hidden = torch.cat([init_hidden, mol_vec], dim=1)
        root_hidden = F.relu(self.W(root_hidden))
        root_score = self.W_o(root_hidden)
        _, root_wid = torch.max(root_score, dim=1)
        root_wid = root_wid.item()

        root = mol_tree_node(smiles=self.vocab.get_smiles(root_wid), wid=root_wid, idx=0)
        stack.append((root, self.vocab.get_slots(root['wid'])))

        all_nodes = [root]
        h = {}
        for step in range(MAX_DECODE_LEN):
            node_x, fa_slot = stack[-1]
            cur_h_nei = [h[(node_y['idx'], node_x['idx'])] for node_y in node_x['neighbors']]
            if len(cur_h_nei) > 0:
                cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size)
            else:
                cur_h_nei = zero_pad

            cur_x = torch.LongTensor([node_x['wid']]).to(device)
            cur_x = self.embedding(cur_x)

            # Predict stop
            cur_h = cur_h_nei.sum(dim=1)
            stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1)
            stop_hidden = F.relu(self.U(stop_hidden))
            stop_score = torch.sigmoid(self.U_s(stop_hidden) * 20).squeeze()

            if prob_decode:
                backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
            else:
                backtrack = (stop_score.item() < 0.5)

            if not backtrack:  # Forward: Predict next clique
                new_h = gru_functional(cur_x, cur_h_nei, self.gru_update.W_z, self.W_r,
                                       self.gru_message.U_r, self.gru_update.W_h)
                pred_hidden = torch.cat([new_h, mol_vec], dim=1)
                pred_hidden = F.relu(self.W(pred_hidden))
                pred_score = torch.softmax(self.W_o(pred_hidden) * 20, dim=1)
                if prob_decode:
                    sort_wid = torch.multinomial(pred_score.data.squeeze(), 5)
                else:
                    _, sort_wid = torch.sort(pred_score, dim=1, descending=True)
                    sort_wid = sort_wid.data.squeeze()

                next_wid = None
                for wid in sort_wid[:5]:
                    slots = self.vocab.get_slots(wid)
                    node_y = mol_tree_node(smiles=self.vocab.get_smiles(wid))
                    if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
                        next_wid = wid
                        next_slots = slots
                        break

                if next_wid is None:
                    backtrack = True  # No more children can be added
                else:
                    node_y = mol_tree_node(smiles=self.vocab.get_smiles(next_wid),
                                           wid=next_wid, idx=step + 1, nbrs=[node_x])
                    h[(node_x['idx'], node_y['idx'])] = new_h[0]
                    stack.append((node_y, next_slots))
                    all_nodes.append(node_y)

            if backtrack:  # Backtrack, use if instead of else
                if len(stack) == 1:
                    break  # At root, terminate

                node_fa, _ = stack[-2]
                cur_h_nei = [h[(node_y['idx'], node_x['idx'])] for node_y in node_x['neighbors']
                             if node_y['idx'] != node_fa['idx']]
                if len(cur_h_nei) > 0:
                    cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size)
                else:
                    cur_h_nei = zero_pad

                new_h = gru_functional(cur_x, cur_h_nei, self.gru_update.W_z, self.W_r,
                                       self.gru_message.U_r, self.gru_update.W_h)
                h[(node_x['idx'], node_fa['idx'])] = new_h[0]
                node_fa['neighbors'].append(node_x)
                stack.pop()

        return root, all_nodes