in ttw/models/continuous.py [0:0]
def forward(self, msg, batch):
obs_msg, act_msg = msg['obs'], msg['act']
l_emb = self.cbow_fn.forward(batch['landmarks']).permute(0, 3, 1, 2)
l_embs = [l_emb]
if self.apply_masc:
for j in range(self.T):
act_mask = self.extract_fns[j](act_msg)
out = self.masc_fn.forward(l_embs[-1], act_mask)
l_embs.append(out)
else:
for j in range(self.T):
out = self.masc_fn.forward(l_emb)
l_embs.append(out)
landmarks = sum([F.sigmoid(gate)*emb for gate, emb in zip(self.landmark_write_gate, l_embs)])
landmarks = landmarks.resize(l_emb.size(0), landmarks.size(1), 16).transpose(1, 2)
out = dict()
logits = torch.bmm(landmarks, obs_msg.unsqueeze(-1)).squeeze(-1)
out['prob'] = F.softmax(logits, dim=1)
y_true = (batch['target'][:, 0]*4 + batch['target'][:, 1])
out['loss'] = self.loss(logits, y_true)
out['acc'] = sum([1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if pred == target])/y_true.size(0)
return out