in exploring_exploration/utils/storage.py [0:0]
def recurrent_generator(self, advantages, num_mini_batch):
num_processes = self.rewards.size(1)
assert num_processes >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"to be greater than or equal to the number of "
"PPO mini batches ({}).".format(num_processes, num_mini_batch)
)
num_envs_per_batch = num_processes // num_mini_batch
perm = torch.randperm(num_processes)
for start_ind in range(0, num_processes, num_envs_per_batch):
obs_im_batch = []
if self.encoder_type == "rgb+map":
obs_sm_batch = []
obs_lm_batch = []
else:
obs_sm_batch = None
obs_lm_batch = None
recurrent_hidden_states_batch = []
actions_batch = []
value_preds_batch = []
return_batch = []
masks_batch = []
collisions_batch = []
action_masks_batch = []
old_action_log_probs_batch = []
adv_targ = []
for offset in range(num_envs_per_batch):
ind = perm[start_ind + offset]
obs_im_batch.append(self.obs_im[:-1, ind])
if self.encoder_type == "rgb+map":
obs_sm_batch.append(self.obs_sm[:-1, ind])
obs_lm_batch.append(self.obs_lm[:-1, ind])
recurrent_hidden_states_batch.append(
self.recurrent_hidden_states[0:1, ind]
)
actions_batch.append(self.actions[:, ind])
value_preds_batch.append(self.value_preds[:-1, ind])
return_batch.append(self.returns[:-1, ind])
masks_batch.append(self.masks[:-1, ind])
collisions_batch.append(self.collisions[:-1, ind])
action_masks_batch.append(self.action_masks[:, ind])
old_action_log_probs_batch.append(self.action_log_probs[:, ind])
adv_targ.append(advantages[:, ind])
T, N = self.num_steps, num_envs_per_batch
# These are all tensors of size (T, N, -1)
obs_im_batch = torch.stack(obs_im_batch, 1)
if self.encoder_type == "rgb+map":
obs_sm_batch = torch.stack(obs_sm_batch, 1)
obs_lm_batch = torch.stack(obs_lm_batch, 1)
actions_batch = torch.stack(actions_batch, 1)
value_preds_batch = torch.stack(value_preds_batch, 1)
return_batch = torch.stack(return_batch, 1)
masks_batch = torch.stack(masks_batch, 1)
collisions_batch = torch.stack(collisions_batch, 1)
action_masks_batch = torch.stack(action_masks_batch, 1)
old_action_log_probs_batch = torch.stack(old_action_log_probs_batch, 1)
adv_targ = torch.stack(adv_targ, 1)
# States is just a (N, -1) tensor
recurrent_hidden_states_batch = torch.stack(
recurrent_hidden_states_batch, 1
).view(N, -1)
# Flatten the (T, N, ...) tensors to (T * N, ...)
obs_im_batch = _flatten_helper(T, N, obs_im_batch)
if self.encoder_type == "rgb+map":
obs_sm_batch = _flatten_helper(T, N, obs_sm_batch)
obs_lm_batch = _flatten_helper(T, N, obs_lm_batch)
actions_batch = _flatten_helper(T, N, actions_batch)
value_preds_batch = _flatten_helper(T, N, value_preds_batch)
return_batch = _flatten_helper(T, N, return_batch)
masks_batch = _flatten_helper(T, N, masks_batch)
collisions_batch = _flatten_helper(T, N, collisions_batch)
action_masks_batch = _flatten_helper(T, N, action_masks_batch)
old_action_log_probs_batch = _flatten_helper(
T, N, old_action_log_probs_batch
)
adv_targ = _flatten_helper(T, N, adv_targ)
yield (
obs_im_batch,
obs_sm_batch,
obs_lm_batch,
recurrent_hidden_states_batch,
actions_batch,
value_preds_batch,
return_batch,
masks_batch,
collisions_batch,
action_masks,
old_action_log_probs_batch,
adv_targ,
T,
N,
)