data/wsj/prepare.py (267 lines of code) (raw):

""" Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the MIT-style license found in the LICENSE file in the root directory of this source tree. ---------- Script to package original WSJ datasets into a form readable in wav2letter++ pipelines Please install `sph2pipe` on your own - see https://www.ldc.upenn.edu/language-resources/tools/sphere-conversion-tools \ with commands : wget https://www.ldc.upenn.edu/sites/www.ldc.upenn.edu/files/ctools/sph2pipe_v2.5.tar.gz tar -xzf sph2pipe_v2.5.tar.gz && cd sph2pipe_v2.5 gcc -o sph2pipe *.c -lm Command : python3 prepare.py --wsj0 [...]/WSJ0/media \ --wsj1 [...]/WSJ1/media --dst [...] --sph2pipe [...]/sph2pipe_v2.5/sph2pipe Replace [...] with appropriate paths """ from __future__ import absolute_import, division, print_function, unicode_literals import argparse import os import re import subprocess from multiprocessing import Pool import numpy from tqdm import tqdm from utils import convert_to_flac, find_transcripts, ndx_to_samples, preprocess_word if __name__ == "__main__": parser = argparse.ArgumentParser(description="WSJ Dataset creation.") parser.add_argument("--wsj0", help="top level directory containing all WSJ0 discs") parser.add_argument("--wsj1", help="top level directory containing all WSJ1 discs") parser.add_argument("--dst", help="destination directory", default="./wsj") parser.add_argument( "--wsj1_type", help="if you are using larger corpus LDC94S13A, set parameter to `LDC94S13A`", default="LDC94S13B", ) parser.add_argument( "--sph2pipe", help="path to sph2pipe executable", default="./sph2pipe_v2.5/sph2pipe", ) parser.add_argument( "-p", "--process", help="# of process for Multiprocessing", default=8, type=int ) args = parser.parse_args() wsj1_sep = "-" if args.wsj1_type == "LDC94S13A" else "_" assert os.path.isdir(str(args.wsj0)), "WSJ0 directory is not found - '{d}'".format( d=args.wsj0 ) assert os.path.isdir(str(args.wsj1)), "WSJ1 directory is not found - '{d}'".format( d=args.wsj1 ) assert args.wsj0 != args.wsj1, "WSJ0 and WSJ1 directories can't be the same" assert os.path.exists(args.sph2pipe), "sph2pipe not found '{d}'".format( d=args.sph2pipe ) # Prepare audio data transcripts = find_transcripts([args.wsj0, args.wsj1]) subsets = dict() subsets["si84"] = ndx_to_samples( args.wsj0, "11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx", transcripts, lambda line: None if "11_2_1:wsj0/si_tr_s/401" in line else line, ) assert len(subsets["si84"]) == 7138, "Incorrect number of samples in si84 part:" " should be 7138, but fould #{}.".format(len(subsets["si84"])) subsets["si284"] = ndx_to_samples( args.wsj0, "11-13.1/wsj0/doc/indices/train/tr_s_wv1.ndx", transcripts, lambda line: None if "11_2_1:wsj0/si_tr_s/401" in line else line, ) subsets["si284"] = subsets["si284"] + ndx_to_samples( args.wsj1, "13{}34.1/wsj1/doc/indices/si_tr_s.ndx".format(wsj1_sep), transcripts, None, wsj1_sep, ) assert len(subsets["si284"]) == 37416, "Incorrect number of samples in si284 part: " "should be 37416, but fould {}.".format(len(subsets["si284"])) subsets["nov92"] = ndx_to_samples( args.wsj0, "11-13.1/wsj0/doc/indices/test/nvp/si_et_20.ndx", transcripts, lambda line: line + ".wv1", ) assert ( len(subsets["nov92"]) == 333 ), "Incorrect number of samples in si284 part: should be 333, but fould {}.".format( len(subsets["nov92"]) ) subsets["nov92_5k"] = ndx_to_samples( args.wsj0, "11-13.1/wsj0/doc/indices/test/nvp/si_et_05.ndx", transcripts, lambda line: line + ".wv1", ) assert ( len(subsets["nov92_5k"]) == 330 ), "Incorrect number of samples in si284 part: should be 330, but fould {}.".format( len(subsets["nov92_5k"]) ) subsets["nov93"] = ndx_to_samples( args.wsj1, "13{}32.1/wsj1/doc/indices/wsj1/eval/h1_p0.ndx".format(wsj1_sep), transcripts, lambda line: line.replace("13_32_1", "13_33_1"), wsj1_sep, ) assert ( len(subsets["nov93"]) == 213 ), "Incorrect number of samples in si284 part: should be 213, but fould {}.".format( len(subsets["nov93"]) ) subsets["nov93_5k"] = ndx_to_samples( args.wsj1, "13{}32.1/wsj1/doc/indices/wsj1/eval/h2_p0.ndx".format(wsj1_sep), transcripts, lambda line: line.replace("13_32_1", "13_33_1"), wsj1_sep, ) assert ( len(subsets["nov93_5k"]) == 215 ), "Incorrect number of samples in si284 part: should be 215, but fould {}.".format( len(subsets["nov93_5k"]) ) subsets["nov93dev"] = ndx_to_samples( args.wsj1, "13{}34.1/wsj1/doc/indices/h1_p0.ndx".format(wsj1_sep), transcripts, None, wsj1_sep, ) assert ( len(subsets["nov93dev"]) == 503 ), "Incorrect number of samples in si284 part: should be 503, but fould {}.".format( len(subsets["nov93dev"]) ) subsets["nov93dev_5k"] = ndx_to_samples( args.wsj1, "13{}34.1/wsj1/doc/indices/h2_p0.ndx".format(wsj1_sep), transcripts, None, wsj1_sep, ) assert ( len(subsets["nov93dev_5k"]) == 513 ), "Incorrect number of samples in si284 part: should be 513, but fould {}.".format( len(subsets["nov93dev_5k"]) ) audio_path = os.path.join(args.dst, "audio") text_path = os.path.join(args.dst, "text") lists_path = os.path.join(args.dst, "lists") os.makedirs(audio_path, exist_ok=True) os.makedirs(text_path, exist_ok=True) os.makedirs(lists_path, exist_ok=True) transcription_words = set() for set_name, samples in subsets.items(): n_samples = len(samples) print( "Writing {s} with {n} samples\n".format(s=set_name, n=n_samples), flush=True ) data_dst = os.path.join(audio_path, set_name) if os.path.exists(data_dst): print( """The folder {} exists, existing flac for this folder will be skipped for generation. Please remove the folder if you want to regenerate the data""".format( data_dst ), flush=True, ) with Pool(args.process) as p: os.makedirs(data_dst, exist_ok=True) samples_info = list( tqdm( p.imap( convert_to_flac, zip( samples, numpy.arange(n_samples), [data_dst] * n_samples, [args.sph2pipe] * n_samples, ), ), total=n_samples, ) ) list_dst = os.path.join(lists_path, set_name + ".lst") if not os.path.exists(list_dst): with open(list_dst, "w") as f_list: for sample_info in samples_info: f_list.write(" ".join(sample_info) + "\n") else: print( "List {} already exists, skip its generation." " Please remove it if you want to regenerate the list".format( list_dst ), flush=True, ) for sample_info in samples_info: transcription_words.update(sample_info[3].lower().split(" ")) # Prepare text data text_dst = os.path.join(text_path, set_name + ".txt") if not os.path.exists(text_dst): with open(text_dst, "w") as f_text: for sample_info in samples_info: f_text.write(sample_info[3] + "\n") else: print( "Transcript text file {} already exists, skip its generation." " Please remove it if you want to regenerate the list".format(text_dst), flush=True, ) # Prepare text data (for language model) lm_paths = [ "13{}32.1/wsj1/doc/lng_modl/lm_train/np_data/87".format(wsj1_sep), "13{}32.1/wsj1/doc/lng_modl/lm_train/np_data/88".format(wsj1_sep), "13{}32.1/wsj1/doc/lng_modl/lm_train/np_data/89".format(wsj1_sep), ] if not os.path.exists(os.path.join(text_path, "cmudict.0.7a")): url = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict.0.7a" cmd = "cd {} && wget {}".format(text_path, url) os.system(cmd) else: print("CMU dict already exists, skip its downloading", flush=True) allowed_words = [] with open(os.path.join(text_path, "cmudict.0.7a"), "r") as f_cmu: for line in f_cmu: line = line.strip() if line.startswith(";;;"): continue allowed_words.append(line.split(" ")[0].lower()) lm_file = os.path.join(text_path, "lm.txt") # define valid words for correct splitting into sentences with "." existed_words = set.union(set(allowed_words), transcription_words) existed_words = existed_words - {"prof."} # for reproducibility from lua code if os.path.exists(lm_file): print( "LM data already exist, skip its generation." " Please remove the file {} to regenerate it".format(lm_file), flush=True, ) else: with open(lm_file, "w") as f_lm: for path in lm_paths: path = os.path.join(args.wsj1, path) for filename in os.listdir(path): if not filename.endswith(".z"): continue # Get text from zip files filename = os.path.join(path, filename) process = subprocess.Popen( ["zcat", filename], stdout=subprocess.PIPE ) out, _ = process.communicate() assert process.returncode == 0, "Error during zcat" text_data = out.decode("utf-8") text_data = text_data.lower() # split several sentences into sequence (split if word contains # dot only at the end and this word is absent # in the existed words set) text_data = " ".join( [ word[:-1] + "\n" if len(word) > 2 and word[-1] == "." and "." not in word[:-1] and word not in existed_words else word for word in text_data.split() ] ) text_data = re.sub("<s[^>]+>", "<s>", text_data) text_data = re.sub("<s>", "{", text_data) text_data = re.sub("</s>", "}", text_data) part_data = re.finditer( r"\{(.*?)\}", text_data, re.MULTILINE | re.DOTALL ) # take the internal of {...} for lines in part_data: lines = lines.group(1).strip() lines = re.sub(" +", " ", lines) for line in lines.split("\n"): sentence = [] for raw_word in line.split(" "): word = preprocess_word(raw_word) if len(word) > 0: sentence.append(word) if len(sentence) > 0: f_lm.write(" ".join(sentence) + "\n") print("Done!", flush=True)