def forward()

in ttw/models/discrete.py [0:0]


    def forward(self, message, batch):
        msg_obs = self.obs_emb_fn(message[0])
        batch_size = message[0].size(0)

        landmark_emb = self.emb_map.forward(batch['landmarks']).permute(0, 3, 1, 2)
        landmark_embs = [landmark_emb]

        if self.apply_masc:
            for j in range(self.T):
                act_msg = message[1]
                action_out = self.action_emb[j](act_msg)
                out = self.masc_fn.forward(landmark_embs[-1], action_out, current_step=j)
                landmark_embs.append(out)
        else:
            for j in range(self.T):
                out = self.masc_fn.forward(landmark_embs[-1])
                landmark_embs.append(out)

        landmarks = sum([F.sigmoid(gate) * emb for gate, emb in zip(self.landmark_write_gate, landmark_embs)])
        landmarks = landmarks.view(batch_size, landmarks.size(1), 16).transpose(1, 2)

        out = dict()
        logits = torch.bmm(landmarks, msg_obs.unsqueeze(-1)).squeeze(-1)
        out['prob'] = F.softmax(logits, 1)
        y_true = (batch['target'][:, 0] * 4 + batch['target'][:, 1])

        out['loss'] = self.loss(logits, y_true)
        out['acc'] = sum(
            [1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if
             pred == target]) / y_true.size(0)
        return out