in level_replay/envs.py [0:0]
def step_wait(self):
obs, reward, done, info = self.venv.step_wait()
# print(f"stepping {info[0]['level_seed']}, done: {done}")
# reset environment here
if self.level_sampler:
for e in done.nonzero()[0]:
seed = self.level_sampler.sample()
self.venv.seed(seed, e) # seed resets the corresponding level
# NB: This reset call propagates upwards through all VecEnvWrappers
obs = self.raw_venv.observe()['rgb'] # Note reset does not reset game instances, but only returns latest observations
if obs.shape[1] != 3:
obs = obs.transpose(0, 3, 1, 2)
obs = torch.from_numpy(obs).float().to(self.device) / 255.
reward = torch.from_numpy(reward).unsqueeze(dim=1).float()
return obs, reward, done, info