in util/convert_nmn_layouts.py [0:0]
def extract_set(params):
# assembler to look for incorrect programs
assembler = Assembler(params.prog_vocab_file)
# manual correction to layouts
layout_correct = {('_Find', '_Transform', '_And', '_Describe')
:['_Find', '_Transform', '_Describe'],
('_Transform', '_Describe')
:['_Find', '_Transform', '_Describe'],
('_Transform', '_Transform', '_And', '_Describe')
:['_Find', '_Transform', '_Transform', '_Describe'],
('_Describe',)
:['_Find', '_Describe'],
('_Transform', '_Find', '_And', '_Describe')
:['_Find', '_Transform', '_Describe']}
with open(params.nmn_file) as f:
# drop the spans
read_layouts = [re.sub(r'\[\d*,\d*\]', '', ll) for ll in f.readlines()]
layouts = [flatten_layout(parse_tree(ll)) for ll in read_layouts]
layouts = [layout_correct.get(tuple(ii), tuple(ii)) for ii in layouts]
with open(params.nmn_file) as f:
# extracting spans as well
lines = [ii for ii in f.readlines()]
attentions = []
for index, ii in enumerate(lines):
layout = layouts[index]
# extract the spans
matches = re.findall('(\w\w)\[(\d*),(\d*)\]', ii)
# match module with attention, if present
att = []
for token in layout:
candidates = []
if token == '_Find':
candidates = [jj for jj in matches if jj[0] == 'nd']
if token == '_Transform':
candidates = [jj for jj in matches if jj[0] == 'te']
if token == '_Describe':
candidates = [jj for jj in matches
if jj[0] != 'te' or jj[0] != 'nd']
if len(candidates) >= 1:
att.append((int(candidates[0][1]), int(candidates[0][2])))
matches.remove(candidates[0])
else:
att.append((0, 0))
# record attentions and layouts
attentions.append(att)
# correct the layouts according to the above dictionary
layouts = [layout_correct.get(tuple(ii), ii) for ii in layouts]
layout_set = {tuple(l) for l in layouts}
print('Found %d unique layouts' % len(layout_set))
for l in layout_set:
print(' ', ' '.join(list(l)))
# check whether the layout is valid
for l in layout_set:
batch = assembler.module_list2tokens(l, T=20)
validity, error = assembler.sanity_check_program(batch)
if not validity:
raise Exception('invalid expr:' + str(l) + ' ' + error)
# read the original data path
with open(params.visdial_file, 'r') as file_id:
vd_data = json.load(file_id)
# question id to layout dictionary
if params.question:
qid2layout_dict = {}
for datum in progressbar(vd_data['data']['dialogs']):
img_id = datum['image_id']
for r_id, round_datum in enumerate(datum['dialog']):
q_id = img_id * 10 + r_id
q_layout = layouts[round_datum['question']]
# record
qid2layout_dict[q_id] = q_layout
np.save(params.save_path, np.array(qid2layout_dict))
else:
np.save(params.save_path, np.array(layouts))
print('Saving to: ' + params.save_path)
save_file_att = params.save_path.replace('.layout', '.attention')
print('Saving (att) to: ' + save_file_att)
np.save(save_file_att, np.array(attentions))
set_layout_length = [len(l) for l in layouts]
return set_layout_length