def prepare_data_instances()

in EAIEvaluation/HiTUT/hitut_train/custom_dataset.py [0:0]


    def prepare_data_instances(self):
        statistic = {'horizon': Counter() ,'rotation': Counter(), 'mani': Counter(), 'navi': Counter()}
        mani_subgoal = {}
        for split, tasks in self.dataset_splits.items():
            if 'test' in split:
                continue
            print('Preparing %s data instances'%split)
            high_instances = []
            low_instances_mani = []
            low_instances_navi = []
            low_mani_seed = []
            low_navi_seed = []
            # det_res_all = {}
            det_res_sep = {}

            cc=0
            for task in tqdm(tasks):
                cc+=1
                if self.args.fast_epoch and cc == 10:
                    break
                task_path = os.path.join(self.pp_path, split, task['task'])
                traj_path = os.path.join(task_path, 'ann_%d.json'%task['repeat_idx'])

                if not os.path.exists(traj_path):
                    os.rmdir(task_path)
                    continue

                with open(traj_path, 'r') as f:
                    traj = json.load(f)
                with open(os.path.join(traj['raw_path'], 'traj_data.json'), 'r') as f:
                    traj_raw = json.load(f)

                # TODO: get correct horizon and rotation
                init_action = traj_raw['scene']['init_action']

                # obj_det_path_all = os.path.join(task_path, 'bbox_cls_scores_all.json')
                obj_det_path_sep = os.path.join(task_path, 'bbox_cls_scores_sep.json')
                # with open(obj_det_path_all, 'r') as f:
                #     obj_det_all = json.load(f)
                with open(obj_det_path_sep, 'r') as f:
                    obj_det_sep = json.load(f)
                for img_idx in obj_det_sep:
                    # det_res_all[task_path+img_idx] = obj_det_all[img_idx]
                    det_res_sep[task_path+img_idx] = obj_det_sep[img_idx]

                # process the vision input
                num_high_without_NoOp = len(traj['lang']['instr'])
                for hidx in range(num_high_without_NoOp + 1):
                    lang_input = traj['lang']['goal']   # list of int
                    vision_input = traj['high']['images'][hidx]
                    actype_history_input = traj['high']['dec_in_high_actions'][:hidx+1]
                    arg_history_input = traj['high']['dec_in_high_args'][:hidx+1]
                    actype_output = traj['high']['dec_out_high_actions'][hidx]
                    arg_output = traj['high']['dec_out_high_args'][hidx]
                    instance = {
                        'path': task_path,
                        'high_idx': hidx,
                        'lang_input': lang_input,
                        'vision_input': vision_input,
                        'actype_history_input': actype_history_input,
                        'arg_history_input': arg_history_input,
                        'actype_output': actype_output,
                        'arg_output': arg_output,
                    }
                    high_instances.append(instance)

                horizon = init_action['horizon']
                rotation = init_action['rotation']
                for hidx in range(num_high_without_NoOp):
                    # check the category of subgoal
                    sg_type = self.dec_in_vocab.id2w(traj['low']['dec_in_low_actions'][hidx][0])
                    sg_arg = self.dec_in_vocab.id2w(traj['low']['dec_in_low_args'][hidx][0])
                    subgoal = '%s(%s)'%(sg_type, sg_arg)
                    low_action_seq = ' '.join(self.dec_in_vocab.seq_decode(traj['low']['dec_in_low_actions'][hidx][1:]))
                    add_to_mani = split == 'train' and sg_type != 'GotoLocation' and subgoal not in statistic['mani']
                    add_to_navi = split == 'train' and sg_type == 'GotoLocation' and subgoal not in statistic['navi']
                    if add_to_mani:
                        if sg_type not in mani_subgoal:
                            mani_subgoal[sg_type] = Counter()
                        mani_subgoal[sg_type][low_action_seq] += 1
                        if mani_subgoal[sg_type][low_action_seq] == 1:
                            mani_subgoal[sg_type][low_action_seq+' (objs)'] = [sg_arg]
                        elif sg_arg not in mani_subgoal[sg_type][low_action_seq+' (objs)']:
                            mani_subgoal[sg_type][low_action_seq+' (objs)'].append(sg_arg)

                    lang_input = traj['lang']['instr'][hidx]
                    num_low_steps = len(traj['low']['dec_out_low_actions'][hidx])
                    vis_history = []
                    for low_idx in range(num_low_steps):
                        vision_input = traj['low']['images'][hidx][low_idx]
                        actype_history_input = traj['low']['dec_in_low_actions'][hidx][:low_idx+1]
                        arg_history_input = traj['low']['dec_in_low_args'][hidx][:low_idx+1]
                        actype_output = traj['low']['dec_out_low_actions'][hidx][low_idx]
                        arg_output = traj['low']['dec_out_low_args'][hidx][low_idx]
                        try:
                            interact = traj['low']['interact'][hidx][low_idx]
                        except:
                            interact = 0
                        if actype_history_input[0] == 3: # gotolocation
                            target_obj = self.dec_in_vocab.id2w(arg_history_input[0])
                            detected_objs = obj_det_sep[str(vision_input)]['class']
                            visible = 0 if target_obj not in detected_objs else 1
                            reached = 0
                            if low_idx == (num_low_steps - 1):
                                reached = 1
                            if low_idx == (num_low_steps - 2):
                                action = self.dec_out_vocab_low.id2w(traj['low']['dec_out_low_actions'][hidx][low_idx])
                                if action in {'LookDown', 'LookUp'} and visible:
                                    reached = 1
                            progress = (low_idx+1)/num_low_steps
                        else:
                            visible, reached, progress = -1, -1, -1

                        instance = {
                            'path': task_path,
                            'high_idx': hidx,
                            'low_idx': low_idx,
                            'interact': interact,
                            'lang_input': lang_input,
                            'vision_input': vision_input,
                            'actype_history_input': actype_history_input,
                            'arg_history_input': arg_history_input,
                            'vis_history_input': copy.deepcopy(vis_history),
                            'actype_output': actype_output,
                            'arg_output': arg_output,
                            'visible': visible,
                            'reached': reached,
                            'progress': progress,
                            'rotation': (rotation%360)/90,
                            'horizon': horizon/15,
                        }
                        statistic['horizon'][horizon] += 1
                        statistic['rotation'][rotation] += 1
                        if self.dec_out_vocab_low.id2w(actype_output) == 'RotateRight':
                            rotation +=  90
                        elif self.dec_out_vocab_low.id2w(actype_output) == 'RotateLeft':
                            rotation -=  90
                        elif self.dec_out_vocab_low.id2w(actype_output) == 'LookUp':
                            horizon +=  15
                        elif self.dec_out_vocab_low.id2w(actype_output) == 'LookDown':
                            horizon -= 15
                        vis_history.append(vision_input)

                        if actype_history_input[0] == 3: # gotolocation
                            low_instances_navi.append(instance)
                        else:
                            low_instances_mani.append(instance)

                        if add_to_mani:
                            low_mani_seed.append(instance)
                        if add_to_navi:
                            low_navi_seed.append(instance)

                    statistic['mani'][subgoal] += 1
                    statistic['navi'][subgoal] += 1

            print('high len:', len(high_instances))
            print('low mani len:', len(low_instances_mani))
            print('low navi len:', len(low_instances_navi))
            statistic['%s high len'%split] = len(high_instances)
            statistic['%s low-mani len'%split] = len(low_instances_mani)
            statistic['%s low-navi len'%split] = len(low_instances_navi)
            if split == 'train':
                statistic['train low-navi seed len'] = len(low_navi_seed)
                statistic['train low-mani seed len'] = len(low_mani_seed)

            high_save_path = os.path.join(self.pp_path, '%s_high_action_instances.json'%split)
            with open(high_save_path, 'w') as f:
                json.dump(high_instances, f, indent=2)
            low_save_path = os.path.join(self.pp_path, '%s_low_action_instances_mani.json'%split)
            with open(low_save_path, 'w') as f:
                json.dump(low_instances_mani, f, indent=2)
            with open(low_save_path.replace('mani', 'navi'), 'w') as f:
                json.dump(low_instances_navi, f, indent=2)

            if split == 'train':
                low_save_path = os.path.join(self.pp_path, '%s_low_action_seed_mani.json'%split)
                with open(low_save_path, 'w') as f:
                    json.dump(low_mani_seed, f, indent=2)
                with open(low_save_path.replace('mani', 'navi'), 'w') as f:
                    json.dump(low_navi_seed, f, indent=2)
                with open(os.path.join(self.pp_path, 'mani_subgoals.json'), 'w') as f:
                    json.dump(mani_subgoal, f, indent=2)

            # det_all_save_path = os.path.join(self.pp_path, '%s_det_res_all.json'%split)
            # with open(det_all_save_path, 'w') as f:
            #     json.dump(det_res_all, f, indent=2)
            det_sep_save_path = os.path.join(self.pp_path, '%s_det_res_sep.json'%split)
            with open(det_sep_save_path, 'w') as f:
                json.dump(det_res_sep, f, indent=2)

        for k,v in statistic.items():
            if isinstance(v, dict):
                statistic[k] = dict(sorted(v.items(), key=lambda item: item[0]))
        with open(os.path.join(self.pp_path, 'data_statistics.json'), 'w') as f:
            json.dump(statistic, f, indent=2)