in gala/storage.py [0:0]
def __init__(self, num_steps, num_processes, obs_shape, action_space,
recurrent_hidden_state_size):
self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.recurrent_hidden_states = torch.zeros(
num_steps + 1, num_processes, recurrent_hidden_state_size)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
if action_space.__class__.__name__ == 'Discrete':
action_shape = 1
else:
action_shape = action_space.shape[0]
self.actions = torch.zeros(num_steps, num_processes, action_shape)
if action_space.__class__.__name__ == 'Discrete':
self.actions = self.actions.long()
self.masks = torch.ones(num_steps + 1, num_processes, 1)
# Masks that indicate whether it's a true terminal state
# or time limit end state
self.bad_masks = torch.ones(num_steps + 1, num_processes, 1)
self.num_steps = num_steps
self.step = 0