def compute_masks()

in envs/thor_beacons.py [0:0]


    def compute_masks(self, history=None, out_sz=None):

        info = {'scene':self.scene, 'episode':self.episode_id}
        out_sz = out_sz or self.mask_sz
        history = history or self.history

        if len(history)==0:
            return torch.zeros(1, 3, out_sz, out_sz), torch.zeros(1, len(self.interactions), 2, out_sz, out_sz).byte(), torch.zeros(1, 5), info

        # Discard timesteps with redundant positions (same visual content)
        # each time-step will admit only some beacons
        positive_beacons = [beacon for beacon in self.beacons if beacon['success']]

        pose_last_seen_time = {hist['pose']:-1 for hist in history}
        beacon_history = []
        for t, hist in enumerate(history):

            # Seen this pose before but state has not changed
            if pose_last_seen_time[hist['pose']] != -1:
                positive_beacons_since_last_visit = [beacon for beacon in positive_beacons if beacon['t']>pose_last_seen_time[hist['pose']] and beacon['t']<=t and beacon['target'] in hist['visible_objs']]
                if len(positive_beacons_since_last_visit)==0:
                    continue

            pose_last_seen_time[hist['pose']] = t
            beacon_history.append(hist)

        #-----------------------------------------------------------------------------------------------------------------#

        beacon_mask = torch.zeros(len(beacon_history), len(self.interactions), 2, self.frame_sz, self.frame_sz)

        beacons_by_key = collections.defaultdict(list)
        for beacon in self.beacons:
            key = (beacon['action'], beacon['success'])
            beacons_by_key[key].append(beacon)

        for t, hist in enumerate(beacon_history):

            inv_obj = hist['inv_obj']
            for ch, action in enumerate(self.interactions):
                pos_masks = [hist['instance_masks'][beacon['target']] for beacon in beacons_by_key[(action, True)] if beacon['target'] in hist['instance_masks'] and beacon['target']!=inv_obj]
                neg_masks = [hist['instance_masks'][beacon['target']] for beacon in beacons_by_key[(action, False)] if beacon['target'] in hist['instance_masks'] and beacon['target']!=inv_obj]

                if len(neg_masks)>0:
                    neg_mask = torch.from_numpy(sum(neg_masks).astype(bool)).byte()
                    beacon_mask[t, ch, 0][neg_mask] = 1
                if len(pos_masks)>0:
                    pos_mask = torch.from_numpy(sum(pos_masks).astype(bool)).byte()
                    beacon_mask[t, ch, 1][pos_mask] = 1


        frames = np.stack([hist['frame'] for hist in beacon_history], 0)
        frames = torch.from_numpy(frames).float().permute(0, 3, 1, 2)/255
        poses = torch.Tensor([hist['pose'] for hist in beacon_history])
        frames, beacon_mask = resize_results(frames, beacon_mask, out_sz)
            
        return frames, beacon_mask, poses, info