in egg/core/reinforce_wrappers.py [0:0]
def forward(self, x, aux_input=None):
prev_hidden = [self.agent(x, aux_input)]
prev_hidden.extend(
[torch.zeros_like(prev_hidden[0]) for _ in range(self.num_layers - 1)]
)
prev_c = [
torch.zeros_like(prev_hidden[0]) for _ in range(self.num_layers)
] # only used for LSTM
input = torch.stack([self.sos_embedding] * x.size(0))
sequence = []
logits = []
entropy = []
for step in range(self.max_len):
for i, layer in enumerate(self.cells):
if isinstance(layer, nn.LSTMCell):
h_t, c_t = layer(input, (prev_hidden[i], prev_c[i]))
prev_c[i] = c_t
else:
h_t = layer(input, prev_hidden[i])
prev_hidden[i] = h_t
input = h_t
step_logits = F.log_softmax(self.hidden_to_output(h_t), dim=1)
distr = Categorical(logits=step_logits)
entropy.append(distr.entropy())
if self.training:
x = distr.sample()
else:
x = step_logits.argmax(dim=1)
logits.append(distr.log_prob(x))
input = self.embedding(x)
sequence.append(x)
sequence = torch.stack(sequence).permute(1, 0)
logits = torch.stack(logits).permute(1, 0)
entropy = torch.stack(entropy).permute(1, 0)
zeros = torch.zeros((sequence.size(0), 1)).to(sequence.device)
sequence = torch.cat([sequence, zeros.long()], dim=1)
logits = torch.cat([logits, zeros], dim=1)
entropy = torch.cat([entropy, zeros], dim=1)
return sequence, logits, entropy