in sing/train.py [0:0]
def main():
args = get_parser().parse_args()
if args.debug:
args.ae_epochs = 1
args.seq_epochs = 1
args.sing_epochs = 1
if args.debug_fast:
args.ae_channels = 128
args.ae_dimension = 16
args.ae_rewrite = 1
args.seq_hidden_size = 128
args.seq_layers = 1
if args.checkpoint:
args.checkpoint.mkdir(exist_ok=True, parents=True)
if not args.data.exists():
utils.fatal("Could not find the nsynth dataset. "
"To download it, follow the instructions at "
"https://github.com/facebookresearch/SING")
nsynth_dataset = nsynth.NSynthDataset(args.data, pad=args.pad)
cardinalities = nsynth_dataset.metadata.cardinalities
train_dataset, valid, test = nsynth.make_datasets(nsynth_dataset)
if args.debug:
train_dataset = datasets.RandomSubset(train_dataset, size=100)
eval_train = datasets.RandomSubset(train_dataset, size=10000)
if args.debug:
eval_datasets = {
'eval_train': eval_train,
}
else:
eval_datasets = {
'eval_train': eval_train,
'valid': valid,
'test': test,
}
base_loss = nn.L1Loss() if args.l1 else nn.MSELoss()
train_loss = base_loss if args.wav else dsp.SpectralLoss(
base_loss, epsilon=args.epsilon)
eval_losses = {
'wav_l1': nn.L1Loss(),
'wav_mse': nn.MSELoss(),
'spec_l1': dsp.SpectralLoss(nn.L1Loss(), epsilon=args.epsilon),
'spec_mse': dsp.SpectralLoss(nn.MSELoss(), epsilon=args.epsilon),
}
kwargs = {
'train_dataset': train_dataset,
'eval_datasets': eval_datasets,
'train_loss': train_loss,
'eval_losses': eval_losses,
'batch_size': args.batch_size,
'lr': args.lr,
'cuda': args.cuda,
'parallel': args.parallel,
}
autoencoder = train_autoencoder(args, **kwargs)
sequence_generator = train_sequence_generator(args, autoencoder,
cardinalities, **kwargs)
sing = fine_tune_sing(args, sequence_generator, autoencoder.decoder,
**kwargs)
torch.save(sing.cpu(), str(args.output))