in example/speech_recognition/main.py [0:0]
def load_data(args):
mode = args.config.get('common', 'mode')
batch_size = args.config.getint('common', 'batch_size')
whcs = WHCS()
whcs.width = args.config.getint('data', 'width')
whcs.height = args.config.getint('data', 'height')
whcs.channel = args.config.getint('data', 'channel')
whcs.stride = args.config.getint('data', 'stride')
save_dir = 'checkpoints'
model_name = args.config.get('common', 'prefix')
is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files')
language = args.config.get('data', 'language')
is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes')
labelUtil = LabelUtil.getInstance()
if language == "en":
if is_bi_graphemes:
try:
labelUtil.load_unicode_set("resources/unicodemap_en_baidu_bi_graphemes.csv")
except:
raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True")
else:
labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv")
else:
raise Exception("Error: Language Type: %s" % language)
args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
if mode == 'predict':
test_json = args.config.get('data', 'test_json')
datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
datagen.load_train_data(test_json)
datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
elif mode =="train" or mode == "load":
data_json = args.config.get('data', 'train_json')
val_json = args.config.get('data', 'val_json')
datagen = DataGenerator(save_dir=save_dir, model_name=model_name)
datagen.load_train_data(data_json)
#test bigramphems
if overwrite_meta_files and is_bi_graphemes:
generate_bi_graphemes_dictionary(datagen.train_texts)
args.config.set('arch', 'n_classes', str(labelUtil.get_count()))
if mode == "train":
if overwrite_meta_files:
normalize_target_k = args.config.getint('train', 'normalize_target_k')
datagen.sample_normalize(normalize_target_k, True)
else:
datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
datagen.load_validation_data(val_json)
elif mode == "load":
# get feat_mean and feat_std to normalize dataset
datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')),
np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std')))
datagen.load_validation_data(val_json)
else:
raise Exception(
'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')
is_batchnorm = args.config.getboolean('arch', 'is_batchnorm')
if batch_size == 1 and is_batchnorm:
raise Warning('batch size 1 is too small for is_batchnorm')
# sort file paths by its duration in ascending order to implement sortaGrad
if mode == "train" or mode == "load":
max_t_count = datagen.get_max_seq_length(partition="train")
max_label_length = datagen.get_max_label_length(partition="train",is_bi_graphemes=is_bi_graphemes)
elif mode == "predict":
max_t_count = datagen.get_max_seq_length(partition="test")
max_label_length = datagen.get_max_label_length(partition="test",is_bi_graphemes=is_bi_graphemes)
else:
raise Exception(
'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.')
args.config.set('arch', 'max_t_count', str(max_t_count))
args.config.set('arch', 'max_label_length', str(max_label_length))
from importlib import import_module
prepare_data_template = import_module(args.config.get('arch', 'arch_file'))
init_states = prepare_data_template.prepare_data(args)
if mode == "train":
sort_by_duration = True
else:
sort_by_duration = False
data_loaded = STTIter(partition="train",
count=datagen.count,
datagen=datagen,
batch_size=batch_size,
num_label=max_label_length,
init_states=init_states,
seq_length=max_t_count,
width=whcs.width,
height=whcs.height,
sort_by_duration=sort_by_duration,
is_bi_graphemes=is_bi_graphemes)
if mode == 'predict':
return data_loaded, args
else:
validation_loaded = STTIter(partition="validation",
count=datagen.val_count,
datagen=datagen,
batch_size=batch_size,
num_label=max_label_length,
init_states=init_states,
seq_length=max_t_count,
width=whcs.width,
height=whcs.height,
sort_by_duration=False,
is_bi_graphemes=is_bi_graphemes)
return data_loaded, validation_loaded, args