def process_images()

in EAIEvaluation/HiTUT/hitut_train/custom_dataset.py [0:0]


    def process_images(self, ex, traj):
        def name2id(img_name):
            return int(img_name.split('.')[0])
        num_hl_actions = len(ex['plan']['high_pddl'])
        traj['high']['images'] = []
        traj['low']['images'] = [list() for _ in range(num_hl_actions)]

        prev_high_idx, prev_low_idx = -1, -1
        for img in ex['images']:
            high_idx, low_idx = img['high_idx'], img['low_idx']
            if high_idx != prev_high_idx:
                # reach a new high action, use the current image as the visual observation
                # 1) to predict the current high action
                # 2) to predict the termination low action of the previous high action
                traj['high']['images'].append(name2id(img['image_name']))
                if prev_high_idx >= 0:
                    traj['low']['images'][prev_high_idx].append(name2id(img['image_name']))
                prev_high_idx = high_idx
            if low_idx != prev_low_idx:
                # reach a new low action, use the current image as the visual observation
                # to predict the current low action
                traj['low']['images'][high_idx].append(name2id(img['image_name']))
                prev_low_idx = low_idx

        if not self.args.fix_traj:
            current_last_image = ex['images'][-1]
            next_image_name = "{0:09d}.png".format((int(current_last_image["image_name"].split('.')[0]) + 1))
            next_image = {'high_idx': current_last_image['high_idx'] + 1, 
                                'image_name': next_image_name, 
                                'low_idx': current_last_image['low_idx'] + 1}
            img = next_image

        # add the last frame for predicting termination action NoOp
        traj['high']['images'].append(name2id(img['image_name']))
        
        # TODO: comment this line for the correct data processing by Luminous
        traj['low']['images'][high_idx].append(name2id(img['image_name']))


        # length check
        assert(len(traj['high']['images']) == len(traj['high']['dec_in_high_actions']) - 1)
        # for hi in range(num_hl_actions):
        #     print(len(traj['low']['images'][hi]), len(traj['low']['dec_in_low_actions'][hi]) - 1)
        for hi in range(num_hl_actions):
            assert(len(traj['low']['images'][hi]) == len(traj['low']['dec_in_low_actions'][hi]) - 1)

        # use mask rcnn for object detection
        if traj['repeat_idx'] != 0 or self.args.skip_detection:
            return   # for different annotations only need do object detection once

        all_imgs = [i for i in traj['high']['images']]
        for img_list in traj['low']['images']:
            all_imgs += img_list
        all_imgs = sorted(list(set(all_imgs)))

        dp = traj['raw_path']#.replace('data/', '')
        results_masks, results_others = {}, {}

        for model_type in ['sep']:
            for idx in range(0, len(all_imgs), self.batch_size):
                batch = all_imgs[idx: idx + self.batch_size]
                #img_path_batch = [self.image_data[dp][b] for b in batch]
                img_path_batch = [os.path.join(dp, "raw_images","{0:09d}.png".format(b)) for b in batch]

                lock.acquire()
                if model_type == 'all':
                    masks, boxes, classes, scores = self.mrcnn.get_mrcnn_preds_all(img_path_batch)
                else:
                    masks, boxes, classes, scores = self.mrcnn.get_mrcnn_preds_sep(img_path_batch)
                lock.release()

                # results_masks += masks
                for i, img_name in enumerate(batch):
                    results_others[img_name] = {
                        'bbox': [[int(coord) for coord in box] for box in boxes[i]],
                        'score': [float(s) for s in scores[i]],
                        'class': classes[i],
                        'label': None,
                    }
                    results_masks[img_name] = [np.packbits(m) for m in masks[i]]
                    self.stats['detection num'][len(classes[i])] += 1
                    self.stats['object num'][len([j for j in classes[i] if j in OBJECTS_DETECTOR])] += 1
                    self.stats['receptacle num'][len([j for j in classes[i] if j in STATIC_RECEPTACLES])] += 1

            # # get object grounding labels
            # for hidx, bbox_seq in enumerate(traj['low']['bbox']):
            #     for lidx, gt in enumerate(bbox_seq):
            #         if gt:
            #             self.interact_num += 1
            #             img_idx = traj['low']['images'][hidx][lidx]
            #             preds = results_others[img_idx]['bbox']
            #             if not preds:
            #                 continue
            #             max_iou = -1
            #             for obj_idx, pred in enumerate(preds):
            #                 iou = bb_IoU(pred, gt)
            #                 if iou > max_iou:
            #                     max_iou = iou
            #                     best_obj_id, best_pred = obj_idx, pred
            #             true_cls = ACTION_ARGS[traj['low']['dec_out_low_args'][hidx][lidx]]
            #             try:
            #                 pred_cls = results_others[img_idx]['class'][best_obj_id]
            #             except:
            #                 print('-'*10)
            #                 print('traj:', traj['raw_path'], 'img:', img_idx)
            #             if max_iou > 0.7 or true_cls in pred_cls or pred_cls in true_cls:
            #                 results_others[img_idx]['label'] = best_obj_id
            #                 self.good_detect_num[model_type] += 1
            #             # else:
            #             #     print('-'*30)
            #             #     print('traj:', traj['raw_path'], 'img:', img_idx)
            #             #     print('iou: %.3f'%max_iou)
            #             #     print('true class:', true_cls)
            #             #     print('pred class:', pred_cls)
            #             #     print('true:', gt)
            #             #     print('pred:', best_pred)

            # save object detection results
            pp_save_path = traj['pp_path']
            pk_save_path = os.path.join(pp_save_path, "masks_%s.pkl"%model_type)
            json_save_path = os.path.join(pp_save_path, "bbox_cls_scores_%s.json"%model_type)
            with open(pk_save_path, 'wb') as f:
                pickle.dump(results_masks, f)
            with open(json_save_path, 'w') as f:
                json.dump(results_others, f, indent=4)