recipes/self_training/pseudo_labeling/generate_synthetic_data.py (182 lines of code) (raw):
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import os
import sys
from dataset_utils import (
create_transcript_dict_from_listfile,
write_transcript_list_to_file,
)
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
def pair_transcripts_with_existing_list(transcript_list, listpath):
transcripts = create_transcript_dict_from_listfile(listpath)
merged = {}
for pred in transcript_list:
merged[pred.sid] = transcripts[pred.sid]
merged[pred.sid].transcript = pred.prediction
# remove transcripts for which we don't have a prediction (those that were removed)
return merged
def compute_ngrams(inp, size):
return [inp[i : i + size] for i in range(len(inp) - (size - 1))]
def filter_transcripts(transcript_list, args):
# fastpath
if not args.filter:
return transcript_list
filtered_transcripts = []
for transcript in transcript_list:
good = True
# skip transcripts with warnings
if args.warnings:
if transcript.warning:
good = False
if args.print_filtered_results:
eprint(
"Filtering predicted transcript (warning) "
+ transcript.sid
+ ": "
+ transcript.prediction
)
continue
if args.ngram:
plist = transcript.prediction.split(" ")
# look for repeating n-grams
ngrams = [" ".join(c) for c in compute_ngrams(plist, args.ngram_size)]
for gram in ngrams:
if transcript.prediction.count(gram) > args.ngram_appearance_threshold:
good = False
if args.print_filtered_results:
eprint(
"Filtering predicted transcript (ngram fail) "
+ transcript.sid
+ ": "
+ transcript.prediction
)
break
# passes all checks
if good:
filtered_transcripts.append(transcript)
return filtered_transcripts
class TranscriptPrediction(object):
def __init__(self, sid, prediction, transcript, warning=False):
self.sid = sid
self.prediction = prediction
self.transcript = transcript
self.warning = warning
def create_transcript_set(inpath, viterbi=False, distributed_decoding=False):
with open(inpath, "r") as f:
if not distributed_decoding:
# first line is chronos job
f.readline()
predictions = []
while True:
# each glob has
# - actual transcript
# - predicted transcript
# - actual word pieces
# - predicted word pieces
transcript = f.readline()
# check if EOF
if not transcript:
break
# each set is four lines, unless there's a warning
warning = False
if "[WARNING]" in transcript:
transcript = f.readline() # read an extra line to compensate
warning = True
transcript = transcript[
transcript.find("|T|: ") + len("|T|: ") :
] # remove |T|:
predicted = f.readline() # predicted transcript
predicted = predicted[
predicted.find("|P|: ") + len("|P|: ") :
] # remove |P|:
if viterbi:
predicted = predicted.replace(" ", "").replace("_", " ")
transcript = transcript.replace(" ", "").replace("_", " ")
# if distributed_decoding:
# predicted = predicted[1:].replace("_", " ")
# if not viterbi:
# read wp
f.readline()
f.readline()
sample_info = f.readline()
if not sample_info.strip():
continue
sid = sample_info.split(" ")[1]
sid = sid[:-1]
predictions.append(
TranscriptPrediction(sid, predicted, transcript, warning)
)
return predictions
def run():
parser = argparse.ArgumentParser(
description="Converts decoder output into train-ready list-style"
" dataset formats"
)
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
help="Path to decoder output containing transcripts",
)
parser.add_argument(
"-p",
"--listpath",
type=str,
required=True,
help="Path of existing list file dataset or which to replace transcripts",
)
parser.add_argument(
"-w",
"--warnings",
action="store_true",
help="Remove transcripts with EOS warnings by default",
)
parser.add_argument(
"-g",
"--ngram",
action="store_true",
help="Remove transcripts with ngram issues",
)
parser.add_argument(
"-n",
"--ngram_appearance_threshold",
type=int,
required=False,
default=4,
help="The number of identical n-grams that must appear in a "
"prediction for it to be thrown out",
)
parser.add_argument(
"-s",
"--ngram_size",
type=int,
required=False,
default=2,
help="The size of n-gram which will be used when searching for duplicates",
)
parser.add_argument(
"-f", "--filter", action="store_true", help="Run some filtering criteria"
)
parser.add_argument(
"-o", "--output", type=str, required=True, help="Output filepath"
)
parser.add_argument(
"-d",
"--distributed_decoding",
action="store_true",
help="Processing a combined transcript with distributed decoding",
)
parser.add_argument(
"-v",
"--print_filtered_results",
type=bool,
required=False,
default=False,
help="Print transcripts that are filtered based on filter criteria to stderr",
)
parser.add_argument(
"-q",
"--viterbi",
action="store_true",
help="Expects a transcript format that is consistent with a Viterbi run",
)
args = parser.parse_args()
if not os.path.isfile(args.input):
raise Exception("'" + args.input + "' - input file doesn't exist")
if not os.path.isfile(args.listpath):
raise Exception("'" + args.input + "' - listpath file doesn't exist")
transcripts_predictions = create_transcript_set(
args.input, args.viterbi, args.distributed_decoding
)
filtered_transcripts = filter_transcripts(transcripts_predictions, args)
final_transcript_dict = pair_transcripts_with_existing_list(
filtered_transcripts, args.listpath
)
write_transcript_list_to_file(final_transcript_dict, args.output)
if __name__ == "__main__":
run()