in ttw/train/predict_location_discrete.py [0:0]
def eval_epoch(loader, tourist, guide, cuda, t_opt=None, g_opt=None):
tourist.eval()
guide.eval()
correct, total = 0, 0
for batch in loader:
# forward
t_out = tourist(batch)
if cuda:
t_out['comms'] = [x.cuda() for x in t_out['comms']]
g_out = guide(t_out['comms'], batch)
# acc
correct += g_out['acc']*len(batch['target'])
total += len(batch['target'])
if t_opt and g_opt:
# train if optimizers are specified
rewards = -g_out['loss'].unsqueeze(-1) # tourist reward is log likelihood of correct answer
t_rl_loss = 0.
eps = 1e-16
advantage = Variable((rewards.data - t_out['baseline'].data))
if cuda:
advantage = advantage.cuda()
t_val_loss = ((t_out['baseline'] - Variable(rewards.data)) ** 2).mean() # mse
for action, prob in zip(t_out['comms'], t_out['probs']):
if cuda:
action = action.cuda()
prob = prob.cuda()
action_prob = action * prob + (1.0 - action) * (1.0 - prob)
t_rl_loss -= (torch.log(action_prob + eps) * advantage).sum()
# backward
g_opt.zero_grad()
t_opt.zero_grad()
g_out['loss'].sum().backward()
(t_rl_loss + t_val_loss).backward()
torch.nn.utils.clip_grad_norm(tourist.parameters(), 5)
torch.nn.utils.clip_grad_norm(guide.parameters(), 5)
g_opt.step()
t_opt.step()
return correct / total