in eval/eval_PER.py [0:0]
def train(args):
# Output Directory
if not os.path.isdir(args.output):
os.mkdir(args.output)
name = f"_{args.name}" if args.command == "per" else ""
pathLogs = os.path.join(args.output, f'logs_{args.command}{name}.txt')
tee = subprocess.Popen(["tee", pathLogs], stdin=subprocess.PIPE)
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
nPhones = get_n_phones(args.path_phone_converter)
phoneLabels = parse_phone_labels(args.pathPhone)
inSeqs = find_all_files(args.pathDB, args.file_extension)
# Model
downsamplingFactor = 160
state_dict = torch.load(args.pathCheckpoint)
featureMaker = load_cpc_features(state_dict)
hiddenGar = featureMaker.get_output_dim()
featureMaker.cuda()
featureMaker = torch.nn.DataParallel(featureMaker)
# Criterion
phoneCriterion = per_src.CTCPhoneCriterion(hiddenGar, nPhones, args.LSTM,
seqNorm=args.seqNorm,
dropout=args.dropout,
reduction=args.loss_reduction)
phoneCriterion.cuda()
phoneCriterion = torch.nn.DataParallel(phoneCriterion)
# Datasets
if args.command == 'train' and args.pathTrain is not None:
seqTrain = filter_seq(args.pathTrain, inSeqs)
else:
seqTrain = inSeqs
if args.pathVal is None:
random.shuffle(seqTrain)
sizeTrain = int(0.9 * len(seqTrain))
seqTrain, seqVal = seqTrain[:sizeTrain], seqTrain[sizeTrain:]
elif args.pathVal is not None:
seqVal = filter_seq(args.pathVal, inSeqs)
print(len(seqVal), len(inSeqs), args.pathVal)
if args.debug:
seqVal = seqVal[:100]
print(f"Loading the validation dataset at {args.pathDB}")
datasetVal = per_src.SingleSequenceDataset(args.pathDB, seqVal,
phoneLabels, inDim=args.in_dim)
valLoader = DataLoader(datasetVal, batch_size=args.batchSize,
shuffle=True)
# Checkpoint file where the model should be saved
pathCheckpoint = os.path.join(args.output, 'checkpoint.pt')
featureMaker.optimize = True
if args.freeze:
featureMaker.eval()
featureMaker.optimize = False
for g in featureMaker.parameters():
g.requires_grad = False
if args.debug:
print("debug")
random.shuffle(seqTrain)
seqTrain = seqTrain[:1000]
seqVal = seqVal[:100]
print(f"Loading the training dataset at {args.pathDB}")
datasetTrain = per_src.SingleSequenceDataset(args.pathDB, seqTrain,
phoneLabels,
inDim=args.in_dim)
trainLoader = DataLoader(datasetTrain, batch_size=args.batchSize,
shuffle=True)
# Optimizer
g_params = list(phoneCriterion.parameters())
if not args.freeze:
print("Optimizing model")
g_params += list(featureMaker.parameters())
optimizer = torch.optim.Adam(g_params, lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.epsilon)
pathArgs = os.path.join(args.output, "args_training.json")
with open(pathArgs, 'w') as file:
json.dump(vars(args), file, indent=2)
run_training(trainLoader, valLoader, featureMaker, phoneCriterion,
optimizer, downsamplingFactor, args.nEpochs, pathCheckpoint)