def parse_episodes()

in interaction_exploration/tools/plot_results.py [0:0]


def parse_episodes(model, episodes_dir, K):
    episodes = list(glob.glob(f'{episodes_dir}/*.pth'))
    episodes = [torch.load(ep) for ep in episodes]
    episodes = [ep for ep in episodes if ep['scene_id'] in test_scenes] # keep only test episodes
    if len(episodes)==0:
        return {}

    # normalize coverage by oracle interactions and organize by scene
    episodes_by_scene = collections.defaultdict(list)
    for episode in episodes:
        N = max_interactons[episode['scene_id']]
        episode['rewards'] = [step['reward']/N for step in episode['stats']['step_info']]
        episodes_by_scene[episode['scene_id']].append(episode)

    episodes = []
    for scene in episodes_by_scene:
        episodes += episodes_by_scene[scene][:K]

    return episodes