in level_replay/envs.py [0:0]
def reset(self):
if self.level_sampler:
seeds = torch.zeros(self.venv.num_envs, dtype=torch.int)
for e in range(self.venv.num_envs):
seed = self.level_sampler.sample('sequential')
seeds[e] = seed
self.venv.seed(seed,e)
obs = self.venv.reset()
if obs.shape[1] != 3:
obs = obs.transpose(0, 3, 1, 2)
obs = torch.from_numpy(obs).float().to(self.device) / 255.
if self.level_sampler:
return obs, seeds
else:
return obs