def train()

in eval/eval_PER.py [0:0]


def train(args):

    # Output Directory
    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    name = f"_{args.name}" if args.command == "per" else ""
    pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt')
    tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE)
    os.dup2(tee.stdin.fileno(), sys.stdout.fileno())

    nPhones = get_n_phones(args.path_phone_converter)
    phoneLabels = parse_phone_labels(args.pathPhone)
    inSeqs = find_all_files(args.pathDB, args.file_extension)

    # Model
    downsamplingFactor = 160
    state_dict = torch.load(args.pathCheckpoint)
    featureMaker = load_cpc_features(state_dict)
    hiddenGar = featureMaker.get_output_dim()
    featureMaker.cuda()
    featureMaker = torch.nn.DataParallel(featureMaker)

    # Criterion
    phoneCriterion = per_src.CTCPhoneCriterion(hiddenGar, nPhones, args.LSTM,
                                               seqNorm=args.seqNorm,
                                               dropout=args.dropout,
                                               reduction=args.loss_reduction)
    phoneCriterion.cuda()
    phoneCriterion = torch.nn.DataParallel(phoneCriterion)

    # Datasets
    if args.command == 'train' and args.pathTrain is not None:
        seqTrain = filter_seq(args.pathTrain, inSeqs)
    else:
        seqTrain = inSeqs

    if args.pathVal is None:
        random.shuffle(seqTrain)
        sizeTrain = int(0.9 * len(seqTrain))
        seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
    elif args.pathVal is not None:
        seqVal = filter_seq(args.pathVal, inSeqs)
        print(len(seqVal), len(inSeqs), args.pathVal)

    if args.debug:
        seqVal = seqVal[:100]

    print(f"Loading the validation dataset at {args.pathDB}")
    datasetVal = per_src.SingleSequenceDataset(args.pathDB, seqVal,
                                               phoneLabels, inDim=args.in_dim)

    valLoader = DataLoader(datasetVal, batch_size=args.batchSize,
                           shuffle=True)

    # Checkpoint file where the model should be saved
    pathCheckpoint = os.path.join(args.output, 'checkpoint.pt')

    featureMaker.optimize = True
    if args.freeze:
        featureMaker.eval()
        featureMaker.optimize = False
        for g in featureMaker.parameters():
            g.requires_grad = False

    if args.debug:
        print("debug")
        random.shuffle(seqTrain)
        seqTrain = seqTrain[:1000]
        seqVal = seqVal[:100]

    print(f"Loading the training dataset at {args.pathDB}")
    datasetTrain = per_src.SingleSequenceDataset(args.pathDB, seqTrain,
                                                 phoneLabels,
                                                 inDim=args.in_dim)

    trainLoader = DataLoader(datasetTrain, batch_size=args.batchSize,
                             shuffle=True)

    # Optimizer
    g_params = list(phoneCriterion.parameters())
    if not args.freeze:
        print("Optimizing model")
        g_params += list(featureMaker.parameters())

    optimizer = torch.optim.Adam(g_params, lr=args.lr,
                                 betas=(args.beta1, args.beta2),
                                 eps=args.epsilon)

    pathArgs = os.path.join(args.output, "args_training.json")
    with open(pathArgs, 'w') as file:
        json.dump(vars(args), file, indent=2)

    run_training(trainLoader, valLoader, featureMaker, phoneCriterion,
                 optimizer, downsamplingFactor, args.nEpochs, pathCheckpoint)