in envs/thor_beacons.py [0:0]
def compute_masks(self, history=None, out_sz=None):
info = {'scene':self.scene, 'episode':self.episode_id}
out_sz = out_sz or self.mask_sz
history = history or self.history
if len(history)==0:
return torch.zeros(1, 3, out_sz, out_sz), torch.zeros(1, len(self.interactions), 2, out_sz, out_sz).byte(), torch.zeros(1, 5), info
# Discard timesteps with redundant positions (same visual content)
# each time-step will admit only some beacons
positive_beacons = [beacon for beacon in self.beacons if beacon['success']]
pose_last_seen_time = {hist['pose']:-1 for hist in history}
beacon_history = []
for t, hist in enumerate(history):
# Seen this pose before but state has not changed
if pose_last_seen_time[hist['pose']] != -1:
positive_beacons_since_last_visit = [beacon for beacon in positive_beacons if beacon['t']>pose_last_seen_time[hist['pose']] and beacon['t']<=t and beacon['target'] in hist['visible_objs']]
if len(positive_beacons_since_last_visit)==0:
continue
pose_last_seen_time[hist['pose']] = t
beacon_history.append(hist)
#-----------------------------------------------------------------------------------------------------------------#
beacon_mask = torch.zeros(len(beacon_history), len(self.interactions), 2, self.frame_sz, self.frame_sz)
beacons_by_key = collections.defaultdict(list)
for beacon in self.beacons:
key = (beacon['action'], beacon['success'])
beacons_by_key[key].append(beacon)
for t, hist in enumerate(beacon_history):
inv_obj = hist['inv_obj']
for ch, action in enumerate(self.interactions):
pos_masks = [hist['instance_masks'][beacon['target']] for beacon in beacons_by_key[(action, True)] if beacon['target'] in hist['instance_masks'] and beacon['target']!=inv_obj]
neg_masks = [hist['instance_masks'][beacon['target']] for beacon in beacons_by_key[(action, False)] if beacon['target'] in hist['instance_masks'] and beacon['target']!=inv_obj]
if len(neg_masks)>0:
neg_mask = torch.from_numpy(sum(neg_masks).astype(bool)).byte()
beacon_mask[t, ch, 0][neg_mask] = 1
if len(pos_masks)>0:
pos_mask = torch.from_numpy(sum(pos_masks).astype(bool)).byte()
beacon_mask[t, ch, 1][pos_mask] = 1
frames = np.stack([hist['frame'] for hist in beacon_history], 0)
frames = torch.from_numpy(frames).float().permute(0, 3, 1, 2)/255
poses = torch.Tensor([hist['pose'] for hist in beacon_history])
frames, beacon_mask = resize_results(frames, beacon_mask, out_sz)
return frames, beacon_mask, poses, info