def compute_masks()

in envs/thor_beacons.py [0:0]


    def compute_masks(self, history=None, out_sz=None):

        info = {'scene':self.scene, 'episode':self.episode_id}
        out_sz = out_sz or self.mask_sz
        history = history or self.history

        if len(history)==0 or len(self.beacons)==0:
            return torch.zeros(1, 3, out_sz, out_sz), torch.zeros(1, len(self.interactions), 2, out_sz, out_sz).byte(), torch.zeros(1, 5), info

        # each hist in history will have different beacons
        history = self.update_beacon_coordinates(history) # 4.2s

        P = torch.stack([hist['P'] for hist in history], 0) # (T=128, 3, H, W)) --> 511 x 300*300*3

        # get wdist to every beacon (across time), and maintain a pt-->index map
        uniq_beacon_pts = set()
        for t, hist in enumerate(history):
            uniq_beacon_pts |= set([beacon['pt'] for beacon in hist['beacons']])
        uniq_beacon_pts = [(0, 0, 0)] if len(uniq_beacon_pts)==0 else sorted(uniq_beacon_pts)
        wdist_pt_to_idx = {pt:idx for idx, pt in enumerate(uniq_beacon_pts)}

        B = torch.Tensor(uniq_beacon_pts) # (N, 3)
        P_flat = rearrange(P, 't p h w -> (t h w) p') # (THW, 3) torch.Size([45990000, 3])

        wdist = torch.cdist(P_flat, B) # (THW, N)
        wdist = rearrange(wdist, '(t h w) n -> t h w n', t=P.shape[0], h=self.mask_sz, w=self.mask_sz) # (T, H, W, N)

        pose_last_seen_time = {hist['pose']:-1 for hist in history}
        beacon_history = []
        for t, hist in enumerate(history):

            beacons_t = hist['beacons']
            if len(beacons_t)==0:
                continue

            # Seen this pose before but state has not changed
            if pose_last_seen_time[hist['pose']] != -1:
                positive_beacons_since_last_visit = [beacon for beacon in beacons_t if beacon['t']>pose_last_seen_time[hist['pose']] and beacon['t']<=t and beacon['success']]
                if len(positive_beacons_since_last_visit)==0:
                    continue

            pose_last_seen_time[hist['pose']] = t
            hist['wdist'] = wdist[t]
            beacon_history.append(hist)

        if len(beacon_history)==0:
            return torch.zeros(1, 3, out_sz, out_sz), torch.zeros(1, len(self.interactions), 2, out_sz, out_sz).byte(), torch.zeros(1, 5), info


        beacon_mask = torch.zeros(len(beacon_history), len(self.interactions), 2, self.mask_sz, self.mask_sz)

        for ch, action in enumerate(self.interactions):

            for t in range(len(beacon_history)):

                hist = beacon_history[t]
                pos_inds_t = [wdist_pt_to_idx[beacon['pt']] for beacon in hist['beacons'] if beacon['action']==action and beacon['success']]
                neg_inds_t = [wdist_pt_to_idx[beacon['pt']] for beacon in hist['beacons'] if beacon['action']==action and not beacon['success']]

                if len(neg_inds_t)>0:
                    neg_mask = hist['wdist'][:, :, neg_inds_t].min(2)[0] < self.dist_thresh  # (H, W)
                    beacon_mask[t, ch, 0][neg_mask] = 1

                if len(pos_inds_t)>0:
                    pos_mask = hist['wdist'][:, :, pos_inds_t].min(2)[0] < self.dist_thresh  # (H, W)
                    beacon_mask[t, ch, 1][pos_mask] = 1

        #-----------------------------------------------------------------------------------------------------------------#

        frames = np.stack([hist['frame'] for hist in beacon_history], 0)
        frames = torch.from_numpy(frames).float().permute(0, 3, 1, 2)/255
        poses = torch.Tensor([hist['pose'] for hist in beacon_history])
        frames, beacon_mask = resize_results(frames, beacon_mask, out_sz)
        
        return frames, beacon_mask, poses, info