def forward()

in ttw/models/discrete.py [0:0]


    def forward(self, batch, greedy=False):
        batch_size = batch['actions'].size(0)
        feat_emb = list()

        max_steps = self.T + 1
        for step in range(max_steps):
            emb = self.goldstandard_emb.forward(batch['goldstandard'][:, step, :])
            emb = emb * F.sigmoid(self.obs_write_gate[step])
            feat_emb.append(emb)

        act_emb = list()
        if self.apply_masc:
            for step in range(self.T):
                emb = self.action_emb.forward(batch['actions'][:, step])
                emb = emb * F.sigmoid(self.act_write_gate[step])
                act_emb.append(emb)

        out = {}
        out['comms'] = list()
        out['probs'] = list()

        feat_embeddings = sum(feat_emb)
        feat_logits = feat_embeddings
        feat_prob = F.sigmoid(feat_logits).cpu()
        feat_msg = feat_prob.bernoulli().detach()

        out['probs'].append(feat_prob)
        out['comms'].append(feat_msg)

        if self.apply_masc:
            act_embeddings = sum(act_emb)
            act_logits = act_embeddings
            act_prob = F.sigmoid(act_logits).cpu()
            act_msg = act_prob.bernoulli().detach()

            out['probs'].append(act_prob)
            out['comms'].append(act_msg)

        if self.apply_masc:
            embeddings = torch.cat([feat_embeddings, act_embeddings], 1).resize(batch_size, 2 * self.vocab_sz)
        else:
            embeddings = feat_embeddings
        out['baseline'] = self.value_pred(embeddings)

        return out