in rlstructures/env_wrappers/gymenv.py [0:0]
def reset(self, env_info=DictTensor({})):
assert env_info.empty() or env_info.n_elems() == self.n_envs()
N = self.n_envs()
self.envs_running = torch.arange(N)
reward = torch.zeros(N)
last_action = None
if isinstance(self.gym_envs[0].action_space, gym.spaces.Discrete):
last_action = torch.zeros(N, dtype=torch.int64)
else:
a = self.gym_envs[0].action_space.sample()
a = torch.tensor(a).unsqueeze(0).repeat(*(N, ) + tuple(1 for i in a.shape))
last_action = a
done = torch.zeros(N).bool()
initial_state = torch.ones(N).bool()
frames = None
if env_info.empty():
frames = [format_frame(e.reset()) for e in self.gym_envs]
else:
frames = []
for n in range(len(self.gym_envs)):
v = {k: env_info[k][n].tolist() for k in env_info.keys()}
frames.append(format_frame(self.gym_envs[n].reset(env_info=v)))
_frames = []
for f in frames:
if isinstance(f, torch.Tensor):
_frames.append({"frame": f})
else:
_frames.append(f)
frames = [DictTensor(_f) for _f in _frames]
frames = DictTensor.cat(frames)
frames.set("reward", reward)
frames.set("done", done)
frames.set("initial_state", initial_state)
frames.set("last_action", last_action)
return frames, self.envs_running