in level_replay/storage.py [0:0]
def __init__(self, num_steps, num_processes, obs_shape, action_space,
recurrent_hidden_state_size, split_ratio=0.05):
self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape)
self.recurrent_hidden_states = torch.zeros(
num_steps + 1, num_processes, recurrent_hidden_state_size)
self.rewards = torch.zeros(num_steps, num_processes, 1)
self.value_preds = torch.zeros(num_steps + 1, num_processes, 1)
self.returns = torch.zeros(num_steps + 1, num_processes, 1)
self.action_log_probs = torch.zeros(num_steps, num_processes, 1)
self.action_log_dist = torch.zeros(num_steps, num_processes, action_space.n)
if action_space.__class__.__name__ == 'Discrete':
action_shape = 1
else:
action_shape = action_space.shape[0]
self.actions = torch.zeros(num_steps, num_processes, action_shape)
if action_space.__class__.__name__ == 'Discrete':
self.actions = self.actions.long()
self.masks = torch.ones(num_steps + 1, num_processes, 1)
# Masks that indicate whether it's a true terminal state
# or time limit end state
self.bad_masks = torch.ones(num_steps + 1, num_processes, 1)
self.level_seeds = torch.zeros(num_steps, num_processes, 1, dtype=torch.int)
self.num_steps = num_steps
self.step = 0
self.split_ratio = split_ratio