in train/models.py [0:0]
def forward(self, last_actions, env_outputs, core_state, unroll=False):
if not unroll:
# [T=1, B, ...].
env_outputs = nest.map(lambda t: t.unsqueeze(0), env_outputs)
observation, reward, done = env_outputs
T, B, *_ = observation.shape
x = torch.flatten(observation, 0, 1) # Merge time and batch.
x = x.view(T * B, -1)
# Separate the job ID from the rest of the observation.
job_id = x[:, -1]
x = x[:, 0:-1]
if self.use_job_id_in_network_input:
x = self.concat_to_job_id(x, job_id)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
# reward = torch.clamp(reward, -1, 1).view(T * B, 1).float()
reward = reward.view(T * B, 1).float()
core_input = torch.cat([x, reward], dim=1)
if self.use_lstm:
core_input = core_input.view(T, B, -1)
core_output_list = []
notdone = (~done).float()
notdone.unsqueeze_(-1) # [T, B, H=1] for broadcasting.
for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()):
# When `done` is True it means this is the first step in a new
# episode => reset the internal state to zero.
core_state = nest.map(notdone_t.mul, core_state)
output_t, core_state = self.core(input_t, core_state)
core_state = (output_t, core_state) # nn.LSTMCell is a bit weird.
core_output_list.append(output_t) # [[B, H], [B, H], ...].
core_output = torch.cat(core_output_list) # [T * B, H].
else:
core_output = core_input
actor_input = (
self.concat_to_job_id(core_output, job_id)
if self.use_job_id_in_actor_head
else core_output
)
policy_logits = self.policy(actor_input)
if self.training:
action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
critic_input = (
self.concat_to_job_id(core_output, job_id)
if self.use_job_id_in_critic_head
else core_output
)
baseline = self.baseline(critic_input)
baseline = baseline.view(T, B)
else:
# Don't sample when testing.
action = torch.argmax(policy_logits, dim=1)
policy_logits = policy_logits.view(T, B, self.num_actions)
action = action.view(T, B)
if self.training:
outputs = dict(
action=action, policy_logits=policy_logits, baseline=baseline
)
if not unroll:
outputs = nest.map(lambda t: t.squeeze(0), outputs)
return outputs, core_state
else:
# In eval mode, we just return (action, core_state). PyTorch doesn't
# support jit tracing output dicts.
return action, core_state