in EAIEvaluation/HiTUT/hitut_train/custom_dataset.py [0:0]
def prepare_data_instances(self):
statistic = {'horizon': Counter() ,'rotation': Counter(), 'mani': Counter(), 'navi': Counter()}
mani_subgoal = {}
for split, tasks in self.dataset_splits.items():
if 'test' in split:
continue
print('Preparing %s data instances'%split)
high_instances = []
low_instances_mani = []
low_instances_navi = []
low_mani_seed = []
low_navi_seed = []
# det_res_all = {}
det_res_sep = {}
cc=0
for task in tqdm(tasks):
cc+=1
if self.args.fast_epoch and cc == 10:
break
task_path = os.path.join(self.pp_path, split, task['task'])
traj_path = os.path.join(task_path, 'ann_%d.json'%task['repeat_idx'])
if not os.path.exists(traj_path):
os.rmdir(task_path)
continue
with open(traj_path, 'r') as f:
traj = json.load(f)
with open(os.path.join(traj['raw_path'], 'traj_data.json'), 'r') as f:
traj_raw = json.load(f)
# TODO: get correct horizon and rotation
init_action = traj_raw['scene']['init_action']
# obj_det_path_all = os.path.join(task_path, 'bbox_cls_scores_all.json')
obj_det_path_sep = os.path.join(task_path, 'bbox_cls_scores_sep.json')
# with open(obj_det_path_all, 'r') as f:
# obj_det_all = json.load(f)
with open(obj_det_path_sep, 'r') as f:
obj_det_sep = json.load(f)
for img_idx in obj_det_sep:
# det_res_all[task_path+img_idx] = obj_det_all[img_idx]
det_res_sep[task_path+img_idx] = obj_det_sep[img_idx]
# process the vision input
num_high_without_NoOp = len(traj['lang']['instr'])
for hidx in range(num_high_without_NoOp + 1):
lang_input = traj['lang']['goal'] # list of int
vision_input = traj['high']['images'][hidx]
actype_history_input = traj['high']['dec_in_high_actions'][:hidx+1]
arg_history_input = traj['high']['dec_in_high_args'][:hidx+1]
actype_output = traj['high']['dec_out_high_actions'][hidx]
arg_output = traj['high']['dec_out_high_args'][hidx]
instance = {
'path': task_path,
'high_idx': hidx,
'lang_input': lang_input,
'vision_input': vision_input,
'actype_history_input': actype_history_input,
'arg_history_input': arg_history_input,
'actype_output': actype_output,
'arg_output': arg_output,
}
high_instances.append(instance)
horizon = init_action['horizon']
rotation = init_action['rotation']
for hidx in range(num_high_without_NoOp):
# check the category of subgoal
sg_type = self.dec_in_vocab.id2w(traj['low']['dec_in_low_actions'][hidx][0])
sg_arg = self.dec_in_vocab.id2w(traj['low']['dec_in_low_args'][hidx][0])
subgoal = '%s(%s)'%(sg_type, sg_arg)
low_action_seq = ' '.join(self.dec_in_vocab.seq_decode(traj['low']['dec_in_low_actions'][hidx][1:]))
add_to_mani = split == 'train' and sg_type != 'GotoLocation' and subgoal not in statistic['mani']
add_to_navi = split == 'train' and sg_type == 'GotoLocation' and subgoal not in statistic['navi']
if add_to_mani:
if sg_type not in mani_subgoal:
mani_subgoal[sg_type] = Counter()
mani_subgoal[sg_type][low_action_seq] += 1
if mani_subgoal[sg_type][low_action_seq] == 1:
mani_subgoal[sg_type][low_action_seq+' (objs)'] = [sg_arg]
elif sg_arg not in mani_subgoal[sg_type][low_action_seq+' (objs)']:
mani_subgoal[sg_type][low_action_seq+' (objs)'].append(sg_arg)
lang_input = traj['lang']['instr'][hidx]
num_low_steps = len(traj['low']['dec_out_low_actions'][hidx])
vis_history = []
for low_idx in range(num_low_steps):
vision_input = traj['low']['images'][hidx][low_idx]
actype_history_input = traj['low']['dec_in_low_actions'][hidx][:low_idx+1]
arg_history_input = traj['low']['dec_in_low_args'][hidx][:low_idx+1]
actype_output = traj['low']['dec_out_low_actions'][hidx][low_idx]
arg_output = traj['low']['dec_out_low_args'][hidx][low_idx]
try:
interact = traj['low']['interact'][hidx][low_idx]
except:
interact = 0
if actype_history_input[0] == 3: # gotolocation
target_obj = self.dec_in_vocab.id2w(arg_history_input[0])
detected_objs = obj_det_sep[str(vision_input)]['class']
visible = 0 if target_obj not in detected_objs else 1
reached = 0
if low_idx == (num_low_steps - 1):
reached = 1
if low_idx == (num_low_steps - 2):
action = self.dec_out_vocab_low.id2w(traj['low']['dec_out_low_actions'][hidx][low_idx])
if action in {'LookDown', 'LookUp'} and visible:
reached = 1
progress = (low_idx+1)/num_low_steps
else:
visible, reached, progress = -1, -1, -1
instance = {
'path': task_path,
'high_idx': hidx,
'low_idx': low_idx,
'interact': interact,
'lang_input': lang_input,
'vision_input': vision_input,
'actype_history_input': actype_history_input,
'arg_history_input': arg_history_input,
'vis_history_input': copy.deepcopy(vis_history),
'actype_output': actype_output,
'arg_output': arg_output,
'visible': visible,
'reached': reached,
'progress': progress,
'rotation': (rotation%360)/90,
'horizon': horizon/15,
}
statistic['horizon'][horizon] += 1
statistic['rotation'][rotation] += 1
if self.dec_out_vocab_low.id2w(actype_output) == 'RotateRight':
rotation += 90
elif self.dec_out_vocab_low.id2w(actype_output) == 'RotateLeft':
rotation -= 90
elif self.dec_out_vocab_low.id2w(actype_output) == 'LookUp':
horizon += 15
elif self.dec_out_vocab_low.id2w(actype_output) == 'LookDown':
horizon -= 15
vis_history.append(vision_input)
if actype_history_input[0] == 3: # gotolocation
low_instances_navi.append(instance)
else:
low_instances_mani.append(instance)
if add_to_mani:
low_mani_seed.append(instance)
if add_to_navi:
low_navi_seed.append(instance)
statistic['mani'][subgoal] += 1
statistic['navi'][subgoal] += 1
print('high len:', len(high_instances))
print('low mani len:', len(low_instances_mani))
print('low navi len:', len(low_instances_navi))
statistic['%s high len'%split] = len(high_instances)
statistic['%s low-mani len'%split] = len(low_instances_mani)
statistic['%s low-navi len'%split] = len(low_instances_navi)
if split == 'train':
statistic['train low-navi seed len'] = len(low_navi_seed)
statistic['train low-mani seed len'] = len(low_mani_seed)
high_save_path = os.path.join(self.pp_path, '%s_high_action_instances.json'%split)
with open(high_save_path, 'w') as f:
json.dump(high_instances, f, indent=2)
low_save_path = os.path.join(self.pp_path, '%s_low_action_instances_mani.json'%split)
with open(low_save_path, 'w') as f:
json.dump(low_instances_mani, f, indent=2)
with open(low_save_path.replace('mani', 'navi'), 'w') as f:
json.dump(low_instances_navi, f, indent=2)
if split == 'train':
low_save_path = os.path.join(self.pp_path, '%s_low_action_seed_mani.json'%split)
with open(low_save_path, 'w') as f:
json.dump(low_mani_seed, f, indent=2)
with open(low_save_path.replace('mani', 'navi'), 'w') as f:
json.dump(low_navi_seed, f, indent=2)
with open(os.path.join(self.pp_path, 'mani_subgoals.json'), 'w') as f:
json.dump(mani_subgoal, f, indent=2)
# det_all_save_path = os.path.join(self.pp_path, '%s_det_res_all.json'%split)
# with open(det_all_save_path, 'w') as f:
# json.dump(det_res_all, f, indent=2)
det_sep_save_path = os.path.join(self.pp_path, '%s_det_res_sep.json'%split)
with open(det_sep_save_path, 'w') as f:
json.dump(det_res_sep, f, indent=2)
for k,v in statistic.items():
if isinstance(v, dict):
statistic[k] = dict(sorted(v.items(), key=lambda item: item[0]))
with open(os.path.join(self.pp_path, 'data_statistics.json'), 'w') as f:
json.dump(statistic, f, indent=2)