def forward()

in train/models.py [0:0]


    def forward(self, last_actions, env_outputs, core_state, unroll=False):
        if not unroll:
            # [T=1, B, ...].
            env_outputs = nest.map(lambda t: t.unsqueeze(0), env_outputs)

        observation, reward, done = env_outputs

        T, B, *_ = observation.shape
        x = torch.flatten(observation, 0, 1)  # Merge time and batch.
        x = x.view(T * B, -1)

        # Separate the job ID from the rest of the observation.
        job_id = x[:, -1]
        x = x[:, 0:-1]

        if self.use_job_id_in_network_input:
            x = self.concat_to_job_id(x, job_id)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        # reward = torch.clamp(reward, -1, 1).view(T * B, 1).float()
        reward = reward.view(T * B, 1).float()
        core_input = torch.cat([x, reward], dim=1)

        if self.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~done).float()
            notdone.unsqueeze_(-1)  # [T, B, H=1] for broadcasting.

            for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()):
                # When `done` is True it means this is the first step in a new
                # episode => reset the internal state to zero.
                core_state = nest.map(notdone_t.mul, core_state)
                output_t, core_state = self.core(input_t, core_state)
                core_state = (output_t, core_state)  # nn.LSTMCell is a bit weird.
                core_output_list.append(output_t)  # [[B, H], [B, H], ...].
            core_output = torch.cat(core_output_list)  # [T * B, H].
        else:
            core_output = core_input

        actor_input = (
            self.concat_to_job_id(core_output, job_id)
            if self.use_job_id_in_actor_head
            else core_output
        )
        policy_logits = self.policy(actor_input)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)

            critic_input = (
                self.concat_to_job_id(core_output, job_id)
                if self.use_job_id_in_critic_head
                else core_output
            )
            baseline = self.baseline(critic_input)

            baseline = baseline.view(T, B)

        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        action = action.view(T, B)

        if self.training:
            outputs = dict(
                action=action, policy_logits=policy_logits, baseline=baseline
            )
            if not unroll:
                outputs = nest.map(lambda t: t.squeeze(0), outputs)
            return outputs, core_state
        else:
            # In eval mode, we just return (action, core_state). PyTorch doesn't
            # support jit tracing output dicts.
            return action, core_state