in python/dgllife/model/model_zoo/jtvae.py [0:0]
def decode(self, tree_vec, mol_vec, prob_decode):
device = tree_vec.device
pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode)
# Mark nid & is_leaf & atommap
for i, node in enumerate(pred_nodes):
node['nid'] = i + 1
node['is_leaf'] = (len(node['neighbors']) == 1)
if len(node['neighbors']) > 1:
set_atommap(node['mol'], node['nid'])
src = []
dst = []
for node in pred_nodes:
cur_id = node['idx']
for nbr in node['neighbors']:
nbr_id = nbr['idx']
src.extend([cur_id])
dst.extend([nbr_id])
if len(src) == 0:
tree_graph = dgl.graph((src, dst), idtype=torch.int32, device=device,
num_nodes=max([node['idx'] + 1 for node in pred_nodes]))
else:
tree_graph = dgl.graph((src, dst), idtype=torch.int32, device=device)
node_ids = torch.LongTensor([node['idx'] for node in pred_nodes]).to(device)
node_wid = torch.LongTensor([node['wid'] for node in pred_nodes]).to(device)
tree_graph_x = torch.zeros(tree_graph.num_nodes(), self.hidden_size).to(device)
tree_graph_x[node_ids] = self.embedding(node_wid)
tree_graph.ndata['x'] = tree_graph_x
tree_mess = self.jtnn(tree_graph)[0]
tree_mess = self.edata_to_dict(tree_graph, tree_mess)
cur_mol = copy_edit_mol(pred_root['mol'])
global_amap = [{}] + [{} for _ in pred_nodes]
global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_amap, [],
pred_root, None, prob_decode)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None:
return None
if not self.use_stereo:
return Chem.MolToSmiles(cur_mol)
smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_cand_graphs = []
for cand in stereo_cands:
cand = get_mol(cand)
cg = mol_to_bigraph(cand, node_featurizer=self.atom_featurizer,
edge_featurizer=self.bond_featurizer,
canonical_atom_order=False)
cg.apply_edges(fn.copy_u('x', 'src'))
cg.edata['x'] = torch.cat([cg.edata.pop('src'), cg.edata['x']], dim=1)
stereo_cand_graphs.append(cg)
stereo_cand_graphs = dgl.batch(stereo_cand_graphs).to(device)
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
_, max_id = scores.max(dim=0)
return stereo_cands[max_id.item()]