recipes/self_training/pseudo_labeling/generate_synthetic_lexicon.py (138 lines of code) (raw):
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import operator
import os
from synthetic_lexicon_utils import (
LexiconEntry,
read_spellings_from_file,
write_spellings_to_file,
)
def generate_wp_selling(wp_list):
spellings = []
this_spelling = []
for wp in wp_list:
if not "_" in wp:
this_spelling.append(wp)
elif "_" in wp:
if len(this_spelling) > 0:
spellings.append(this_spelling)
this_spelling = [wp]
if len(this_spelling) > 0:
spellings.append(this_spelling)
return spellings
def generate(infile):
# maps word --> dict mapping wp spellings to the number of
# times that spelling appears
lexicon = {}
with open(infile, "r") as f:
prediction = None
wp_spelling_raw = None
for line in f:
if "|P|" in line:
# format is "|P|: _[wp]..."
prediction = (
line[line.find("|P|: ") + len("|P|: ") :]
.replace(" ", "")
.replace("_", " ")
)
continue
elif "|p|" in line:
wp_spelling_raw = line[line.find("|p|:") + len("|p|: ") :]
elif "|T|" in line:
continue
elif "|t|" in line:
continue
elif "sample" in line:
continue
elif "WARNING" in line:
continue
elif "CHRONOS" in line:
continue
elif "---" in line:
continue
else:
raise Exception("Format invalid; extraneous line: " + line)
transcription = prediction.strip().split(" ")
wp_spelling = [e.strip() for e in wp_spelling_raw.strip().split(" ") if e]
wp_spelling = generate_wp_selling(wp_spelling)
for transcription_word, wp_spelling_word in zip(transcription, wp_spelling):
wp_key = " ".join(wp_spelling_word)
if transcription_word not in lexicon:
lexicon[transcription_word] = {}
if wp_key not in lexicon[transcription_word]:
lexicon[transcription_word][wp_key] = 0
lexicon[transcription_word][wp_key] += 1
return lexicon
def order_lexicon(lexicon):
spellings = {} # maps a transcription word to its spellings, in order
for transcription_word in lexicon.keys():
spellings[transcription_word] = []
for spelling, _freq in sorted(
lexicon[transcription_word].items(),
key=operator.itemgetter(1),
reverse=True,
):
spellings[transcription_word].append(spelling.split(" "))
return spellings
def create_spellings(spellings):
entries = {}
sorted_keys = sorted(spellings.keys())
for word in sorted_keys:
for spelling in spellings[word]:
if word not in entries:
entries[word] = LexiconEntry(word, [])
entries[word].add_spelling(spelling)
return entries
def run():
parser = argparse.ArgumentParser(
description="Converts decoder output into train-ready lexicon format"
)
parser.add_argument(
"-i",
"--inputhyp",
type=str,
required=True,
help="Path to decoder output using --usewordpiece=false file",
)
parser.add_argument(
"-l",
"--inputlexicon",
type=str,
required=True,
help="Path to the existing lexicon with which to merge a lexicon from the hyp",
)
parser.add_argument(
"-o", "--output", type=str, required=True, help="Path to output lexicon file"
)
args = parser.parse_args()
if not os.path.isfile(args.inputhyp):
raise Exception("'" + args.inputhyp + "' - input file doesn't exist")
if not os.path.isfile(args.inputlexicon):
raise Exception("'" + args.inputlexicon + "' - input file doesn't exist")
lexicon = generate(args.inputhyp)
sorted_spellings = order_lexicon(lexicon)
spellings = create_spellings(sorted_spellings)
new_lexicon = []
for key in sorted(spellings.keys()):
new_lexicon.append(spellings[key])
old_lexicon_spellings = read_spellings_from_file(args.inputlexicon)
old = {}
for entry in old_lexicon_spellings:
old[entry.word] = entry
count = 0
for entry in new_lexicon:
count += 1
if count % 1000 == 0:
print("Processed " + str(count) + " entries in new lexicon.")
if entry.word in old.keys():
# entry in lexicon, check if spelling exists, else append to end
for spelling in entry.sorted_spellings:
if spelling in old[entry.word].sorted_spellings:
continue
else:
# only add spelling if we don't already have it
if spelling not in old[entry.word].sorted_spellings:
old[entry.word].sorted_spellings.append(spelling)
else:
# OOV case: create a new lexicon entry with these spellings
old[entry.word] = entry
final = []
# sort the final spellings
for key in sorted(old.keys()):
final.append(old[key])
write_spellings_to_file(final, args.output)
if __name__ == "__main__":
run()