def main()

in sing/train.py [0:0]


def main():
    args = get_parser().parse_args()

    if args.debug:
        args.ae_epochs = 1
        args.seq_epochs = 1
        args.sing_epochs = 1

    if args.debug_fast:
        args.ae_channels = 128
        args.ae_dimension = 16
        args.ae_rewrite = 1
        args.seq_hidden_size = 128
        args.seq_layers = 1

    if args.checkpoint:
        args.checkpoint.mkdir(exist_ok=True, parents=True)

    if not args.data.exists():
        utils.fatal("Could not find the nsynth dataset. "
                    "To download it, follow the instructions at "
                    "https://github.com/facebookresearch/SING")

    nsynth_dataset = nsynth.NSynthDataset(args.data, pad=args.pad)
    cardinalities = nsynth_dataset.metadata.cardinalities

    train_dataset, valid, test = nsynth.make_datasets(nsynth_dataset)
    if args.debug:
        train_dataset = datasets.RandomSubset(train_dataset, size=100)
    eval_train = datasets.RandomSubset(train_dataset, size=10000)

    if args.debug:
        eval_datasets = {
            'eval_train': eval_train,
        }
    else:
        eval_datasets = {
            'eval_train': eval_train,
            'valid': valid,
            'test': test,
        }

    base_loss = nn.L1Loss() if args.l1 else nn.MSELoss()
    train_loss = base_loss if args.wav else dsp.SpectralLoss(
        base_loss, epsilon=args.epsilon)
    eval_losses = {
        'wav_l1': nn.L1Loss(),
        'wav_mse': nn.MSELoss(),
        'spec_l1': dsp.SpectralLoss(nn.L1Loss(), epsilon=args.epsilon),
        'spec_mse': dsp.SpectralLoss(nn.MSELoss(), epsilon=args.epsilon),
    }

    kwargs = {
        'train_dataset': train_dataset,
        'eval_datasets': eval_datasets,
        'train_loss': train_loss,
        'eval_losses': eval_losses,
        'batch_size': args.batch_size,
        'lr': args.lr,
        'cuda': args.cuda,
        'parallel': args.parallel,
    }

    autoencoder = train_autoencoder(args, **kwargs)
    sequence_generator = train_sequence_generator(args, autoencoder,
                                                  cardinalities, **kwargs)
    sing = fine_tune_sing(args, sequence_generator, autoencoder.decoder,
                          **kwargs)
    torch.save(sing.cpu(), str(args.output))