def train_sequence_generator()

in sing/train.py [0:0]


def train_sequence_generator(args, autoencoder, cardinalities, train_dataset,
                             eval_datasets, train_loss, eval_losses, **kwargs):
    checkpoint_path = (args.checkpoint / "seq.torch"
                       if args.checkpoint else None)

    wav_length = train_dataset[0].tensors['wav'].size(-1)
    embedding_length = autoencoder.decoder.embedding_length(
        wav_length - 2 * autoencoder.decoder.strip)
    embeddings = {
        name: (cardinalities[name], getattr(args, '{}_dim'.format(name)))
        for name in ['velocity', 'instrument', 'pitch']
    }

    model = SequenceGenerator(
        embeddings=embeddings,
        length=embedding_length,
        time_dimension=args.time_dim,
        output_dimension=args.ae_dimension,
        hidden_size=args.seq_hidden_size,
        num_layers=args.seq_layers)

    if args.seq_epochs:
        print("Precomputing embeddings for all datasets")
        generate_embeddings = functools.partial(
            generate_embeddings_dataset,
            encoder=autoencoder.encoder,
            batch_size=args.batch_size,
            cuda=args.cuda,
            parallel=args.parallel)
        train_dataset = generate_embeddings(train_dataset)

        print("Training sequence generator")
        SequenceGeneratorTrainer(
            suffix="_seq",
            model=model,
            decoder=autoencoder.decoder,
            epochs=args.seq_epochs,
            train_loss=nn.MSELoss(),
            eval_losses=eval_losses,
            train_dataset=train_dataset,
            eval_datasets=eval_datasets,
            truncated_gradient=args.seq_truncated,
            checkpoint_path=checkpoint_path,
            **kwargs).train()
    return model