def d4rl_episode_buffer()

in salina_examples/offline_rl/d4rl.py [0:0]


def d4rl_episode_buffer(d4rl_env):
    # Import the dataset associated with a d4rl environment to a Workspace as full episodes (of different length)
    print("[d4rl_episode_buffer] Reading dataset")
    sequence_dataset = _fixed_sequence_dataset(d4rl_env)

    episodes = []
    current_episode = []
    cumulated_reward = 0.0

    episodes = []
    for s in sequence_dataset:
        episode = {k: torch.tensor(v).unsqueeze(1) for k, v in s.items()}
        nepisode = {}
        for k, v in episode.items():
            if k.endswith("s"):
                nepisode["env/" + k[:-1]] = v
            else:
                nepisode["env/" + k] = v
        if "env/timeout" in nepisode:
            nepisode["env/done"] = (
                nepisode["env/terminal"] + nepisode["env/timeout"]
            ).bool()
        else:
            nepisode["env/done"] = (nepisode["env/terminal"]).bool()

        if "env/observation" in nepisode:
            nepisode["env/env_obs"] = nepisode.pop("env/observation")
        nepisode["env/done"][-1] = True
        nepisode["env/initial_state"] = nepisode["env/done"].clone()
        nepisode["env/initial_state"].fill_(False)
        nepisode["env/initial_state"][0] = True
        nepisode["action"] = nepisode.pop("env/action")
        nepisode["env/timestep"] = torch.arange(
            nepisode["env/done"].size()[0]
        ).unsqueeze(1)
        nepisode["env/reward"][1:] = nepisode["env/reward"][:-1].clone()
        nepisode["env/reward"][0] = 0.0
        nepisode["env/cumulated_reward"] = torch.zeros(
            nepisode["env/done"].size()[0], 1
        )
        cr = 0.0
        for t in range(nepisode["env/done"].size()[0]):
            cr += nepisode["env/reward"][t].item()
            nepisode["env/cumulated_reward"][t] = cr
        episodes.append(nepisode)

    max_length = max([e["env/reward"].size()[0] for e in episodes])
    print("\t max episode length = ", max_length)
    print("\t n episodes = ", len(episodes))

    n_skip=0
    for e in episodes:
        l=e["env/reward"].size()[0]
        if l==0:
            n_skip+=1
            continue
        for k,v in e.items():
            ts=v.size()[0]
            if ts<max_length:
                v.resize_(max_length,*(v.size()[1:]))
                v[ts:]=0
    print("\tSkip ",n_skip," trajectories of size = 0")
    workspace=Workspace()
    f_episode={}
    for k in episodes[0]:
        vals=[e[k] for e in episodes]
        workspace.set_full(k,torch.cat(vals,dim=1))

    return workspace