in rlstructures/env_wrappers/gymenv.py [0:0]
def step(self, policy_output):
assert policy_output.n_elems() == self.n_envs()
outputs = policy_output.unfold()
alls = []
alls_after = []
env_run = {}
for b in range(len(outputs)):
action = policy_output["action"][b]
last_action = action
if isinstance(self.gym_envs[0].action_space, gym.spaces.Discrete):
action = action.item()
last_action = last_action.unsqueeze(0)
else:
action = action.tolist()
last_action = last_action.unsqueeze(0)
initial_state = torch.tensor([False])
act = action
frame, reward, done, unused_info = self.gym_envs[b].step(act)
reward = torch.tensor([reward]).float()
frame = format_frame(frame)
if isinstance(frame, torch.Tensor):
frame = {"frame": frame}
done = torch.tensor([done])
r = DictTensor(
{
"reward": reward,
"done": done,
"initial_state": initial_state,
"last_action": last_action,
**frame,
}
)
alls.append(r)
if done:
if "set" in dir(self.gym_envs[b]):
self.gym_envs[b].set(self.env_info[b])
if self.env_info.empty():
frame = self.gym_envs[b].reset()
else:
v = {k: env_info[k][b].tolist() for k in env_info.keys()}
frame = self.gym_envs[b].reset(env_info=v)
frame = format_frame(frame)
if isinstance(frame, torch.Tensor):
frame = {"frame": frame}
last_action = None
if isinstance(self.gym_envs[0].action_space, gym.spaces.Discrete):
last_action = torch.zeros(1, dtype=torch.int64)
else:
a = self.gym_envs[0].action_space.sample()
a = torch.tensor([a])
last_action = a
initial_state = torch.tensor([True])
reward = torch.tensor([0.0])
r = DictTensor(
{
"reward": reward,
"done": done,
"initial_state": initial_state,
"last_action": last_action,
**frame,
}
)
alls_after.append(r)
else:
alls_after.append(r)
next_observation = DictTensor.cat(alls)
next_observation_next_slot = DictTensor.cat(alls_after)
return (
(next_observation, torch.arange(self.n_envs())),
(next_observation_next_slot, torch.arange(self.n_envs())),
)