in ttw/train/predict_location_generated.py [0:0]
def epoch(loader, tourist, guide, g_opt=None, t_opt=None,
decoding_strategy='greedy', beam_width=4, on_the_fly=False):
accuracy, total = 0.0, 0.0
for batch in loader:
if on_the_fly:
t_out = tourist.forward(batch,
decoding_strategy=decoding_strategy,
beam_width=beam_width,
train=False)
batch['utterance'] = t_out['utterance']
batch['utterance_mask'] = t_out['utterance_mask']
g_out = guide.forward(batch)
reward = -g_out['sl_loss'].squeeze()
loss = g_out['sl_loss'].sum()
total += batch['landmarks'].size(0)
accuracy += g_out['acc'] * batch['landmarks'].size(0)
if g_opt is not None:
g_opt.zero_grad()
loss.backward()
g_opt.step()
if t_opt is not None:
# reinforce
probs = t_out['probs']
mask = t_out['utterance_mask']
sampled_ind = t_out['utterance']
loss = 0.0
for k in range(probs.size(1)):
prob = probs[:, k, :]
ind = sampled_ind[:, k]
selected_prob = torch.gather(prob, 1, ind.unsqueeze(-1))
log_prob = torch.log(selected_prob + 1e-8).squeeze(-1)
advantage = reward - reward.mean()
loss -= (mask[:, k] * log_prob * advantage).sum()
t_opt.zero_grad()
loss.backward()
t_opt.step()
return accuracy / total