def main()

in eval/eval_WER.py [0:0]


def main(args):
    description = "Training and evaluation of a letter classifier on top of a pre-trained CPC model. "
    "Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training)."

    parser = argparse.ArgumentParser(description=description)

    parser.add_argument('--path_checkpoint', type=str)
    parser.add_argument('--path_train', default=None, type=str)
    parser.add_argument('--path_val', default=None, type=str)
    parser.add_argument('--n_epochs', type=int, default=30)
    parser.add_argument('--seed', type=int, default=7)
    parser.add_argument('--downsampling_factor', type=int, default=160)

    parser.add_argument('--lr', type=float, default=2e-04)
    parser.add_argument('--output', type=str, default='out',
                        help="Output directory")
    parser.add_argument('--p_dropout', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=32)

    parser.add_argument('--lm_weight', type=float, default=2.0)
    parser.add_argument('--path_wer',
                        help="For computing the WER on specific sequences",
                        action='append')
    parser.add_argument('--letters_path', type=str,
                        default='WER_data/letters.lst')

    args = parser.parse_args(args=args)

    if not args.path_wer and not (args.path_train and args.path_val):
        print('Please specify at least one `path_wer` (to calculate WER) or `path_train` and `path_val` (for training).')

    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    # creating models before reading the datasets
    with open(args.letters_path) as f:
        n_chars = len(f.readlines())

    state_dict = torch.load(args.path_checkpoint)
    feature_maker = load_cpc_features(state_dict)
    feature_maker.cuda()
    hidden = feature_maker.get_output_dim()

    letter_classifier = LetterClassifier(
        feature_maker,
        hidden, n_chars, p_dropout=args.p_dropout if hasattr(args, 'p_dropout') else 0.0)

    criterion = CTCLetterCriterion(letter_classifier, n_chars)
    criterion.cuda()
    criterion = torch.nn.DataParallel(criterion)

    # Checkpoint file where the model should be saved
    path_checkpoint = os.path.join(args.output, 'checkpoint.pt')

    if args.path_train and args.path_val:
        set_seed(args.seed)

        char_labels_val, n_chars, _ = parse_ctc_labels_from_root(
            args.path_val, letters_path="./WER_data/letters.lst")
        print(f"Loading the validation dataset at {args.path_val}")
        dataset_val = SingleSequenceDataset(args.path_val, char_labels_val)
        val_loader = DataLoader(dataset_val, batch_size=args.batch_size,
                                shuffle=False)

        # train dataset
        char_labels_train, n_chars, _ = parse_ctc_labels_from_root(
            args.path_train, letters_path="./WER_data/letters.lst")

        print(f"Loading the training dataset at {args.path_train}")
        dataset_train = SingleSequenceDataset(
            args.path_train, char_labels_train)
        train_loader = DataLoader(dataset_train, batch_size=args.batch_size,
                                  shuffle=True)

        # Optimizer
        g_params = list(criterion.parameters())
        optimizer = torch.optim.Adam(g_params, lr=args.lr)

        args_path = os.path.join(args.output, "args_training.json")
        with open(args_path, 'w') as file:
            json.dump(vars(args), file, indent=2)

        run(train_loader, val_loader, criterion,
            optimizer, args.downsampling_factor, args.n_epochs, path_checkpoint)

    if args.path_wer:
        args = get_eval_args(args)

        state_dict = torch.load(path_checkpoint)
        criterion.load_state_dict(state_dict)
        criterion = criterion.module
        criterion.eval()

        args_path = os.path.join(args.output, "args_validation.json")
        with open(args_path, 'w') as file:
            json.dump(vars(args), file, indent=2)

        for path_wer in args.path_wer:
            print(f"Loading the validation dataset at {path_wer}")

            char_labels_wer, _, (letter2index, index2letter) = parse_ctc_labels_from_root(
                path_wer, letters_path="./WER_data/letters.lst")
            dataset_eval = SingleSequenceDataset(path_wer, char_labels_wer)
            eval_loader = DataLoader(
                dataset_eval, batch_size=args.batch_size, shuffle=False)

            wer = eval_wer(eval_loader,
                           criterion,
                           args.lm_weight,
                           index2letter)
            print(f'WER: {wer}')