in model/reader.py [0:0]
def forward(self, inputs):
name = inputs['name'].long() # (T, B, H, W, placement, name_len)
T, B, height, width, n_placement, n_text = name.size()
# encode everything
cell = self.encode_cell(inputs)
inv = self.encode_inv(inputs)
wiki = self.encode_wiki(inputs)
task = self.encode_task(inputs)
rep = self.fuse(inputs, cell, inv, wiki, task)
policy_logits = self.policy(rep)
baseline = self.baseline(rep)
# mask out invalid actions
action_mask = inputs['valid'].float().view(T*B, -1)
policy_logits -= (1-action_mask) * 1e20
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
aux_loss = self.compute_aux_loss(inputs, cell, inv, wiki, task)
return dict(policy_logits=policy_logits, baseline=baseline, action=action, aux_loss=aux_loss)