in ttw/models/discrete.py [0:0]
def forward(self, batch, greedy=False):
batch_size = batch['actions'].size(0)
feat_emb = list()
max_steps = self.T + 1
for step in range(max_steps):
emb = self.goldstandard_emb.forward(batch['goldstandard'][:, step, :])
emb = emb * F.sigmoid(self.obs_write_gate[step])
feat_emb.append(emb)
act_emb = list()
if self.apply_masc:
for step in range(self.T):
emb = self.action_emb.forward(batch['actions'][:, step])
emb = emb * F.sigmoid(self.act_write_gate[step])
act_emb.append(emb)
out = {}
out['comms'] = list()
out['probs'] = list()
feat_embeddings = sum(feat_emb)
feat_logits = feat_embeddings
feat_prob = F.sigmoid(feat_logits).cpu()
feat_msg = feat_prob.bernoulli().detach()
out['probs'].append(feat_prob)
out['comms'].append(feat_msg)
if self.apply_masc:
act_embeddings = sum(act_emb)
act_logits = act_embeddings
act_prob = F.sigmoid(act_logits).cpu()
act_msg = act_prob.bernoulli().detach()
out['probs'].append(act_prob)
out['comms'].append(act_msg)
if self.apply_masc:
embeddings = torch.cat([feat_embeddings, act_embeddings], 1).resize(batch_size, 2 * self.vocab_sz)
else:
embeddings = feat_embeddings
out['baseline'] = self.value_pred(embeddings)
return out