in interaction_exploration/tools/plot_results.py [0:0]
def parse_episodes(model, episodes_dir, K):
episodes = list(glob.glob(f'{episodes_dir}/*.pth'))
episodes = [torch.load(ep) for ep in episodes]
episodes = [ep for ep in episodes if ep['scene_id'] in test_scenes] # keep only test episodes
if len(episodes)==0:
return {}
# normalize coverage by oracle interactions and organize by scene
episodes_by_scene = collections.defaultdict(list)
for episode in episodes:
N = max_interactons[episode['scene_id']]
episode['rewards'] = [step['reward']/N for step in episode['stats']['step_info']]
episodes_by_scene[episode['scene_id']].append(episode)
episodes = []
for scene in episodes_by_scene:
episodes += episodes_by_scene[scene][:K]
return episodes