def get()

in rlstructures/deprecated/batchers/episodebatchers.py [0:0]


    def get(self):
        with torch.no_grad():
            obs, is_running = self.env.reset(self.env_info)
            n_elems = obs.n_elems()
            observations = [{k: obs[k] for k in obs.keys()}]
            states = []
            agent_state = None
            agent_info = self.agent_info
            if agent_info is None:
                agent_info = DictTensor({})
            t = 0
            length = torch.zeros(is_running.size()[0]).long()
            first_state = None
            first_info = agent_info
            while is_running.size()[0] > 0:
                old_agent_state, agent_output, new_agent_state = self.agent(
                    agent_state, obs, agent_info
                )

                if len(states) == 0:
                    first_state = old_agent_state
                    s = {k: old_agent_state[k] for k in old_agent_state.keys()}
                    s = {**s, **{k: agent_output[k] for k in agent_output.keys()}}
                    s = {
                        **s,
                        **{"_" + k: new_agent_state[k] for k in new_agent_state.keys()},
                    }
                    states.append(s)
                else:
                    s = {k: old_agent_state[k] for k in old_agent_state.keys()}
                    s = {**s, **{k: agent_output[k] for k in agent_output.keys()}}
                    s = {
                        **s,
                        **{"_" + k: new_agent_state[k] for k in new_agent_state.keys()},
                    }

                    ns = {k: states[0][k].clone() for k in states[0]}

                    for k in states[0]:
                        ns[k][is_running] = s[k]
                    states.append(ns)

                (l_o, l_is_running), (obs, is_running) = self.env.step(agent_output)

                for k in l_o.keys():
                    observations[t]["_" + k] = observations[0][k].clone()
                for k in l_o.keys():
                    observations[t]["_" + k][l_is_running] = l_o[k]
                length[l_is_running] += 1
                t += 1
                if is_running.size()[0] > 0:
                    observations.append({})
                    for k in obs.keys():
                        observations[t][k] = observations[0][k].clone()
                    for k in obs.keys():
                        observations[t][k][is_running] = obs[k]

                    ag = {k: first_state[k].clone() for k in first_state.keys()}
                    for k in ag:
                        ag[k][l_is_running] = new_agent_state[k]
                    agent_state = DictTensor({k: ag[k][is_running] for k in ag})

                    ai = {k: first_info[k].clone() for k in first_info.keys()}
                    agent_info = DictTensor({k: ai[k][is_running] for k in ai})

            f_observations = {}
            for k in observations[0]:
                _all = [o[k].unsqueeze(1) for o in observations]
                f_observations[k] = torch.cat(_all, dim=1)
            f_states = {}
            for k in states[0]:
                _all = [o[k].unsqueeze(1) for o in states]
                f_states[k] = torch.cat(_all, dim=1)
            return TemporalDictTensor({**f_observations, **f_states}, lengths=length)