def step()

in rlstructures/env_wrappers/gymenv.py [0:0]


    def step(self, policy_output):
        assert policy_output.n_elems() == self.n_envs()
        outputs = policy_output.unfold()
        alls = []
        alls_after = []
        env_run = {}
        for b in range(len(outputs)):
            action = policy_output["action"][b]
            last_action = action
            if isinstance(self.gym_envs[0].action_space, gym.spaces.Discrete):
                action = action.item()
                last_action = last_action.unsqueeze(0)
            else:
                action = action.tolist()
                last_action = last_action.unsqueeze(0)

            initial_state = torch.tensor([False])
            act = action

            frame, reward, done, unused_info = self.gym_envs[b].step(act)
            reward = torch.tensor([reward]).float()
            frame = format_frame(frame)
            if isinstance(frame, torch.Tensor):
                frame = {"frame": frame}

            done = torch.tensor([done])
            r = DictTensor(
                {
                    "reward": reward,
                    "done": done,
                    "initial_state": initial_state,
                    "last_action": last_action,
                    **frame,
                }
            )
            alls.append(r)

            if done:
                if "set" in dir(self.gym_envs[b]):
                    self.gym_envs[b].set(self.env_info[b])

                if self.env_info.empty():
                    frame = self.gym_envs[b].reset()
                else:
                    v = {k: env_info[k][b].tolist() for k in env_info.keys()}
                    frame = self.gym_envs[b].reset(env_info=v)

                frame = format_frame(frame)
                if isinstance(frame, torch.Tensor):
                    frame = {"frame": frame}

                last_action = None
                if isinstance(self.gym_envs[0].action_space, gym.spaces.Discrete):
                    last_action = torch.zeros(1, dtype=torch.int64)
                else:
                    a = self.gym_envs[0].action_space.sample()
                    a = torch.tensor([a])
                    last_action = a

                initial_state = torch.tensor([True])
                reward = torch.tensor([0.0])
                r = DictTensor(
                    {
                        "reward": reward,
                        "done": done,
                        "initial_state": initial_state,
                        "last_action": last_action,
                        **frame,
                    }
                )
                alls_after.append(r)
            else:
                alls_after.append(r)

        next_observation = DictTensor.cat(alls)
        next_observation_next_slot = DictTensor.cat(alls_after)
        return (
            (next_observation, torch.arange(self.n_envs())),
            (next_observation_next_slot, torch.arange(self.n_envs())),
        )