in sing/train.py [0:0]
def train_sequence_generator(args, autoencoder, cardinalities, train_dataset,
eval_datasets, train_loss, eval_losses, **kwargs):
checkpoint_path = (args.checkpoint / "seq.torch"
if args.checkpoint else None)
wav_length = train_dataset[0].tensors['wav'].size(-1)
embedding_length = autoencoder.decoder.embedding_length(
wav_length - 2 * autoencoder.decoder.strip)
embeddings = {
name: (cardinalities[name], getattr(args, '{}_dim'.format(name)))
for name in ['velocity', 'instrument', 'pitch']
}
model = SequenceGenerator(
embeddings=embeddings,
length=embedding_length,
time_dimension=args.time_dim,
output_dimension=args.ae_dimension,
hidden_size=args.seq_hidden_size,
num_layers=args.seq_layers)
if args.seq_epochs:
print("Precomputing embeddings for all datasets")
generate_embeddings = functools.partial(
generate_embeddings_dataset,
encoder=autoencoder.encoder,
batch_size=args.batch_size,
cuda=args.cuda,
parallel=args.parallel)
train_dataset = generate_embeddings(train_dataset)
print("Training sequence generator")
SequenceGeneratorTrainer(
suffix="_seq",
model=model,
decoder=autoencoder.decoder,
epochs=args.seq_epochs,
train_loss=nn.MSELoss(),
eval_losses=eval_losses,
train_dataset=train_dataset,
eval_datasets=eval_datasets,
truncated_gradient=args.seq_truncated,
checkpoint_path=checkpoint_path,
**kwargs).train()
return model