in python/dgllife/data/uspto.py [0:0]
def __init__(self,
raw_file_path,
mol_graph_path,
mol_to_graph=mol_to_bigraph,
node_featurizer=default_node_featurizer_center,
edge_featurizer=default_edge_featurizer_center,
atom_pair_featurizer=default_atom_pair_featurizer,
load=True,
num_processes=1,
check_reaction_validity=True,
reaction_validity_result_prefix='',
cache=True,
**kwargs):
super(WLNCenterDataset, self).__init__()
self._atom_pair_featurizer = atom_pair_featurizer
self.atom_pair_features = []
self.atom_pair_labels = []
# Map number of nodes to a corresponding complete graph
self.complete_graphs = dict()
self.cache = cache
path_to_reaction_file = raw_file_path + '.proc'
built_in = kwargs.get('built_in', False)
if not built_in:
print('Pre-processing graph edits from reaction data')
process_file(raw_file_path, num_processes)
if check_reaction_validity:
print('Start checking validity of input reactions for modeling...')
valid_reactions, invalid_reactions = \
reaction_validity_full_check(path_to_reaction_file)
print('# valid reactions {:d}'.format(len(valid_reactions)))
print('# invalid reactions {:d}'.format(len(invalid_reactions)))
path_to_valid_reactions = reaction_validity_result_prefix + \
'_valid_reactions.proc'
path_to_invalid_reactions = reaction_validity_result_prefix + \
'_invalid_reactions.proc'
with open(path_to_valid_reactions, 'w') as f:
for line in valid_reactions:
f.write(line)
with open(path_to_invalid_reactions, 'w') as f:
for line in invalid_reactions:
f.write(line)
path_to_reaction_file = path_to_valid_reactions
import time
t0 = time.time()
full_mols, full_reactions, full_graph_edits = \
self.load_reaction_data(path_to_reaction_file, num_processes)
self.mols = full_mols
self.reactions = full_reactions
self.graph_edits = full_graph_edits
print('Time spent', time.time() - t0)
if self.cache:
if load and os.path.isfile(mol_graph_path):
print('Loading previously saved graphs...')
self.reactant_mol_graphs, _ = load_graphs(mol_graph_path)
else:
print('Constructing graphs from scratch...')
if num_processes == 1:
self.reactant_mol_graphs = []
for mol in full_mols:
self.reactant_mol_graphs.append(mol_to_graph(
mol, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False))
else:
torch.multiprocessing.set_sharing_strategy('file_system')
with Pool(processes=num_processes) as pool:
self.reactant_mol_graphs = pool.map(
partial(mol_to_graph, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer, canonical_atom_order=False),
full_mols)
save_graphs(mol_graph_path, self.reactant_mol_graphs)
else:
self.mol_to_graph = mol_to_graph
self.node_featurizer = node_featurizer
self.edge_featurizer = edge_featurizer
self.atom_pair_features.extend([None for _ in range(len(self.mols))])
self.atom_pair_labels.extend([None for _ in range(len(self.mols))])