in question_generation/generate_questions.py [0:0]
def instantiate_templates_dfs(scene_struct, template, metadata, answer_counts,
synonyms, max_instances=None, verbose=False):
param_name_to_type = {p['name']: p['type'] for p in template['params']}
initial_state = {
'nodes': [node_shallow_copy(template['nodes'][0])],
'vals': {},
'input_map': {0: 0},
'next_template_node': 1,
}
states = [initial_state]
final_states = []
while states:
state = states.pop()
# Check to make sure the current state is valid
q = {'nodes': state['nodes']}
outputs = qeng.answer_question(q, metadata, scene_struct, all_outputs=True)
answer = outputs[-1]
if answer == '__INVALID__': continue
# Check to make sure constraints are satisfied for the current state
skip_state = False
for constraint in template['constraints']:
if constraint['type'] == 'NEQ':
p1, p2 = constraint['params']
v1, v2 = state['vals'].get(p1), state['vals'].get(p2)
if v1 is not None and v2 is not None and v1 != v2:
if verbose:
print('skipping due to NEQ constraint')
print(constraint)
print(state['vals'])
skip_state = True
break
elif constraint['type'] == 'NULL':
p = constraint['params'][0]
p_type = param_name_to_type[p]
v = state['vals'].get(p)
if v is not None:
skip = False
if p_type == 'Shape' and v != 'thing': skip = True
if p_type != 'Shape' and v != '': skip = True
if skip:
if verbose:
print('skipping due to NULL constraint')
print(constraint)
print(state['vals'])
skip_state = True
break
elif constraint['type'] == 'OUT_NEQ':
i, j = constraint['params']
i = state['input_map'].get(i, None)
j = state['input_map'].get(j, None)
if i is not None and j is not None and outputs[i] == outputs[j]:
if verbose:
print('skipping due to OUT_NEQ constraint')
print(outputs[i])
print(outputs[j])
skip_state = True
break
else:
assert False, 'Unrecognized constraint type "%s"' % constraint['type']
if skip_state:
continue
# We have already checked to make sure the answer is valid, so if we have
# processed all the nodes in the template then the current state is a valid
# question, so add it if it passes our rejection sampling tests.
if state['next_template_node'] == len(template['nodes']):
# Use our rejection sampling heuristics to decide whether we should
# keep this template instantiation
cur_answer_count = answer_counts[answer]
answer_counts_sorted = sorted(answer_counts.values())
median_count = answer_counts_sorted[len(answer_counts_sorted) // 2]
median_count = max(median_count, 5)
if cur_answer_count > 1.1 * answer_counts_sorted[-2]:
if verbose: print('skipping due to second count')
continue
if cur_answer_count > 5.0 * median_count:
if verbose: print('skipping due to median')
continue
# If the template contains a raw relate node then we need to check for
# degeneracy at the end
has_relate = any(n['type'] == 'relate' for n in template['nodes'])
if has_relate:
degen = qeng.is_degenerate(q, metadata, scene_struct, answer=answer,
verbose=verbose)
if degen:
continue
answer_counts[answer] += 1
state['answer'] = answer
final_states.append(state)
if max_instances is not None and len(final_states) == max_instances:
break
continue
# Otherwise fetch the next node from the template
# Make a shallow copy so cached _outputs don't leak ... this is very nasty
next_node = template['nodes'][state['next_template_node']]
next_node = node_shallow_copy(next_node)
special_nodes = {
'filter_unique', 'filter_count', 'filter_exist', 'filter',
'relate_filter', 'relate_filter_unique', 'relate_filter_count',
'relate_filter_exist',
}
if next_node['type'] in special_nodes:
if next_node['type'].startswith('relate_filter'):
unique = (next_node['type'] == 'relate_filter_unique')
include_zero = (next_node['type'] == 'relate_filter_count'
or next_node['type'] == 'relate_filter_exist')
filter_options = find_relate_filter_options(answer, scene_struct, metadata,
unique=unique, include_zero=include_zero)
else:
filter_options = find_filter_options(answer, scene_struct, metadata)
if next_node['type'] == 'filter':
# Remove null filter
filter_options.pop((None, None, None, None), None)
if next_node['type'] == 'filter_unique':
# Get rid of all filter options that don't result in a single object
filter_options = {k: v for k, v in filter_options.items()
if len(v) == 1}
else:
# Add some filter options that do NOT correspond to the scene
if next_node['type'] == 'filter_exist':
# For filter_exist we want an equal number that do and don't
num_to_add = len(filter_options)
elif next_node['type'] == 'filter_count' or next_node['type'] == 'filter':
# For filter_count add nulls equal to the number of singletons
num_to_add = sum(1 for k, v in filter_options.items() if len(v) == 1)
add_empty_filter_options(filter_options, metadata, num_to_add)
filter_option_keys = list(filter_options.keys())
random.shuffle(filter_option_keys)
for k in filter_option_keys:
new_nodes = []
cur_next_vals = {k: v for k, v in state['vals'].items()}
next_input = state['input_map'][next_node['inputs'][0]]
filter_side_inputs = next_node['side_inputs']
if next_node['type'].startswith('relate'):
param_name = next_node['side_inputs'][0] # First one should be relate
filter_side_inputs = next_node['side_inputs'][1:]
param_type = param_name_to_type[param_name]
assert param_type == 'Relation'
param_val = k[0]
k = k[1]
new_nodes.append({
'type': 'relate',
'inputs': [next_input],
'side_inputs': [param_val],
})
cur_next_vals[param_name] = param_val
next_input = len(state['nodes']) + len(new_nodes) - 1
for param_name, param_val in zip(filter_side_inputs, k):
param_type = param_name_to_type[param_name]
filter_type = 'filter_%s' % param_type.lower()
if param_val is not None:
new_nodes.append({
'type': filter_type,
'inputs': [next_input],
'side_inputs': [param_val],
})
cur_next_vals[param_name] = param_val
next_input = len(state['nodes']) + len(new_nodes) - 1
elif param_val is None:
if metadata['dataset'] == 'CLEVR-v1.0' and param_type == 'Shape':
param_val = 'thing'
else:
param_val = ''
cur_next_vals[param_name] = param_val
input_map = {k: v for k, v in state['input_map'].items()}
extra_type = None
if next_node['type'].endswith('unique'):
extra_type = 'unique'
if next_node['type'].endswith('count'):
extra_type = 'count'
if next_node['type'].endswith('exist'):
extra_type = 'exist'
if extra_type is not None:
new_nodes.append({
'type': extra_type,
'inputs': [input_map[next_node['inputs'][0]] + len(new_nodes)],
})
input_map[state['next_template_node']] = len(state['nodes']) + len(new_nodes) - 1
states.append({
'nodes': state['nodes'] + new_nodes,
'vals': cur_next_vals,
'input_map': input_map,
'next_template_node': state['next_template_node'] + 1,
})
elif 'side_inputs' in next_node:
# If the next node has template parameters, expand them out
# TODO: Generalize this to work for nodes with more than one side input
assert len(next_node['side_inputs']) == 1, 'NOT IMPLEMENTED'
# Use metadata to figure out domain of valid values for this parameter.
# Iterate over the values in a random order; then it is safe to bail
# from the DFS as soon as we find the desired number of valid template
# instantiations.
param_name = next_node['side_inputs'][0]
param_type = param_name_to_type[param_name]
param_vals = metadata['types'][param_type][:]
random.shuffle(param_vals)
for val in param_vals:
input_map = {k: v for k, v in state['input_map'].items()}
input_map[state['next_template_node']] = len(state['nodes'])
cur_next_node = {
'type': next_node['type'],
'inputs': [input_map[idx] for idx in next_node['inputs']],
'side_inputs': [val],
}
cur_next_vals = {k: v for k, v in state['vals'].items()}
cur_next_vals[param_name] = val
states.append({
'nodes': state['nodes'] + [cur_next_node],
'vals': cur_next_vals,
'input_map': input_map,
'next_template_node': state['next_template_node'] + 1,
})
else:
input_map = {k: v for k, v in state['input_map'].items()}
input_map[state['next_template_node']] = len(state['nodes'])
next_node = {
'type': next_node['type'],
'inputs': [input_map[idx] for idx in next_node['inputs']],
}
states.append({
'nodes': state['nodes'] + [next_node],
'vals': state['vals'],
'input_map': input_map,
'next_template_node': state['next_template_node'] + 1,
})
# Actually instantiate the template with the solutions we've found
text_questions, structured_questions, answers = [], [], []
for state in final_states:
structured_questions.append(state['nodes'])
answers.append(state['answer'])
text = random.choice(template['text'])
for name, val in state['vals'].items():
if val in synonyms:
val = random.choice(synonyms[val])
text = text.replace(name, val)
text = ' '.join(text.split())
text = replace_optionals(text)
text = ' '.join(text.split())
text = other_heuristic(text, state['vals'])
text_questions.append(text)
return text_questions, structured_questions, answers