def extract_set()

in util/convert_nmn_layouts.py [0:0]


def extract_set(params):
  # assembler to look for incorrect programs
  assembler = Assembler(params.prog_vocab_file)

  # manual correction to layouts
  layout_correct = {('_Find', '_Transform', '_And', '_Describe')
                      :['_Find', '_Transform', '_Describe'],
                    ('_Transform', '_Describe')
                      :['_Find', '_Transform', '_Describe'],
                    ('_Transform', '_Transform', '_And', '_Describe')
                      :['_Find', '_Transform', '_Transform', '_Describe'],
                    ('_Describe',)
                      :['_Find', '_Describe'],
                    ('_Transform', '_Find', '_And', '_Describe')
                      :['_Find', '_Transform', '_Describe']}

  with open(params.nmn_file) as f:
    # drop the spans
    read_layouts = [re.sub(r'\[\d*,\d*\]', '', ll) for ll in f.readlines()]
    layouts = [flatten_layout(parse_tree(ll)) for ll in read_layouts]
    layouts = [layout_correct.get(tuple(ii), tuple(ii)) for ii in layouts]

  with open(params.nmn_file) as f:
    # extracting spans as well
    lines = [ii for ii in f.readlines()]
    attentions = []
    for index, ii in enumerate(lines):
      layout = layouts[index]
      # extract the spans
      matches = re.findall('(\w\w)\[(\d*),(\d*)\]', ii)

      # match module with attention, if present
      att = []
      for token in layout:
        candidates = []
        if token == '_Find':
          candidates = [jj for jj in matches if jj[0] == 'nd']

        if token == '_Transform':
          candidates = [jj for jj in matches if jj[0] == 'te']

        if token == '_Describe':
          candidates = [jj for jj in matches
                        if jj[0] != 'te' or jj[0] != 'nd']

        if len(candidates) >= 1:
          att.append((int(candidates[0][1]), int(candidates[0][2])))
          matches.remove(candidates[0])
        else:
          att.append((0, 0))

      # record attentions and layouts
      attentions.append(att)

  # correct the layouts according to the above dictionary
  layouts = [layout_correct.get(tuple(ii), ii) for ii in layouts]
  layout_set = {tuple(l) for l in layouts}

  print('Found %d unique layouts' % len(layout_set))
  for l in layout_set:
    print(' ', ' '.join(list(l)))

  # check whether the layout is valid
  for l in layout_set:
    batch = assembler.module_list2tokens(l, T=20)
    validity, error = assembler.sanity_check_program(batch)

    if not validity:
      raise Exception('invalid expr:' + str(l) + ' ' + error)

  # read the original data path
  with open(params.visdial_file, 'r') as file_id:
    vd_data = json.load(file_id)

  # question id to layout dictionary
  if params.question:
    qid2layout_dict = {}
    for datum in progressbar(vd_data['data']['dialogs']):
      img_id = datum['image_id']
      for r_id, round_datum in enumerate(datum['dialog']):
        q_id = img_id * 10 + r_id
        q_layout = layouts[round_datum['question']]
        # record
        qid2layout_dict[q_id] = q_layout

    np.save(params.save_path, np.array(qid2layout_dict))
  else:
    np.save(params.save_path, np.array(layouts))
  print('Saving to: ' + params.save_path)

  save_file_att = params.save_path.replace('.layout', '.attention')
  print('Saving (att) to: ' + save_file_att)
  np.save(save_file_att, np.array(attentions))

  set_layout_length = [len(l) for l in layouts]
  return set_layout_length