in python/dgllife/model/model_zoo/jtvae.py [0:0]
def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap, fa_amap,
cur_node, fa_node, prob_decode):
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children = [nei for nei in cur_node['neighbors'] if nei['nid'] != fa_nid]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
_, cand_mols, cand_amap = zip(*cands)
cands = [(candmol, all_nodes, cur_node) for candmol in cand_mols]
cand_vecs = self.jtmpn(cands, tree_mess, mol_vec.device)
cand_vecs = self.G_mean(cand_vecs)
mol_vec = mol_vec.squeeze()
scores = torch.mv(cand_vecs, mol_vec) * 20
if prob_decode:
probs = torch.softmax(scores.view(1, -1)).squeeze() + 1e-5 # prevent prob = 0
cand_idx = torch.multinomial(probs, probs.numel())
else:
_, cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
for i in range(cand_idx.numel()):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None:
continue
result = True
for nei_node in children:
if nei_node['is_leaf']:
continue
cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes, cur_mol,
new_global_amap, pred_amap, nei_node,
cur_node, prob_decode)
if cur_mol is None:
result = False
break
if result:
return cur_mol
return None