def forward()

in ttw/models/continuous.py [0:0]


    def forward(self, msg, batch):
        obs_msg, act_msg = msg['obs'], msg['act']

        l_emb = self.cbow_fn.forward(batch['landmarks']).permute(0, 3, 1, 2)
        l_embs = [l_emb]

        if self.apply_masc:
            for j in range(self.T):
                act_mask = self.extract_fns[j](act_msg)
                out = self.masc_fn.forward(l_embs[-1], act_mask)
                l_embs.append(out)
        else:
            for j in range(self.T):
                out = self.masc_fn.forward(l_emb)
                l_embs.append(out)

        landmarks = sum([F.sigmoid(gate)*emb for gate, emb in zip(self.landmark_write_gate, l_embs)])
        landmarks = landmarks.resize(l_emb.size(0), landmarks.size(1), 16).transpose(1, 2)

        out = dict()
        logits = torch.bmm(landmarks, obs_msg.unsqueeze(-1)).squeeze(-1)
        out['prob'] = F.softmax(logits, dim=1)


        y_true = (batch['target'][:, 0]*4 + batch['target'][:, 1])

        out['loss'] = self.loss(logits, y_true)
        out['acc'] = sum([1.0 for pred, target in zip(out['prob'].max(1)[1].data.cpu().numpy(), y_true.data.cpu().numpy()) if pred == target])/y_true.size(0)
        return out