in torchbeast/monobeast.py [0:0]
def forward(self, inputs, core_state=()):
x = inputs["frame"] # [T, B, C, H, W].
T, B, *_ = x.shape
x = torch.flatten(x, 0, 1) # Merge time and batch.
x = x.float() / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(T * B, -1)
x = F.relu(self.fc(x))
one_hot_last_action = F.one_hot(
inputs["last_action"].view(T * B), self.num_actions
).float()
clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1)
if self.use_lstm:
core_input = core_input.view(T, B, -1)
core_output_list = []
notdone = (~inputs["done"]).float()
for input, nd in zip(core_input.unbind(), notdone.unbind()):
# Reset core state to zero whenever an episode ended.
# Make `done` broadcastable with (num_layers, B, hidden_size)
# states:
nd = nd.view(1, -1, 1)
core_state = tuple(nd * s for s in core_state)
output, core_state = self.core(input.unsqueeze(0), core_state)
core_output_list.append(output)
core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
else:
core_output = core_input
core_state = tuple()
policy_logits = self.policy(core_output)
baseline = self.baseline(core_output)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
baseline = baseline.view(T, B)
action = action.view(T, B)
return (
dict(policy_logits=policy_logits, baseline=baseline, action=action),
core_state,
)