in pretraining/preprocess.py [0:0]
def main(args):
print(args)
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
def build_dictionary(filenames):
d = BertDictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line, args.workers)
return d
def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += f'.{lang}'
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path('dict', lang) + '.txt'
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = build_dictionary(set([
train_path(lang)
for lang in [args.source_lang, args.target_lang]
]))
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = BertDictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)])
if target:
if args.tgtdict:
tgt_dict = BertDictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)])
src_dict.finalize(
threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(dict_path(args.source_lang))
if target:
if not args.joined_dictionary:
tgt_dict.finalize(
threshold=args.thresholdtgt,
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang, num_workers):
dict = BertDictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict)))
n_seq_tok = [0, 0]
replaced = Counter()
def merge_result(worker_result):
replaced.update(worker_result['replaced'])
n_seq_tok[0] += worker_result['nseq']
n_seq_tok[1] += worker_result['ntok']
input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
offsets = Tokenizer.find_offsets(input_file, num_workers)
pool = None
if num_workers > 1:
pool = Pool(processes=num_workers-1)
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
pool.apply_async(binarize, (args, input_file, dict, prefix, lang,
offsets[worker_id],
offsets[worker_id + 1]), callback=merge_result)
pool.close()
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_file(args, output_prefix, lang, 'bin'))
merge_result(Tokenizer.binarize(input_file, dict, lambda t: ds.add_item(t),
offset=0, end=offsets[1]))
if num_workers > 1:
pool.join()
for worker_id in range(1, num_workers):
prefix = "{}{}".format(output_prefix, worker_id)
temp_file_path = dataset_dest_prefix(args, prefix, lang)
ds.merge_file_(temp_file_path)
os.remove(indexed_dataset.data_file_path(temp_file_path))
os.remove(indexed_dataset.index_file_path(temp_file_path))
ds.finalize(dataset_dest_file(args, output_prefix, lang, 'idx'))
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, n_seq_tok[0], n_seq_tok[1],
100 * sum(replaced.values()) / n_seq_tok[1], dict.unk_word))
def make_dataset(input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang, num_workers)
elif args.output_format == 'raw':
# Copy original text file to destination folder
output_text_file = dest_path(
output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang),
lang,
)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang, num_workers=args.workers)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, lang)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, lang)
make_all(args.source_lang)
if target:
make_all(args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {}
with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file:
with open(tgt_file_name, 'r') as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split('-')), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
assert srcidx != src_dict.pad()
assert srcidx != src_dict.eos()
assert tgtidx != tgt_dict.pad()
assert tgtidx != tgt_dict.eos()
if srcidx not in freq_map:
freq_map[srcidx] = {}
if tgtidx not in freq_map[srcidx]:
freq_map[srcidx][tgtidx] = 1
else:
freq_map[srcidx][tgtidx] += 1
align_dict = {}
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format(
args.source_lang, args.target_lang)), 'w') as f:
for k, v in align_dict.items():
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)