in habitat_baselines/rl/ppo/ppo_trainer.py [0:0]
def _collect_environment_result(self, buffer_index: int = 0):
num_envs = self.envs.num_envs
env_slice = slice(
int(buffer_index * num_envs / self._nbuffers),
int((buffer_index + 1) * num_envs / self._nbuffers),
)
t_step_env = time.time()
outputs = [
self.envs.wait_step_at(index_env)
for index_env in range(env_slice.start, env_slice.stop)
]
observations, rewards_l, dones, infos = [
list(x) for x in zip(*outputs)
]
self.env_time += time.time() - t_step_env
t_update_stats = time.time()
batch = batch_obs(
observations, device=self.device, cache=self._obs_batching_cache
)
batch = apply_obs_transforms_batch(batch, self.obs_transforms) # type: ignore
rewards = torch.tensor(
rewards_l,
dtype=torch.float,
device=self.current_episode_reward.device,
)
rewards = rewards.unsqueeze(1)
not_done_masks = torch.tensor(
[[not done] for done in dones],
dtype=torch.bool,
device=self.current_episode_reward.device,
)
done_masks = torch.logical_not(not_done_masks)
self.current_episode_reward[env_slice] += rewards
current_ep_reward = self.current_episode_reward[env_slice]
self.running_episode_stats["reward"][env_slice] += current_ep_reward.where(done_masks, current_ep_reward.new_zeros(())) # type: ignore
self.running_episode_stats["count"][env_slice] += done_masks.float() # type: ignore
for k, v_k in self._extract_scalars_from_infos(infos).items():
v = torch.tensor(
v_k,
dtype=torch.float,
device=self.current_episode_reward.device,
).unsqueeze(1)
if k not in self.running_episode_stats:
self.running_episode_stats[k] = torch.zeros_like(
self.running_episode_stats["count"]
)
self.running_episode_stats[k][env_slice] += v.where(done_masks, v.new_zeros(())) # type: ignore
self.current_episode_reward[env_slice].masked_fill_(done_masks, 0.0)
if self._static_encoder:
with torch.no_grad():
batch["visual_features"] = self._encoder(batch)
self.rollouts.insert(
next_observations=batch,
rewards=rewards,
next_masks=not_done_masks,
buffer_index=buffer_index,
)
self.rollouts.advance_rollout(buffer_index)
self.pth_time += time.time() - t_update_stats
return env_slice.stop - env_slice.start