in torchbeast/polybeast_learner.py [0:0]
def forward(self, inputs, core_state):
x = inputs["frame"]
T, B, *_ = x.shape
x = torch.flatten(x, 0, 1) # Merge time and batch.
x = x.float() / 255.0
res_input = None
for i, fconv in enumerate(self.feat_convs):
x = fconv(x)
res_input = x
x = self.resnet1[i](x)
x += res_input
res_input = x
x = self.resnet2[i](x)
x += res_input
x = F.relu(x)
x = x.view(T * B, -1)
x = F.relu(self.fc(x))
clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
core_input = torch.cat([x, clipped_reward], dim=-1)
if self.use_lstm:
core_input = core_input.view(T, B, -1)
core_output_list = []
notdone = (~inputs["done"]).float()
for input, nd in zip(core_input.unbind(), notdone.unbind()):
# Reset core state to zero whenever an episode ended.
# Make `done` broadcastable with (num_layers, B, hidden_size)
# states:
nd = nd.view(1, -1, 1)
core_state = nest.map(nd.mul, core_state)
output, core_state = self.core(input.unsqueeze(0), core_state)
core_output_list.append(output)
core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
else:
core_output = core_input
policy_logits = self.policy(core_output)
baseline = self.baseline(core_output)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
return (action, policy_logits, baseline), core_state