in ddppo_agents.py [0:0]
def act(self, observations: Observations) -> Dict[str, int]:
batch = batch_obs([observations], device=self.device)
batch = apply_obs_transforms_batch(batch, self.obs_transforms)
with torch.no_grad():
(_, actions, _, self.test_recurrent_hidden_states) = self.actor_critic.act(
batch,
self.test_recurrent_hidden_states,
self.prev_actions,
self.not_done_masks,
deterministic=False,
)
# Make masks not done till reset (end of episode) will be called
self.not_done_masks.fill_(True)
self.prev_actions.copy_(actions) # type: ignore
return {"action": actions[0][0].item()}