def decode()

in python/dgllife/model/model_zoo/jtvae.py [0:0]


    def decode(self, tree_vec, mol_vec, prob_decode):
        device = tree_vec.device
        pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode)

        # Mark nid & is_leaf & atommap
        for i, node in enumerate(pred_nodes):
            node['nid'] = i + 1
            node['is_leaf'] = (len(node['neighbors']) == 1)
            if len(node['neighbors']) > 1:
                set_atommap(node['mol'], node['nid'])

        src = []
        dst = []
        for node in pred_nodes:
            cur_id = node['idx']
            for nbr in node['neighbors']:
                nbr_id = nbr['idx']
                src.extend([cur_id])
                dst.extend([nbr_id])
        if len(src) == 0:
            tree_graph = dgl.graph((src, dst), idtype=torch.int32, device=device,
                                   num_nodes=max([node['idx'] + 1 for node in pred_nodes]))
        else:
            tree_graph = dgl.graph((src, dst), idtype=torch.int32, device=device)
        node_ids = torch.LongTensor([node['idx'] for node in pred_nodes]).to(device)
        node_wid = torch.LongTensor([node['wid'] for node in pred_nodes]).to(device)
        tree_graph_x = torch.zeros(tree_graph.num_nodes(), self.hidden_size).to(device)
        tree_graph_x[node_ids] = self.embedding(node_wid)
        tree_graph.ndata['x'] = tree_graph_x
        tree_mess = self.jtnn(tree_graph)[0]
        tree_mess = self.edata_to_dict(tree_graph, tree_mess)

        cur_mol = copy_edit_mol(pred_root['mol'])
        global_amap = [{}] + [{} for _ in pred_nodes]
        global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}

        cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_amap, [],
                                    pred_root, None, prob_decode)
        if cur_mol is None:
            return None

        cur_mol = cur_mol.GetMol()
        set_atommap(cur_mol)
        cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
        if cur_mol is None:
            return None
        if not self.use_stereo:
            return Chem.MolToSmiles(cur_mol)

        smiles2D = Chem.MolToSmiles(cur_mol)
        stereo_cands = decode_stereo(smiles2D)
        if len(stereo_cands) == 1:
            return stereo_cands[0]

        stereo_cand_graphs = []
        for cand in stereo_cands:
            cand = get_mol(cand)
            cg = mol_to_bigraph(cand, node_featurizer=self.atom_featurizer,
                                edge_featurizer=self.bond_featurizer,
                                canonical_atom_order=False)
            cg.apply_edges(fn.copy_u('x', 'src'))
            cg.edata['x'] = torch.cat([cg.edata.pop('src'), cg.edata['x']], dim=1)
            stereo_cand_graphs.append(cg)
        stereo_cand_graphs = dgl.batch(stereo_cand_graphs).to(device)

        stereo_vecs = self.mpn(stereo_cand_graphs)
        stereo_vecs = self.G_mean(stereo_vecs)
        scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
        _, max_id = scores.max(dim=0)
        return stereo_cands[max_id.item()]