in python/dgllife/utils/jtvae/chemutils.py [0:0]
def tree_decomp(mol, mst_max_weight=100):
"""Tree decomposition of a molecule for junction tree construction.
Parameters
----------
mol : rdkit.Chem.rdchem.Mol
A molecule.
mst_max_weight : int
Max weight considered in generating a minimum spanning tree.
Returns
-------
list
Clusters. Each element is a list of int,
representing the atoms that constitute the cluster.
list
Edges between the clusters. Each element is a 2-tuple of cluster IDs.
"""
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
return [[0]], []
cliques = []
# Find all edges not belonging to any cycles
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1, a2])
# Find all simple cycles, each represented by a list of IDs of the atoms in the ring
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
# Record the non-ring bonds/simple cycles that each atom belongs to
nei_list = [[] for _ in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Merge simple rings that have more than two overlapping atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
cliques[i] = list(set(cliques[i]))
cliques[j] = []
# Remove merged simple cycles
cliques = [c for c in cliques if len(c) > 0]
# Record the non-ring bonds/simple cycles that each atom belongs to
nei_list = [[] for _ in range(n_atoms)]
for i in range(len(cliques)):
for atom in cliques[i]:
nei_list[atom].append(i)
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
continue
cnei = nei_list[atom]
bonds = [c for c in cnei if len(cliques[c]) == 2]
rings = [c for c in cnei if len(cliques[c]) > 4]
# In general, if len(cnei) >= 3, a singleton should be added,
# but 1 bond + 2 ring is currently not dealt with.
if len(bonds) > 2 or (len(bonds) == 2 and len(cnei) > 2):
# Add singleton clique
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = 1
elif len(rings) > 2:
# Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1, c2)] = mst_max_weight - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1, c2)] < len(inter):
# cnei[i] < cnei[j] by construction
edges[(c1, c2)] = len(inter)
edges = [u + (mst_max_weight - v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
# Compute Maximum Spanning Tree
row, col, data = zip(*edges)
n_clique = len(cliques)
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)