in envs/thor_beacons.py [0:0]
def compute_masks(self, history=None, out_sz=None):
info = {'scene':self.scene, 'episode':self.episode_id}
out_sz = out_sz or self.mask_sz
history = history or self.history
if len(history)==0 or len(self.beacons)==0:
return torch.zeros(1, 3, out_sz, out_sz), torch.zeros(1, len(self.interactions), 2, out_sz, out_sz).byte(), torch.zeros(1, 5), info
# each hist in history will have different beacons
history = self.update_beacon_coordinates(history) # 4.2s
P = torch.stack([hist['P'] for hist in history], 0) # (T=128, 3, H, W)) --> 511 x 300*300*3
# get wdist to every beacon (across time), and maintain a pt-->index map
uniq_beacon_pts = set()
for t, hist in enumerate(history):
uniq_beacon_pts |= set([beacon['pt'] for beacon in hist['beacons']])
uniq_beacon_pts = [(0, 0, 0)] if len(uniq_beacon_pts)==0 else sorted(uniq_beacon_pts)
wdist_pt_to_idx = {pt:idx for idx, pt in enumerate(uniq_beacon_pts)}
B = torch.Tensor(uniq_beacon_pts) # (N, 3)
P_flat = rearrange(P, 't p h w -> (t h w) p') # (THW, 3) torch.Size([45990000, 3])
wdist = torch.cdist(P_flat, B) # (THW, N)
wdist = rearrange(wdist, '(t h w) n -> t h w n', t=P.shape[0], h=self.mask_sz, w=self.mask_sz) # (T, H, W, N)
pose_last_seen_time = {hist['pose']:-1 for hist in history}
beacon_history = []
for t, hist in enumerate(history):
beacons_t = hist['beacons']
if len(beacons_t)==0:
continue
# Seen this pose before but state has not changed
if pose_last_seen_time[hist['pose']] != -1:
positive_beacons_since_last_visit = [beacon for beacon in beacons_t if beacon['t']>pose_last_seen_time[hist['pose']] and beacon['t']<=t and beacon['success']]
if len(positive_beacons_since_last_visit)==0:
continue
pose_last_seen_time[hist['pose']] = t
hist['wdist'] = wdist[t]
beacon_history.append(hist)
if len(beacon_history)==0:
return torch.zeros(1, 3, out_sz, out_sz), torch.zeros(1, len(self.interactions), 2, out_sz, out_sz).byte(), torch.zeros(1, 5), info
beacon_mask = torch.zeros(len(beacon_history), len(self.interactions), 2, self.mask_sz, self.mask_sz)
for ch, action in enumerate(self.interactions):
for t in range(len(beacon_history)):
hist = beacon_history[t]
pos_inds_t = [wdist_pt_to_idx[beacon['pt']] for beacon in hist['beacons'] if beacon['action']==action and beacon['success']]
neg_inds_t = [wdist_pt_to_idx[beacon['pt']] for beacon in hist['beacons'] if beacon['action']==action and not beacon['success']]
if len(neg_inds_t)>0:
neg_mask = hist['wdist'][:, :, neg_inds_t].min(2)[0] < self.dist_thresh # (H, W)
beacon_mask[t, ch, 0][neg_mask] = 1
if len(pos_inds_t)>0:
pos_mask = hist['wdist'][:, :, pos_inds_t].min(2)[0] < self.dist_thresh # (H, W)
beacon_mask[t, ch, 1][pos_mask] = 1
#-----------------------------------------------------------------------------------------------------------------#
frames = np.stack([hist['frame'] for hist in beacon_history], 0)
frames = torch.from_numpy(frames).float().permute(0, 3, 1, 2)/255
poses = torch.Tensor([hist['pose'] for hist in beacon_history])
frames, beacon_mask = resize_results(frames, beacon_mask, out_sz)
return frames, beacon_mask, poses, info