in sing/sequence/trainer.py [0:0]
def _train_batch(self, batch):
embeddings = batch.tensors['embeddings']
assert embeddings.size(-1) == self.model.length
total_length = self.model.length
hidden = None
if self.truncated_gradient:
truncated_gradient = self.truncated_gradient
else:
truncated_gradient = total_length
steps = list(range(0, total_length, truncated_gradient))
total_loss = 0
for start_time in steps:
sequence_length = min(truncated_gradient,
total_length - start_time)
target = embeddings[..., start_time:start_time + sequence_length]
rebuilt, hidden = self.parallel.forward(
start=start_time,
length=sequence_length,
hidden=hidden,
**batch.tensors)
hidden = tuple([h.detach() for h in hidden])
self.optimizer.zero_grad()
loss = self.train_loss(rebuilt, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item() / len(steps)
return total_loss