in salina_examples/offline_rl/d4rl.py [0:0]
def d4rl_episode_buffer(d4rl_env):
# Import the dataset associated with a d4rl environment to a Workspace as full episodes (of different length)
print("[d4rl_episode_buffer] Reading dataset")
sequence_dataset = _fixed_sequence_dataset(d4rl_env)
episodes = []
current_episode = []
cumulated_reward = 0.0
episodes = []
for s in sequence_dataset:
episode = {k: torch.tensor(v).unsqueeze(1) for k, v in s.items()}
nepisode = {}
for k, v in episode.items():
if k.endswith("s"):
nepisode["env/" + k[:-1]] = v
else:
nepisode["env/" + k] = v
if "env/timeout" in nepisode:
nepisode["env/done"] = (
nepisode["env/terminal"] + nepisode["env/timeout"]
).bool()
else:
nepisode["env/done"] = (nepisode["env/terminal"]).bool()
if "env/observation" in nepisode:
nepisode["env/env_obs"] = nepisode.pop("env/observation")
nepisode["env/done"][-1] = True
nepisode["env/initial_state"] = nepisode["env/done"].clone()
nepisode["env/initial_state"].fill_(False)
nepisode["env/initial_state"][0] = True
nepisode["action"] = nepisode.pop("env/action")
nepisode["env/timestep"] = torch.arange(
nepisode["env/done"].size()[0]
).unsqueeze(1)
nepisode["env/reward"][1:] = nepisode["env/reward"][:-1].clone()
nepisode["env/reward"][0] = 0.0
nepisode["env/cumulated_reward"] = torch.zeros(
nepisode["env/done"].size()[0], 1
)
cr = 0.0
for t in range(nepisode["env/done"].size()[0]):
cr += nepisode["env/reward"][t].item()
nepisode["env/cumulated_reward"][t] = cr
episodes.append(nepisode)
max_length = max([e["env/reward"].size()[0] for e in episodes])
print("\t max episode length = ", max_length)
print("\t n episodes = ", len(episodes))
n_skip=0
for e in episodes:
l=e["env/reward"].size()[0]
if l==0:
n_skip+=1
continue
for k,v in e.items():
ts=v.size()[0]
if ts<max_length:
v.resize_(max_length,*(v.size()[1:]))
v[ts:]=0
print("\tSkip ",n_skip," trajectories of size = 0")
workspace=Workspace()
f_episode={}
for k in episodes[0]:
vals=[e[k] for e in episodes]
workspace.set_full(k,torch.cat(vals,dim=1))
return workspace