def epoch()

in ttw/train/predict_location_generated.py [0:0]


def epoch(loader, tourist, guide, g_opt=None, t_opt=None,
          decoding_strategy='greedy', beam_width=4, on_the_fly=False):
    accuracy, total = 0.0, 0.0

    for batch in loader:
        if on_the_fly:
            t_out = tourist.forward(batch,
                                    decoding_strategy=decoding_strategy,
                                    beam_width=beam_width,
                                    train=False)
            batch['utterance'] = t_out['utterance']
            batch['utterance_mask'] = t_out['utterance_mask']

        g_out = guide.forward(batch)

        reward = -g_out['sl_loss'].squeeze()
        loss = g_out['sl_loss'].sum()

        total += batch['landmarks'].size(0)
        accuracy += g_out['acc'] * batch['landmarks'].size(0)

        if g_opt is not None:
            g_opt.zero_grad()
            loss.backward()
            g_opt.step()

        if t_opt is not None:
            # reinforce
            probs = t_out['probs']
            mask = t_out['utterance_mask']
            sampled_ind = t_out['utterance']

            loss = 0.0

            for k in range(probs.size(1)):
                prob = probs[:, k, :]
                ind = sampled_ind[:, k]
                selected_prob = torch.gather(prob, 1, ind.unsqueeze(-1))
                log_prob = torch.log(selected_prob + 1e-8).squeeze(-1)
                advantage = reward - reward.mean()
                loss -= (mask[:, k] * log_prob * advantage).sum()

            t_opt.zero_grad()
            loss.backward()
            t_opt.step()

    return accuracy / total