in python/dgllife/data/uspto.py [0:0]
def pre_process_one_reaction(info, num_candidate_bond_changes, max_num_bond_changes,
max_num_change_combos, mode):
"""Pre-process one reaction for candidate ranking.
Parameters
----------
info : 4-tuple
* candidate_bond_changes : list of tuples
The candidate bond changes for the reaction
* real_bond_changes : list of tuples
The real bond changes for the reaction
* reactant_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for reactants
* product_mol : rdkit.Chem.rdchem.Mol
RDKit molecule instance for product
num_candidate_bond_changes : int
Number of candidate bond changes to consider for the ground truth reaction.
max_num_bond_changes : int
Maximum number of bond changes per reaction.
max_num_change_combos : int
Number of bond change combos to consider for each reaction.
node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph.
mode : str
Whether the dataset is to be used for training, validation or test.
Returns
-------
valid_candidate_combos : list
valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
of candidate bond changes for the reaction.
candidate_bond_changes : list of 4-tuples
Refined candidate bond changes considered for combos.
reactant_info : dict
Reaction-related information of reactants.
"""
assert mode in ['train', 'val', 'test'], \
"Expect mode to be 'train' or 'val' or 'test', got {}".format(mode)
candidate_bond_changes_, real_bond_changes, reactant_mol, product_mol = info
candidate_pairs = [(atom1, atom2) for (atom1, atom2, _, _)
in candidate_bond_changes_]
reactant_info = bookkeep_reactant(reactant_mol, candidate_pairs)
if mode == 'train':
product_info = bookkeep_product(product_mol)
# Filter out candidate new bonds already in reactants
candidate_bond_changes = []
count = 0
for (atom1, atom2, change_type, score) in candidate_bond_changes_:
if ((atom1, atom2) not in reactant_info['pair_to_bond_val']) or \
(reactant_info['pair_to_bond_val'][(atom1, atom2)] != change_type):
candidate_bond_changes.append((atom1, atom2, change_type, score))
count += 1
if count == num_candidate_bond_changes:
break
# Check if two bond changes have atom in common
cand_change_adj = np.eye(len(candidate_bond_changes), dtype=bool)
for i in range(len(candidate_bond_changes)):
atom1_1, atom1_2, _, _ = candidate_bond_changes[i]
for j in range(i + 1, len(candidate_bond_changes)):
atom2_1, atom2_2, _, _ = candidate_bond_changes[j]
if atom1_1 == atom2_1 or atom1_1 == atom2_2 or \
atom1_2 == atom2_1 or atom1_2 == atom2_2:
cand_change_adj[i, j] = cand_change_adj[j, i] = True
# Enumerate combinations of k candidate bond changes and record
# those that are connected and chemically valid
valid_candidate_combos = []
cand_change_ids = range(len(candidate_bond_changes))
for k in range(1, max_num_bond_changes + 1):
for combo_ids in combinations(cand_change_ids, k):
# Check if the changed bonds form a connected component
if not is_connected_change_combo(combo_ids, cand_change_adj):
continue
combo_changes = [candidate_bond_changes[j] for j in combo_ids]
# Check if the combo is chemically valid
if is_valid_combo(combo_changes, reactant_info):
valid_candidate_combos.append(combo_changes)
if mode == 'train':
random.shuffle(valid_candidate_combos)
# Index for the combo of candidate bond changes
# that is equivalent to the gold combo
real_combo_id = -1
for j, combo_changes in enumerate(valid_candidate_combos):
if set([(atom1, atom2, change_type) for
(atom1, atom2, change_type, score) in combo_changes]) == \
set(real_bond_changes):
real_combo_id = j
break
# If we fail to find the real combo, make it the first entry
if real_combo_id == -1:
valid_candidate_combos = \
[[(atom1, atom2, change_type, 0.0)
for (atom1, atom2, change_type) in real_bond_changes]] + \
valid_candidate_combos
else:
valid_candidate_combos[0], valid_candidate_combos[real_combo_id] = \
valid_candidate_combos[real_combo_id], valid_candidate_combos[0]
product_smiles = get_product_smiles(
reactant_mol, valid_candidate_combos[0], product_info)
if len(product_smiles) > 0:
# Remove combos yielding duplicate products
product_smiles = set([product_smiles])
new_candidate_combos = [valid_candidate_combos[0]]
count = 0
for combo in valid_candidate_combos[1:]:
smiles = get_product_smiles(reactant_mol, combo, product_info)
if smiles in product_smiles or len(smiles) == 0:
continue
product_smiles.add(smiles)
new_candidate_combos.append(combo)
count += 1
if count == max_num_change_combos:
break
valid_candidate_combos = new_candidate_combos
valid_candidate_combos = valid_candidate_combos[:max_num_change_combos]
return valid_candidate_combos, candidate_bond_changes, reactant_info