def dfs_assemble()

in python/dgllife/model/model_zoo/jtvae.py [0:0]


    def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap, fa_amap,
                     cur_node, fa_node, prob_decode):
        fa_nid = fa_node['nid'] if fa_node is not None else -1
        prev_nodes = [fa_node] if fa_node is not None else []

        children = [nei for nei in cur_node['neighbors'] if nei['nid'] != fa_nid]
        neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
        neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
        singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
        neighbors = singletons + neighbors

        cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
        cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
        if len(cands) == 0:
            return None
        _, cand_mols, cand_amap = zip(*cands)

        cands = [(candmol, all_nodes, cur_node) for candmol in cand_mols]

        cand_vecs = self.jtmpn(cands, tree_mess, mol_vec.device)
        cand_vecs = self.G_mean(cand_vecs)
        mol_vec = mol_vec.squeeze()
        scores = torch.mv(cand_vecs, mol_vec) * 20

        if prob_decode:
            probs = torch.softmax(scores.view(1, -1)).squeeze() + 1e-5  # prevent prob = 0
            cand_idx = torch.multinomial(probs, probs.numel())
        else:
            _, cand_idx = torch.sort(scores, descending=True)

        backup_mol = Chem.RWMol(cur_mol)
        for i in range(cand_idx.numel()):
            cur_mol = Chem.RWMol(backup_mol)
            pred_amap = cand_amap[cand_idx[i].item()]
            new_global_amap = copy.deepcopy(global_amap)

            for nei_id, ctr_atom, nei_atom in pred_amap:
                if nei_id == fa_nid:
                    continue
                new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]

            # father is already attached
            cur_mol = attach_mols(cur_mol, children, [], new_global_amap)
            new_mol = cur_mol.GetMol()
            new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))

            if new_mol is None:
                continue

            result = True
            for nei_node in children:
                if nei_node['is_leaf']:
                    continue
                cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes, cur_mol,
                                            new_global_amap, pred_amap, nei_node,
                                            cur_node, prob_decode)
                if cur_mol is None:
                    result = False
                    break
            if result:
                return cur_mol

        return None