def __init__()

in python/dgllife/data/uspto.py [0:0]


    def __init__(self,
                 raw_file_path,
                 mol_graph_path,
                 mol_to_graph=mol_to_bigraph,
                 node_featurizer=default_node_featurizer_center,
                 edge_featurizer=default_edge_featurizer_center,
                 atom_pair_featurizer=default_atom_pair_featurizer,
                 load=True,
                 num_processes=1,
                 check_reaction_validity=True,
                 reaction_validity_result_prefix='',
                 cache=True,
                 **kwargs):
        super(WLNCenterDataset, self).__init__()

        self._atom_pair_featurizer = atom_pair_featurizer
        self.atom_pair_features = []
        self.atom_pair_labels = []
        # Map number of nodes to a corresponding complete graph
        self.complete_graphs = dict()
        self.cache = cache

        path_to_reaction_file = raw_file_path + '.proc'
        built_in = kwargs.get('built_in', False)
        if not built_in:
            print('Pre-processing graph edits from reaction data')
            process_file(raw_file_path, num_processes)

        if check_reaction_validity:
            print('Start checking validity of input reactions for modeling...')
            valid_reactions, invalid_reactions = \
                reaction_validity_full_check(path_to_reaction_file)
            print('# valid reactions {:d}'.format(len(valid_reactions)))
            print('# invalid reactions {:d}'.format(len(invalid_reactions)))
            path_to_valid_reactions = reaction_validity_result_prefix + \
                                      '_valid_reactions.proc'
            path_to_invalid_reactions = reaction_validity_result_prefix + \
                                        '_invalid_reactions.proc'
            with open(path_to_valid_reactions, 'w') as f:
                for line in valid_reactions:
                    f.write(line)
            with open(path_to_invalid_reactions, 'w') as f:
                for line in invalid_reactions:
                    f.write(line)
            path_to_reaction_file = path_to_valid_reactions

        import time
        t0 = time.time()
        full_mols, full_reactions, full_graph_edits = \
            self.load_reaction_data(path_to_reaction_file, num_processes)
        self.mols = full_mols
        self.reactions = full_reactions
        self.graph_edits = full_graph_edits
        print('Time spent', time.time() - t0)

        if self.cache:
            if load and os.path.isfile(mol_graph_path):
                print('Loading previously saved graphs...')
                self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
            else:
                print('Constructing graphs from scratch...')
                if num_processes == 1:
                    self.reactant_mol_graphs = []
                    for mol in full_mols:
                        self.reactant_mol_graphs.append(mol_to_graph(
                            mol, node_featurizer=node_featurizer,
                            edge_featurizer=edge_featurizer, canonical_atom_order=False))
                else:
                    torch.multiprocessing.set_sharing_strategy('file_system')
                    with Pool(processes=num_processes) as pool:
                        self.reactant_mol_graphs = pool.map(
                            partial(mol_to_graph, node_featurizer=node_featurizer,
                                    edge_featurizer=edge_featurizer, canonical_atom_order=False),
                            full_mols)

                save_graphs(mol_graph_path, self.reactant_mol_graphs)
        else:
            self.mol_to_graph = mol_to_graph
            self.node_featurizer = node_featurizer
            self.edge_featurizer = edge_featurizer

        self.atom_pair_features.extend([None for _ in range(len(self.mols))])
        self.atom_pair_labels.extend([None for _ in range(len(self.mols))])