def main()

in question_generation/generate_questions.py [0:0]


def main(args):
  with open(args.metadata_file, 'r') as f:
    metadata = json.load(f)
    dataset = metadata['dataset']
    if dataset != 'CLEVR-v1.0':
      raise ValueError('Unrecognized dataset "%s"' % dataset)
  
  functions_by_name = {}
  for f in metadata['functions']:
    functions_by_name[f['name']] = f
  metadata['_functions_by_name'] = functions_by_name

  # Load templates from disk
  # Key is (filename, file_idx)
  num_loaded_templates = 0
  templates = {}
  for fn in os.listdir(args.template_dir):
    if not fn.endswith('.json'): continue
    with open(os.path.join(args.template_dir, fn), 'r') as f:
      base = os.path.splitext(fn)[0]
      for i, template in enumerate(json.load(f)):
        num_loaded_templates += 1
        key = (fn, i)
        templates[key] = template
  print('Read %d templates from disk' % num_loaded_templates)

  def reset_counts():
    # Maps a template (filename, index) to the number of questions we have
    # so far using that template
    template_counts = {}
    # Maps a template (filename, index) to a dict mapping the answer to the
    # number of questions so far of that template type with that answer
    template_answer_counts = {}
    node_type_to_dtype = {n['name']: n['output'] for n in metadata['functions']}
    for key, template in templates.items():
      template_counts[key[:2]] = 0
      final_node_type = template['nodes'][-1]['type']
      final_dtype = node_type_to_dtype[final_node_type]
      answers = metadata['types'][final_dtype]
      if final_dtype == 'Bool':
        answers = [True, False]
      if final_dtype == 'Integer':
        if metadata['dataset'] == 'CLEVR-v1.0':
          answers = list(range(0, 11))
      template_answer_counts[key[:2]] = {}
      for a in answers:
        template_answer_counts[key[:2]][a] = 0
    return template_counts, template_answer_counts

  template_counts, template_answer_counts = reset_counts()

  # Read file containing input scenes
  all_scenes = []
  with open(args.input_scene_file, 'r') as f:
    scene_data = json.load(f)
    all_scenes = scene_data['scenes']
    scene_info = scene_data['info']
  begin = args.scene_start_idx
  if args.num_scenes > 0:
    end = args.scene_start_idx + args.num_scenes
    all_scenes = all_scenes[begin:end]
  else:
    all_scenes = all_scenes[begin:]

  # Read synonyms file
  with open(args.synonyms_json, 'r') as f:
    synonyms = json.load(f)

  questions = []
  scene_count = 0
  for i, scene in enumerate(all_scenes):
    scene_fn = scene['image_filename']
    scene_struct = scene
    print('starting image %s (%d / %d)'
          % (scene_fn, i + 1, len(all_scenes)))

    if scene_count % args.reset_counts_every == 0:
      print('resetting counts')
      template_counts, template_answer_counts = reset_counts()
    scene_count += 1

    # Order templates by the number of questions we have so far for those
    # templates. This is a simple heuristic to give a flat distribution over
    # templates.
    templates_items = list(templates.items())
    templates_items = sorted(templates_items,
                        key=lambda x: template_counts[x[0][:2]])
    num_instantiated = 0
    for (fn, idx), template in templates_items:
      if args.verbose:
        print('trying template ', fn, idx)
      if args.time_dfs and args.verbose:
        tic = time.time()
      ts, qs, ans = instantiate_templates_dfs(
                      scene_struct,
                      template,
                      metadata,
                      template_answer_counts[(fn, idx)],
                      synonyms,
                      max_instances=args.instances_per_template,
                      verbose=False)
      if args.time_dfs and args.verbose:
        toc = time.time()
        print('that took ', toc - tic)
      image_index = int(os.path.splitext(scene_fn)[0].split('_')[-1])
      for t, q, a in zip(ts, qs, ans):
        questions.append({
          'split': scene_info['split'],
          'image_filename': scene_fn,
          'image_index': image_index,
          'image': os.path.splitext(scene_fn)[0],
          'question': t,
          'program': q,
          'answer': a,
          'template_filename': fn,
          'question_family_index': idx,
          'question_index': len(questions),
        })
      if len(ts) > 0:
        if args.verbose:
          print('got one!')
        num_instantiated += 1
        template_counts[(fn, idx)] += 1
      elif args.verbose:
        print('did not get any =(')
      if num_instantiated >= args.templates_per_image:
        break

  # Change "side_inputs" to "value_inputs" in all functions of all functional
  # programs. My original name for these was "side_inputs" but I decided to
  # change the name to "value_inputs" for the public CLEVR release. I should
  # probably go through all question generation code and templates and rename,
  # but that could be tricky and take a while, so instead I'll just do it here.
  # To further complicate things, originally functions without value inputs did
  # not have a "side_inputs" field at all, and I'm pretty sure this fact is used
  # in some of the code above; however in the public CLEVR release all functions
  # have a "value_inputs" field, and it's an empty list for functions that take
  # no value inputs. Again this should probably be refactored, but the quick and
  # dirty solution is to keep the code above as-is, but here make "value_inputs"
  # an empty list for those functions that do not have "side_inputs". Gross.
  for q in questions:
    for f in q['program']:
      if 'side_inputs' in f:
        f['value_inputs'] = f['side_inputs']
        del f['side_inputs']
      else:
        f['value_inputs'] = []

  with open(args.output_questions_file, 'w') as f:
    print('Writing output to %s' % args.output_questions_file)
    json.dump({
        'info': scene_info,
        'questions': questions,
      }, f)