in train_fns.py [0:0]
def train_step(model, config, inputs, optimizer, batch_idx, logger=None):
"""Training step for the model."""
outputs = inputs.clone().detach()
# Forward pass
(preds, priors, posteriors), stored_vars = model(inputs, config, False)
# Accumulate preds and select targets
targets = outputs[:, config['n_ctx']:]
# Compute the reconstruction loss
loss_rec = losses.reconstruction_loss(config, preds, targets)
# Compute the prior loss
if config['beta'] > 0:
loss_prior = losses.kl_loss(config, priors, posteriors)
loss = loss_rec + config['beta']*loss_prior
else:
loss_prior = 0.
loss = loss_rec
# Backward pass and optimizer step
optimizer.zero_grad()
if config['apex']:
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
# Logs
if logger is not None:
logger.scalar('train_loss_rec', loss_rec.item())
logger.scalar('train_loss', loss.item())
if config['beta'] > 0:
logger.scalar('train_loss_prior', loss_prior.item())
return preds, targets, priors, posteriors, loss_rec, loss_prior, loss, stored_vars