in eval/eval_WER.py [0:0]
def main(args):
description = "Training and evaluation of a letter classifier on top of a pre-trained CPC model. "
"Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training)."
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--path_checkpoint', type=str)
parser.add_argument('--path_train', default=None, type=str)
parser.add_argument('--path_val', default=None, type=str)
parser.add_argument('--n_epochs', type=int, default=30)
parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--downsampling_factor', type=int, default=160)
parser.add_argument('--lr', type=float, default=2e-04)
parser.add_argument('--output', type=str, default='out',
help="Output directory")
parser.add_argument('--p_dropout', type=float, default=0.0)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lm_weight', type=float, default=2.0)
parser.add_argument('--path_wer',
help="For computing the WER on specific sequences",
action='append')
parser.add_argument('--letters_path', type=str,
default='WER_data/letters.lst')
args = parser.parse_args(args=args)
if not args.path_wer and not (args.path_train and args.path_val):
print('Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training).')
if not os.path.isdir(args.output):
os.mkdir(args.output)
# creating models before reading the datasets
with open(args.letters_path) as f:
n_chars = len(f.readlines())
state_dict = torch.load(args.path_checkpoint)
feature_maker = load_cpc_features(state_dict)
feature_maker.cuda()
hidden = feature_maker.get_output_dim()
letter_classifier = LetterClassifier(
feature_maker,
hidden, n_chars, p_dropout=args.p_dropout if hasattr(args, 'p_dropout') else 0.0)
criterion = CTCLetterCriterion(letter_classifier, n_chars)
criterion.cuda()
criterion = torch.nn.DataParallel(criterion)
# Checkpoint file where the model should be saved
path_checkpoint = os.path.join(args.output, 'checkpoint.pt')
if args.path_train and args.path_val:
set_seed(args.seed)
char_labels_val, n_chars, _ = parse_ctc_labels_from_root(
args.path_val, letters_path="./WER_data/letters.lst")
print(f"Loading the validation dataset at {args.path_val}")
dataset_val = SingleSequenceDataset(args.path_val, char_labels_val)
val_loader = DataLoader(dataset_val, batch_size=args.batch_size,
shuffle=False)
# train dataset
char_labels_train, n_chars, _ = parse_ctc_labels_from_root(
args.path_train, letters_path="./WER_data/letters.lst")
print(f"Loading the training dataset at {args.path_train}")
dataset_train = SingleSequenceDataset(
args.path_train, char_labels_train)
train_loader = DataLoader(dataset_train, batch_size=args.batch_size,
shuffle=True)
# Optimizer
g_params = list(criterion.parameters())
optimizer = torch.optim.Adam(g_params, lr=args.lr)
args_path = os.path.join(args.output, "args_training.json")
with open(args_path, 'w') as file:
json.dump(vars(args), file, indent=2)
run(train_loader, val_loader, criterion,
optimizer, args.downsampling_factor, args.n_epochs, path_checkpoint)
if args.path_wer:
args = get_eval_args(args)
state_dict = torch.load(path_checkpoint)
criterion.load_state_dict(state_dict)
criterion = criterion.module
criterion.eval()
args_path = os.path.join(args.output, "args_validation.json")
with open(args_path, 'w') as file:
json.dump(vars(args), file, indent=2)
for path_wer in args.path_wer:
print(f"Loading the validation dataset at {path_wer}")
char_labels_wer, _, (letter2index, index2letter) = parse_ctc_labels_from_root(
path_wer, letters_path="./WER_data/letters.lst")
dataset_eval = SingleSequenceDataset(path_wer, char_labels_wer)
eval_loader = DataLoader(
dataset_eval, batch_size=args.batch_size, shuffle=False)
wer = eval_wer(eval_loader,
criterion,
args.lm_weight,
index2letter)
print(f'WER: {wer}')