in EAIEvaluation/HiTUT/hitut_train/custom_dataset.py [0:0]
def process_actions(self, ex, traj):
def get_normalized_arg(a, level):
if level == 'high':
arg = a['discrete_action']['args']
if arg == [] or arg == ['']:
return 'None'
else:
arg =arg[-1] #argument for action PutObject is the receptacle (2nd item in the list)
elif level == 'low':
# print(a['api_action'])
if 'objectId' in a['api_action'] and len(a['api_action']['objectId']) > 0:
if a['api_action']['action'] == 'PutObject':
#arg = a['api_action']['receptacleObjectId'].split('|')[0]
arg = a['api_action']['objectId'].split('|')[0]
elif len(a['api_action']['objectId'].split('|')) == 4:
arg = a['api_action']['objectId'].split('|')[0]
else:
arg = a['api_action']['objectId'].split('|')[4].split('_')[0]
else:
return 'None'
if arg in OBJECTS_LOWER_TO_UPPER:
arg = OBJECTS_LOWER_TO_UPPER[arg]
# fix high argument for sliced objects
if level == 'high' and arg in {'Apple', 'Bread', 'Lettuce', 'Potato', 'Tomato'} and \
'objectId' in a['planner_action'] and 'Sliced' in a['planner_action']['objectId']:
arg += 'Sliced'
return arg
def fix_missing_high_pddl_end_action(ex):
'''
appends a terminal action to a sequence of high-level actions
'''
if ex['plan']['high_pddl'][-1]['planner_action']['action'] != 'End':
ex['plan']['high_pddl'].append({
'discrete_action': {'action': 'NoOp', 'args': []},
'planner_action': {'value': 1, 'action': 'End'},
'high_idx': len(ex['plan']['high_pddl'])
})
# deal with missing end high-level action
fix_missing_high_pddl_end_action(ex)
# process high-level actions
picked = None
actions, args = [SOS], ['None']
for idx, a in enumerate(ex['plan']['high_pddl']):
high_action = a['discrete_action']['action']
high_arg = get_normalized_arg(a, 'high')
# change destinations into ones can be inferred
# e.g. For task "clean a knife" turn GotoLocation(SideTable) to GotoLocation(Knife)
if high_action == 'GotoLocation' and idx+1 < len(ex['plan']['high_pddl']):
next_a = ex['plan']['high_pddl'][idx+1]
if next_a['discrete_action']['action'] == 'PickupObject':
next_high_arg = get_normalized_arg(next_a, 'high')
high_arg = next_high_arg
# fix argument of sliced object for Clean, Cool and Heat
if high_action == 'PickupObject':
picked = high_arg
if high_action == 'PutObject':
picked = None
if picked is not None and 'Sliced' in picked and picked[:-6] == high_arg:
high_arg = picked
actions.append(high_action)
args.append(high_arg)
self.stats['high action steps'][len(actions)] += 1
# high actions to action decoder input ids (including all task special tokens)
traj['high'] = {}
traj['high']['dec_in_high_actions'] = self.dec_in_vocab.seq_encode(actions)
traj['high']['dec_in_high_args'] = self.dec_in_vocab.seq_encode(args)
# high actions to high action decoder output ids
traj['high']['dec_out_high_actions'] = self.dec_out_vocab_high.seq_encode(actions)[1:]
traj['high']['dec_out_high_args'] = self.dec_out_vocab_arg.seq_encode(args)[1:]
# process low-level actions
num_hl_actions = len(ex['plan']['high_pddl'])
# temporally aligned with HL actions
traj['low'] = {}
traj['low']['dec_in_low_actions'] = [list() for _ in range(num_hl_actions)]
traj['low']['dec_in_low_args'] = [list() for _ in range(num_hl_actions)]
traj['low']['dec_out_low_actions'] = [list() for _ in range(num_hl_actions)]
traj['low']['dec_out_low_args'] = [list() for _ in range(num_hl_actions)]
traj['low']['bbox'] = [list() for _ in range(num_hl_actions)]
traj['low']['centroid'] = [list() for _ in range(num_hl_actions)]
traj['low']['mask'] = [list() for _ in range(num_hl_actions)]
traj['low']['interact'] = [list() for _ in range(num_hl_actions)]
low_actions = [list() for _ in range(num_hl_actions)]
low_args = [list() for _ in range(num_hl_actions)]
prev_high_idx = -1
for idx, a in enumerate(ex['plan']['low_actions']):
# high-level action index (subgoals)
high_idx = a['high_idx']
if high_idx != prev_high_idx:
# add NoOp to indicate the terimination of low-level action prediction
low_actions[prev_high_idx].append('NoOp')
low_args[prev_high_idx].append('None')
# add the high-level action name as the first input of low-level action decoding
high_action = ex['plan']['high_pddl'][high_idx]
high_arg = get_normalized_arg(high_action, 'high')
low_actions[high_idx].append(high_action['discrete_action']['action'])
low_args[high_idx].append(high_arg)
prev_high_idx = high_idx
low_arg = get_normalized_arg(a, 'low')
low_action = a['discrete_action']['action']
if '_' in low_action:
low_action = low_action.split('_')[0]
low_actions[high_idx].append(low_action)
low_args[high_idx].append(low_arg)
# low-level bounding box (not used in the model)
if 'bbox' in a['discrete_action']['args']:
traj['low']['bbox'][high_idx].append(a['discrete_action']['args']['bbox'])
xmin, ymin, xmax, ymax = [float(x) if x != 'NULL' else -1 for x in a['discrete_action']['args']['bbox']]
traj['low']['centroid'][high_idx].append([
(xmin + (xmax - xmin) / 2) / self.image_size,
(ymin + (ymax - ymin) / 2) / self.image_size,
])
else:
traj['low']['bbox'][high_idx].append([])
traj['low']['centroid'][high_idx].append([])
# low-level interaction mask (Note: this mask needs to be decompressed)
mask = a['discrete_action']['args']['mask'] if 'mask' in a['discrete_action']['args'] else None
traj['low']['mask'][high_idx].append(mask)
# interaction validity
has_interact = 0 if low_action in NON_INTERACT_ACTIONS else 1
traj['low']['interact'][high_idx].append(has_interact)
# add termination indicator for the last low-level action sequence
low_actions[high_idx].append('NoOp')
low_args[high_idx].append('None')
for high_idx in range(num_hl_actions):
actions, args = low_actions[high_idx], low_args[high_idx]
traj['low']['dec_in_low_actions'][high_idx] = self.dec_in_vocab.seq_encode(actions)
traj['low']['dec_in_low_args'][high_idx] = self.dec_in_vocab.seq_encode(args)
traj['low']['dec_out_low_actions'][high_idx] = self.dec_out_vocab_low.seq_encode(actions)[1:]
traj['low']['dec_out_low_args'][high_idx] = self.dec_out_vocab_arg.seq_encode(args)[1:]
self.stats['low action steps'][len(actions)] += 1
# check alignment between step-by-step language and action sequence segments
action_low_seg_len = num_hl_actions
lang_instr_seg_len = len(traj['lang']['instr'])
seg_len_diff = action_low_seg_len - lang_instr_seg_len
if seg_len_diff != 1:
assert (seg_len_diff == 2) # sometimes the alignment is off by one ¯\_(ツ)_/¯
# print('Non align data file:', traj['raw_path'])
# Because 1) this bug only in a few trajs 2) merge is very troublesome
# we simply duplicate the last language instruction to align
traj['lang']['instr_tokenize'].append(traj['lang']['instr_tokenize'][-1])
traj['lang']['instr'].append(traj['lang']['instr'][-1])