def create_dataset()

in pyhanabi/tools/action_matrix.py [0:0]


def create_dataset(agent, sad, device):
    # use it in "vdn" mode so that trajecoties from the same game are
    # grouped together
    agent = agent.clone(device, {"vdn": True})
    runner = rela.BatchRunner(
        agent, device, 100, ["act", "compute_priority"]
    )

    dataset_size = 1000
    replay_buffer = rela.RNNPrioritizedReplay(
        dataset_size,  # args.dataset_size,
        1,  # args.seed,
        0,  # args.priority_exponent, uniform sampling
        1,  # args.priority_weight,
        0,  # args.prefetch,
    )

    num_thread = 100
    num_game_per_thread = 1
    max_len = 80
    actors = []
    for i in range(num_thread):
        # thread_actors = []
        actor = rela.R2D2Actor(
            runner,
            1,  # multi_step,
            num_game_per_thread,
            0.99,  # gamma,
            0.9,  # eta
            max_len,  # max_len,
            2,  # num_player
            replay_buffer,
        )
        actors.append(actor)

    eps = [0] # for _ in range(num_game_per_thread)]
    num_game = num_thread * num_game_per_thread
    games = create.create_envs(num_game, 1, 2, 5, 0, [0], max_len, sad, False, False)
    context, threads = create.create_threads(num_thread, num_game_per_thread, actors, games)

    runner.start()
    context.start()
    while replay_buffer.size() < dataset_size:
        print("collecting data from replay buffer:", replay_buffer.size())
        time.sleep(0.2)

    context.pause()

    # remove extra data
    for _ in range(2):
        data, unif = replay_buffer.sample(10, "cpu")
        replay_buffer.update_priority(unif.detach().cpu())
        time.sleep(0.2)

    print("dataset size:", replay_buffer.size())
    print("done about to return")
    return replay_buffer, agent, context