def create_babyai_dataset()

in data/envs/babyai/create_babyai_dataset.py [0:0]


def create_babyai_dataset(task_name, max_num_episodes):
    env_id = TASK_NAME_TO_ENV_ID[task_name]
    env = gym.make(env_id)
    data = {"text_observations": [], "discrete_observations": [], "discrete_actions": [], "rewards": []}

    print("Starting trajectories generation")
    while len(data["rewards"]) < max_num_episodes:
        print(f"Episode {len(data['rewards']) + 1}/{max_num_episodes}")

        try:
            episode = generate_episode(env)
        except Exception as e:
            print(e)
            continue

        for k, v in episode.items():
            data[k].append(v)

    print(f"Finished generation. Generated {len(data['rewards'])} transitions.")

    features = Features(
        {
            "text_observations": Sequence(Value("string")),
            "discrete_observations": Sequence(Sequence(Value("int64"))),
            "discrete_actions": Sequence(Value("int64")),
            "rewards": Sequence(Value("float32")),
        }
    )
    dataset = Dataset.from_dict(data, features)
    print("Saving dataset...")
    dataset.save_to_disk(task_name)
    print("Saved dataset!")

    print("Pushing dataset to hub...")
    dataset = dataset.train_test_split(test_size=0.02)
    dataset.push_to_hub("jat-project/jat-dataset", task_name, branch="additional_babyai_tasks")
    print("Pushed dataset to hub!")