in data.py [0:0]
def _build_corpus(data_path, env_params, sort_dict):
# save the corpus to a file so that it's faster next time
if sort_dict:
corpus_path = os.path.join(data_path, 'corpus_sorted.pt')
else:
corpus_path = os.path.join(data_path, 'corpus.pt')
if os.path.exists(corpus_path):
print('Loading an existing corpus file from {}'.format(corpus_path))
corpus = torch.load(corpus_path)
else:
print('Creating a corpus file at {}'.format(corpus_path))
if env_params['distributed']:
# only one process need to create a corpus file
if env_params['rank'] == 0:
corpus = Corpus(data_path, sort_dict)
torch.save(corpus, corpus_path)
# sync with other processes
torch.distributed.broadcast(torch.zeros(1).cuda(), src=0)
else:
print('Waiting rank0 to create a corpus file.')
# sync with rank0
torch.distributed.broadcast(torch.zeros(1).cuda(), src=0)
corpus = torch.load(corpus_path)
else:
corpus = Corpus(data_path, sort_dict)
torch.save(corpus, corpus_path)
return corpus