def train_step()

in train_fns.py [0:0]


def train_step(model, config, inputs, optimizer, batch_idx, logger=None):
    """Training step for the model."""

    outputs = inputs.clone().detach()

    # Forward pass
    (preds, priors, posteriors), stored_vars = model(inputs, config, False)

    # Accumulate preds and select targets
    targets = outputs[:, config['n_ctx']:]

    # Compute the reconstruction loss
    loss_rec = losses.reconstruction_loss(config, preds, targets)

    # Compute the prior loss
    if config['beta'] > 0:
        loss_prior = losses.kl_loss(config, priors, posteriors)
        loss = loss_rec + config['beta']*loss_prior
    else:
        loss_prior = 0.
        loss = loss_rec

    # Backward pass and optimizer step
    optimizer.zero_grad()
    if config['apex']:
        from apex import amp
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
    else:
        loss.backward()
    optimizer.step()

    # Logs 
    if logger is not None:
        logger.scalar('train_loss_rec', loss_rec.item())
        logger.scalar('train_loss', loss.item())
        if config['beta'] > 0:
            logger.scalar('train_loss_prior', loss_prior.item())

    return preds, targets, priors, posteriors, loss_rec, loss_prior, loss, stored_vars