in gala/storage.py [0:0]
def feed_forward_generator(self,
advantages,
num_mini_batch=None,
mini_batch_size=None):
num_steps, num_processes = self.rewards.size()[0:2]
batch_size = num_processes * num_steps
if mini_batch_size is None:
assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
"".format(num_processes, num_steps, num_processes * num_steps,
num_mini_batch))
mini_batch_size = batch_size // num_mini_batch
sampler = BatchSampler(
SubsetRandomSampler(range(batch_size)),
mini_batch_size,
drop_last=True)
for indices in sampler:
obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices]
recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view(
-1, self.recurrent_hidden_states.size(-1))[indices]
actions_batch = self.actions.view(-1,
self.actions.size(-1))[indices]
value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices]
return_batch = self.returns[:-1].view(-1, 1)[indices]
masks_batch = self.masks[:-1].view(-1, 1)[indices]
old_action_log_probs_batch = self.action_log_probs.view(-1,
1)[indices]
if advantages is None:
adv_targ = None
else:
adv_targ = advantages.view(-1, 1)[indices]
yield obs_batch, recurrent_hidden_states_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ