def pre_process_one_reaction()

in python/dgllife/data/uspto.py [0:0]


def pre_process_one_reaction(info, num_candidate_bond_changes, max_num_bond_changes,
                             max_num_change_combos, mode):
    """Pre-process one reaction for candidate ranking.

    Parameters
    ----------
    info : 4-tuple
        * candidate_bond_changes : list of tuples
            The candidate bond changes for the reaction
        * real_bond_changes : list of tuples
            The real bond changes for the reaction
        * reactant_mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance for reactants
        * product_mol : rdkit.Chem.rdchem.Mol
            RDKit molecule instance for product
    num_candidate_bond_changes : int
        Number of candidate bond changes to consider for the ground truth reaction.
    max_num_bond_changes : int
        Maximum number of bond changes per reaction.
    max_num_change_combos : int
        Number of bond change combos to consider for each reaction.
    node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict
        Featurization for nodes like atoms in a molecule, which can be used to update
        ndata for a DGLGraph.
    mode : str
        Whether the dataset is to be used for training, validation or test.

    Returns
    -------
    valid_candidate_combos : list
        valid_candidate_combos[i] gives a list of tuples, which is the i-th valid combo
        of candidate bond changes for the reaction.
    candidate_bond_changes : list of 4-tuples
        Refined candidate bond changes considered for combos.
    reactant_info : dict
        Reaction-related information of reactants.
    """
    assert mode in ['train', 'val', 'test'], \
        "Expect mode to be 'train' or 'val' or 'test', got {}".format(mode)
    candidate_bond_changes_, real_bond_changes, reactant_mol, product_mol = info
    candidate_pairs = [(atom1, atom2) for (atom1, atom2, _, _)
                       in candidate_bond_changes_]
    reactant_info = bookkeep_reactant(reactant_mol, candidate_pairs)
    if mode == 'train':
        product_info = bookkeep_product(product_mol)

    # Filter out candidate new bonds already in reactants
    candidate_bond_changes = []
    count = 0
    for (atom1, atom2, change_type, score) in candidate_bond_changes_:
        if ((atom1, atom2) not in reactant_info['pair_to_bond_val']) or \
                (reactant_info['pair_to_bond_val'][(atom1, atom2)] != change_type):
            candidate_bond_changes.append((atom1, atom2, change_type, score))
            count += 1
            if count == num_candidate_bond_changes:
                break

    # Check if two bond changes have atom in common
    cand_change_adj = np.eye(len(candidate_bond_changes), dtype=bool)
    for i in range(len(candidate_bond_changes)):
        atom1_1, atom1_2, _, _ = candidate_bond_changes[i]
        for j in range(i + 1, len(candidate_bond_changes)):
            atom2_1, atom2_2, _, _ = candidate_bond_changes[j]
            if atom1_1 == atom2_1 or atom1_1 == atom2_2 or \
                    atom1_2 == atom2_1 or atom1_2 == atom2_2:
                cand_change_adj[i, j] = cand_change_adj[j, i] = True

    # Enumerate combinations of k candidate bond changes and record
    # those that are connected and chemically valid
    valid_candidate_combos = []
    cand_change_ids = range(len(candidate_bond_changes))
    for k in range(1, max_num_bond_changes + 1):
        for combo_ids in combinations(cand_change_ids, k):
            # Check if the changed bonds form a connected component
            if not is_connected_change_combo(combo_ids, cand_change_adj):
                continue
            combo_changes = [candidate_bond_changes[j] for j in combo_ids]
            # Check if the combo is chemically valid
            if is_valid_combo(combo_changes, reactant_info):
                valid_candidate_combos.append(combo_changes)

    if mode == 'train':
        random.shuffle(valid_candidate_combos)
        # Index for the combo of candidate bond changes
        # that is equivalent to the gold combo
        real_combo_id = -1
        for j, combo_changes in enumerate(valid_candidate_combos):
            if set([(atom1, atom2, change_type) for
                    (atom1, atom2, change_type, score) in combo_changes]) == \
                    set(real_bond_changes):
                real_combo_id = j
                break

        # If we fail to find the real combo, make it the first entry
        if real_combo_id == -1:
            valid_candidate_combos = \
                [[(atom1, atom2, change_type, 0.0)
                  for (atom1, atom2, change_type) in real_bond_changes]] + \
                valid_candidate_combos
        else:
            valid_candidate_combos[0], valid_candidate_combos[real_combo_id] = \
                valid_candidate_combos[real_combo_id], valid_candidate_combos[0]

        product_smiles = get_product_smiles(
            reactant_mol, valid_candidate_combos[0], product_info)
        if len(product_smiles) > 0:
            # Remove combos yielding duplicate products
            product_smiles = set([product_smiles])
            new_candidate_combos = [valid_candidate_combos[0]]

            count = 0
            for combo in valid_candidate_combos[1:]:
                smiles = get_product_smiles(reactant_mol, combo, product_info)
                if smiles in product_smiles or len(smiles) == 0:
                    continue
                product_smiles.add(smiles)
                new_candidate_combos.append(combo)
                count += 1
                if count == max_num_change_combos:
                    break
            valid_candidate_combos = new_candidate_combos
    valid_candidate_combos = valid_candidate_combos[:max_num_change_combos]

    return valid_candidate_combos, candidate_bond_changes, reactant_info