in affordance_seg/collect_dset.py [0:0]
def save_episode(self, observations):
# torch.Size([17, 3, 80, 80]) torch.Size([17, 7, 2, 80, 80]) torch.Size([17, 5]), info
frames, masks, poses, info = observations
# keep episodes only where annotations are present
neg_score = (masks[:, :, 0]).sum(2).sum(2) # (N, 7)
pos_score = (masks[:, :, 1]).sum(2).sum(2) # (N, 7)
scores = (pos_score*neg_score).sum(1) # (N, )
keep_idx = scores>0
if keep_idx.sum()==0:
return False
frames = frames[keep_idx]
masks = masks[keep_idx]
poses = poses[keep_idx]
episode = {'frames':frames, 'masks':masks, 'poses':poses, 'info':info}
scene, episode_id = episode['info']['scene'], episode['info']['episode']
out_dir = f'{self.config.OUT_DIR}/episodes/'
filename = f'{scene}_{episode_id}_data.npz'
np.savez_compressed(os.path.join(out_dir, filename),
frames=episode['frames'],
masks=episode['masks'],
poses=episode['poses'],
)
torch.save({'info': episode['info'],
}, os.path.join(out_dir, f'{scene}_{episode_id}_info.pth'))
# log information for dataset stats
key = (info['scene'], info['episode'])
log_episode = {'N_frames':frames.shape[0]}
self.episodes[key] = log_episode
print (f'Saved episode {scene} {episode_id} to {out_dir}')
return True