in rlstructures/deprecated/batchers/episodebatchers.py [0:0]
def get(self):
with torch.no_grad():
obs, is_running = self.env.reset(self.env_info)
n_elems = obs.n_elems()
observations = [{k: obs[k] for k in obs.keys()}]
states = []
agent_state = None
agent_info = self.agent_info
if agent_info is None:
agent_info = DictTensor({})
t = 0
length = torch.zeros(is_running.size()[0]).long()
first_state = None
first_info = agent_info
while is_running.size()[0] > 0:
old_agent_state, agent_output, new_agent_state = self.agent(
agent_state, obs, agent_info
)
if len(states) == 0:
first_state = old_agent_state
s = {k: old_agent_state[k] for k in old_agent_state.keys()}
s = {**s, **{k: agent_output[k] for k in agent_output.keys()}}
s = {
**s,
**{"_" + k: new_agent_state[k] for k in new_agent_state.keys()},
}
states.append(s)
else:
s = {k: old_agent_state[k] for k in old_agent_state.keys()}
s = {**s, **{k: agent_output[k] for k in agent_output.keys()}}
s = {
**s,
**{"_" + k: new_agent_state[k] for k in new_agent_state.keys()},
}
ns = {k: states[0][k].clone() for k in states[0]}
for k in states[0]:
ns[k][is_running] = s[k]
states.append(ns)
(l_o, l_is_running), (obs, is_running) = self.env.step(agent_output)
for k in l_o.keys():
observations[t]["_" + k] = observations[0][k].clone()
for k in l_o.keys():
observations[t]["_" + k][l_is_running] = l_o[k]
length[l_is_running] += 1
t += 1
if is_running.size()[0] > 0:
observations.append({})
for k in obs.keys():
observations[t][k] = observations[0][k].clone()
for k in obs.keys():
observations[t][k][is_running] = obs[k]
ag = {k: first_state[k].clone() for k in first_state.keys()}
for k in ag:
ag[k][l_is_running] = new_agent_state[k]
agent_state = DictTensor({k: ag[k][is_running] for k in ag})
ai = {k: first_info[k].clone() for k in first_info.keys()}
agent_info = DictTensor({k: ai[k][is_running] for k in ai})
f_observations = {}
for k in observations[0]:
_all = [o[k].unsqueeze(1) for o in observations]
f_observations[k] = torch.cat(_all, dim=1)
f_states = {}
for k in states[0]:
_all = [o[k].unsqueeze(1) for o in states]
f_states[k] = torch.cat(_all, dim=1)
return TemporalDictTensor({**f_observations, **f_states}, lengths=length)