def process_actions()

in EAIEvaluation/HiTUT/hitut_train/custom_dataset.py [0:0]


    def process_actions(self, ex, traj):

        def get_normalized_arg(a, level):
            if level == 'high':
                arg = a['discrete_action']['args']

                if arg == [] or arg == ['']:
                    return 'None'
                else:
                    arg =arg[-1]   #argument for action PutObject is the receptacle (2nd item in the list)
            elif level == 'low':
                # print(a['api_action'])
                if 'objectId' in a['api_action'] and len(a['api_action']['objectId']) > 0:
                    if a['api_action']['action'] == 'PutObject':
                        #arg = a['api_action']['receptacleObjectId'].split('|')[0]
                        arg = a['api_action']['objectId'].split('|')[0]
                    elif len(a['api_action']['objectId'].split('|')) == 4:
                        arg = a['api_action']['objectId'].split('|')[0]
                    else:
                        arg = a['api_action']['objectId'].split('|')[4].split('_')[0]
                else:
                    return 'None'

            if arg in OBJECTS_LOWER_TO_UPPER:
                arg = OBJECTS_LOWER_TO_UPPER[arg]

            # fix high argument for sliced objects
            if level == 'high' and arg in {'Apple', 'Bread', 'Lettuce', 'Potato', 'Tomato'} and \
                'objectId' in a['planner_action'] and 'Sliced' in a['planner_action']['objectId']:
                arg += 'Sliced'
            return arg

        def fix_missing_high_pddl_end_action(ex):
            '''
            appends a terminal action to a sequence of high-level actions
            '''
            if ex['plan']['high_pddl'][-1]['planner_action']['action'] != 'End':
                ex['plan']['high_pddl'].append({
                    'discrete_action': {'action': 'NoOp', 'args': []},
                    'planner_action': {'value': 1, 'action': 'End'},
                    'high_idx': len(ex['plan']['high_pddl'])
                })

        # deal with missing end high-level action
        fix_missing_high_pddl_end_action(ex)

        # process high-level actions
        picked = None
        actions, args = [SOS], ['None']
        for idx, a in enumerate(ex['plan']['high_pddl']):
            high_action = a['discrete_action']['action']
            high_arg = get_normalized_arg(a, 'high')
            # change destinations into ones can be inferred
            # e.g. For task "clean a knife" turn GotoLocation(SideTable) to GotoLocation(Knife)
            if high_action == 'GotoLocation' and idx+1 < len(ex['plan']['high_pddl']):
                next_a = ex['plan']['high_pddl'][idx+1]
                if next_a['discrete_action']['action'] == 'PickupObject':
                    next_high_arg = get_normalized_arg(next_a, 'high')
                    high_arg = next_high_arg

            # fix argument of sliced object for Clean, Cool and Heat
            if high_action == 'PickupObject':
                picked = high_arg
            if high_action == 'PutObject':
                picked = None
            if picked is not None and 'Sliced' in picked and picked[:-6] == high_arg:
                high_arg = picked

            actions.append(high_action)
            args.append(high_arg)
        self.stats['high action steps'][len(actions)] += 1

        # high actions to action decoder input ids (including all task special tokens)
        traj['high'] = {}
        traj['high']['dec_in_high_actions'] = self.dec_in_vocab.seq_encode(actions)
        traj['high']['dec_in_high_args'] = self.dec_in_vocab.seq_encode(args)
        # high actions to high action decoder output ids
        traj['high']['dec_out_high_actions'] = self.dec_out_vocab_high.seq_encode(actions)[1:]
        traj['high']['dec_out_high_args'] = self.dec_out_vocab_arg.seq_encode(args)[1:]

        # process low-level actions
        num_hl_actions = len(ex['plan']['high_pddl'])
        # temporally aligned with HL actions
        traj['low'] = {}
        traj['low']['dec_in_low_actions'] = [list() for _ in range(num_hl_actions)]
        traj['low']['dec_in_low_args'] = [list() for _ in range(num_hl_actions)]
        traj['low']['dec_out_low_actions'] = [list() for _ in range(num_hl_actions)]
        traj['low']['dec_out_low_args'] = [list() for _ in range(num_hl_actions)]
        traj['low']['bbox'] = [list() for _ in range(num_hl_actions)]
        traj['low']['centroid'] = [list() for _ in range(num_hl_actions)]
        traj['low']['mask'] = [list() for _ in range(num_hl_actions)]
        traj['low']['interact'] = [list() for _ in range(num_hl_actions)]

        low_actions = [list() for _ in range(num_hl_actions)]
        low_args = [list() for _ in range(num_hl_actions)]
        prev_high_idx = -1
        for idx, a in enumerate(ex['plan']['low_actions']):
            # high-level action index (subgoals)
            high_idx = a['high_idx']

            if high_idx != prev_high_idx:
                # add NoOp to indicate the terimination of low-level action prediction
                low_actions[prev_high_idx].append('NoOp')
                low_args[prev_high_idx].append('None')
                # add the high-level action name as the first input of low-level action decoding
                high_action = ex['plan']['high_pddl'][high_idx]
                high_arg = get_normalized_arg(high_action, 'high')
                low_actions[high_idx].append(high_action['discrete_action']['action'])
                low_args[high_idx].append(high_arg)
                prev_high_idx = high_idx

            low_arg = get_normalized_arg(a, 'low')
            low_action = a['discrete_action']['action']
            if '_' in low_action:
                low_action = low_action.split('_')[0]
            low_actions[high_idx].append(low_action)
            low_args[high_idx].append(low_arg)

            # low-level bounding box (not used in the model)
            if 'bbox' in a['discrete_action']['args']:
                traj['low']['bbox'][high_idx].append(a['discrete_action']['args']['bbox'])
                xmin, ymin, xmax, ymax = [float(x) if x != 'NULL' else -1 for x in a['discrete_action']['args']['bbox']]
                traj['low']['centroid'][high_idx].append([
                    (xmin + (xmax - xmin) / 2) / self.image_size,
                    (ymin + (ymax - ymin) / 2) / self.image_size,
                    ])
            else:
                traj['low']['bbox'][high_idx].append([])
                traj['low']['centroid'][high_idx].append([])

            # low-level interaction mask (Note: this mask needs to be decompressed)
            mask = a['discrete_action']['args']['mask'] if 'mask' in a['discrete_action']['args'] else None
            traj['low']['mask'][high_idx].append(mask)

            # interaction validity
            has_interact = 0 if low_action in NON_INTERACT_ACTIONS else 1
            traj['low']['interact'][high_idx].append(has_interact)

        # add termination indicator for the last low-level action sequence
        low_actions[high_idx].append('NoOp')
        low_args[high_idx].append('None')

        for high_idx in range(num_hl_actions):
            actions, args = low_actions[high_idx], low_args[high_idx]
            traj['low']['dec_in_low_actions'][high_idx] = self.dec_in_vocab.seq_encode(actions)
            traj['low']['dec_in_low_args'][high_idx] = self.dec_in_vocab.seq_encode(args)
            traj['low']['dec_out_low_actions'][high_idx] = self.dec_out_vocab_low.seq_encode(actions)[1:]
            traj['low']['dec_out_low_args'][high_idx] = self.dec_out_vocab_arg.seq_encode(args)[1:]
            self.stats['low action steps'][len(actions)] += 1


        # check alignment between step-by-step language and action sequence segments
        action_low_seg_len = num_hl_actions
        lang_instr_seg_len = len(traj['lang']['instr'])
        seg_len_diff = action_low_seg_len - lang_instr_seg_len
        if seg_len_diff != 1:
            assert (seg_len_diff == 2) # sometimes the alignment is off by one  ¯\_(ツ)_/¯
            # print('Non align data file:', traj['raw_path'])
            # Because 1) this bug only in a few trajs 2) merge is very troublesome
            # we simply duplicate the last language instruction to align
            traj['lang']['instr_tokenize'].append(traj['lang']['instr_tokenize'][-1])
            traj['lang']['instr'].append(traj['lang']['instr'][-1])