in ttw/models/language.py [0:0]
def forward(self, batch, add_rl_loss=False):
batch_size = batch['utterance'].size(0)
input_emb = self.embed_fn(batch['utterance'])
hidden_states, _ = self.encoder_fn(input_emb)
last_state_indices = batch['utterance_mask'].sum(1).long() - 1
last_hidden_states = hidden_states[torch.arange(batch_size).long(), last_state_indices, :]
T_dist = F.softmax(self.T_prediction_fn(last_hidden_states))
sampled_Ts = T_dist.multinomial(1).squeeze(-1)
obs_msgs = list()
feat_controller = self.feat_control_emb.unsqueeze(0).repeat(batch_size, 1)
for step in range(self.T + 1):
extracted_msg, feat_controller = self.feat_control_step_fn(hidden_states, batch['utterance_mask'], feat_controller)
obs_msgs.append(extracted_msg)
tourist_obs_msg = []
for i, (gate, emb) in enumerate(zip(self.obs_write_gate, obs_msgs)):
include = (i <= sampled_Ts).float().unsqueeze(-1)
tourist_obs_msg.append(include*F.sigmoid(gate)*emb)
tourist_obs_msg = sum(tourist_obs_msg)
landmark_emb = self.cbow_fn(batch['landmarks']).permute(0, 3, 1, 2)
landmark_embs = [landmark_emb]
if self.apply_masc:
act_controller = self.act_control_emb.unsqueeze(0).repeat(batch_size, 1)
for step in range(self.T):
extracted_msg, act_controller = self.act_control_step_fn(hidden_states, batch['utterance_mask'], act_controller)
action_out = self.action_linear_fn(extracted_msg)
out = self.masc_fn.forward(landmark_embs[-1], action_out, current_step=step, Ts=sampled_Ts)
landmark_embs.append(out)
else:
for step in range(self.T):
landmark_embs.append(self.masc_fn.forward(landmark_embs[-1]))
landmarks = sum([F.sigmoid(gate)*emb for gate, emb in zip(self.landmark_write_gate, landmark_embs)])
landmarks = landmarks.resize(batch_size, landmarks.size(1), 16).transpose(1, 2)
out = dict()
logits = torch.bmm(landmarks, tourist_obs_msg.unsqueeze(-1)).squeeze(-1)
out['prob'] = F.softmax(logits, dim=1)
y_true = (batch['target'][:, 0] * 4 + batch['target'][:, 1])
out['sl_loss'] = -torch.log(torch.gather(out['prob'], 1, y_true.unsqueeze(-1)) + 1e-8)
# add RL loss
if add_rl_loss:
advantage = -(out['sl_loss'] - out['sl_loss'].mean()).detach()
log_prob = torch.log(torch.gather(T_dist, 1, sampled_Ts.unsqueeze(-1)) + 1e-8)
out['rl_loss'] = log_prob*advantage
out['acc'] = sum([1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if
pred == target]) / batch_size
return out