def instantiate_templates_dfs()

in dvd_generation/simulators/template_dfs.py [0:0]


def instantiate_templates_dfs(scene_struct, template, metadata,
                              answer_counts, synonyms, period_idx, turn_dependencies,
                              max_instances=None, turn_pos=-1, verbose=False):

    param_name_to_type = {p['name']: p['type'] for p in template['params']}
    
    curr_period = template['used_periods'][-1]
    cutoff = template['cutoff']

    initial_state = {
        'nodes': [node_shallow_copy(template['nodes'][0])],
        'vals': {},
        'input_map': {0: 0},
        'next_template_node': 1,
    }
    states = [initial_state]
    final_states = []
    while states:
        state = states.pop()

        # Check to make sure the current state is valid
        q = {'nodes': state['nodes']}
        outputs = qeng.answer_question(q, metadata, scene_struct, period_idx, turn_dependencies, all_outputs=True)
        answer = outputs[-1]
        if answer == '__INVALID__': continue

        # Check to make sure constraints are satisfied for the current state
        skip_state = constraint_filters(scene_struct, template, param_name_to_type, state, outputs, 
                                        period_idx, turn_dependencies=turn_dependencies, verbose=False)
        if skip_state: continue

        # We have already checked to make sure the answer is valid, so if we have
        # processed all the nodes in the template then the current state is a valid
        # question, so add it if it passes our rejection sampling tests.
        if state['next_template_node'] == len(template['nodes']):
            skip_by_sampling_check = sample_question_by_answer_counts(answer_counts, answer, verbose)
            if skip_by_sampling_check: continue 
            
            # If the template contains a raw relate node then we need to check for
            # degeneracy at the end
            has_relate = any(n['type'] == 'relate' for n in template['nodes'])
            if has_relate:
                degen = qeng.is_degenerate(q, metadata, scene_struct, answer=answer, period_idx=period_idx, verbose=verbose)
                if degen: continue

            answer_counts[answer] += 1
            state['answer'] = answer

            final_states.append(state)
            if max_instances is not None and len(final_states) == max_instances: break
            pdb.set_trace()
            continue
            
        # Otherwise fetch the next node from the template
        # Make a shallow copy so cached _outputs don't leak ... this is very nasty
        next_node = template['nodes'][state['next_template_node']]
        next_node = node_shallow_copy(next_node)

        if next_node['type'] in special_nodes:
            
            if random.random()>0 and 'unique' in next_node['type'] \
                and template['prior_unique_obj'] != -1 \
                and template['unique_obj'] is not None \
                and (template['interval_type'] == 'none' \
                     or 'none' not in turn_dependencies['temporal'] \
                     or curr_period==(None, cutoff)) \
                and template['ref_remark'] != 'no_reference': 
                input_map = {k: v for k, v in state['input_map'].items()}
                cur_next_vals = {k: v for k, v in state['vals'].items()}
                new_nodes = []
                new_nodes.append({
                    'type': 'unique_obj_ref',
                    'inputs':[input_map[next_node['inputs'][0]] + len(new_nodes)],
                    '_output': template['prior_unique_obj']
                })
                
                input_map[state['next_template_node']] = len(state['nodes']) + len(new_nodes) - 1
                for k,v in template['prior_unique_obj_attr'].items():
                    # only save old action value if same period
                    if '<A' in k:
                        curr_period = template['used_periods'][-1]
                        if v['period'] == curr_period: 
                            cur_next_vals[k] = v['val']
                        continue
                    cur_next_vals[k] = v
                states.append({
                    'nodes':state['nodes'] + new_nodes,
                    'vals': cur_next_vals,
                    'input_map':input_map,
                    'next_template_node':state['next_template_node'] + 1,
                })
                turn_dependencies['object'] = 'last_unique' 
                
                if template['sampled_ans_object']!=-1: 
                    oa = template['sampled_ans_object_attr_ref']
                    oi = template['sampled_ans_object'] 
                    update_unique_object(oi, oa, template['used_objects'], turn_pos)
                
                continue 

            if  random.random()>0 and 'unique' in next_node['type'] \
                and template['earlier_unique_obj']!=-1 \
                and template['earlier_unique_obj_node']==next_node:
                input_map = {k: v for k, v in state['input_map'].items()}
                cur_next_vals = {k: v for k, v in state['vals'].items()}
                new_nodes = []
                new_nodes.append({
                    'type': 'earlier_obj_ref',
                    'inputs':[input_map[next_node['inputs'][0]] + len(new_nodes)],
                    '_output': template['earlier_unique_obj']
                })
                input_map[state['next_template_node']] = len(state['nodes']) + len(new_nodes) - 1     
                attr_map = template['earlier_unique_obj_attr_map']
                for k,v in template['earlier_unique_obj_attr'].items(): 
                    cur_next_vals[attr_map[k]] = v
                turn_dependencies['object'] = 'earlier_unique'
                states.append({
                    'nodes':state['nodes'] + new_nodes,
                    'vals': cur_next_vals,
                    'input_map':input_map,
                    'next_template_node':state['next_template_node'] + 1,
                })
                continue 
                
            filter_options = get_filter_options(metadata, scene_struct, template, answer, next_node, period_idx)
            filter_option_keys = list(filter_options.keys())
            random.shuffle(filter_option_keys)
            
            for k in filter_option_keys:
                new_nodes = []
                cur_next_vals = {k: v for k, v in state['vals'].items()}
                next_input = state['input_map'][next_node['inputs'][0]]
                filter_side_inputs = next_node['side_inputs']
                if next_node['type'].startswith('relate'):
                    param_name = next_node['side_inputs'][0]  # First one should be relate
                    filter_side_inputs = next_node['side_inputs'][1:]
                    param_type = param_name_to_type[param_name]
                    assert param_type == 'Relation'
                    param_val = k[0]
                    k = k[1]
                    new_nodes.append({
                        'type': 'relate',
                        'inputs': [next_input],
                        'side_inputs': [param_val],
                    })
                    cur_next_vals[param_name] = param_val
                    next_input = len(state['nodes']) + len(new_nodes) - 1
                for param_name, param_val in zip(filter_side_inputs, k):
                    param_type = param_name_to_type[param_name]
                    if param_type == 'Action' and 'actions' in next_node['type']: 
                        filter_type = 'filter_actions'
                    else:
                        filter_type = 'filter_%s' % param_type.lower()
                    if param_val is not None:
                        new_nodes.append({
                            'type': filter_type,
                            'inputs': [next_input],
                            'side_inputs': [param_val],
                        })
                        cur_next_vals[param_name] = param_val
                        next_input = len(state['nodes']) + len(new_nodes) - 1
                    elif param_val is None:
                        if metadata['dataset'] == 'CLEVR-v1.0' and param_type == 'Shape':
                            param_val = 'thing'
                        else:
                            param_val = ''
                        cur_next_vals[param_name] = param_val
                input_map = {k: v for k, v in state['input_map'].items()}
                extra_type = None
                if next_node['type'].endswith('unique'):
                    extra_type = 'unique'
                if next_node['type'].endswith('count'):
                    extra_type = 'count'
                if next_node['type'].endswith('exist'):
                    extra_type = 'exist'
                if extra_type is not None:
                    new_nodes.append({
                        'type':extra_type,
                        'inputs':[input_map[next_node['inputs'][0]] + len(new_nodes)],
                    })
                input_map[state['next_template_node']] = len(
                    state['nodes']) + len(new_nodes) - 1
                states.append({
                    'nodes':state['nodes'] + new_nodes,
                    'vals':cur_next_vals,
                    'input_map':input_map,
                    'next_template_node':state['next_template_node'] + 1,
                })
        # i.e. 'Relate' node 
        elif 'side_inputs' in next_node:
            # If the next node has template parameters, expand them out
            # TODO: Generalize this to work for nodes with more than one side input
            assert len(next_node['side_inputs']) == 1, 'NOT IMPLEMENTED: {}'.format(next_node)

            # Use metadata to figure out domain of valid values for this parameter.
            # Iterate over the values in a random order; then it is safe to bail
            # from the DFS as soon as we find the desired number of valid template
            # instantiations.
            param_name = next_node['side_inputs'][0]
            if param_name not in param_name_to_type: 
                print(template)
                pdb.set_trace()
            param_type = param_name_to_type[param_name]
            param_vals = metadata['types'][param_type][:]
            random.shuffle(param_vals)
            for val in param_vals:
                input_map = {k: v for k, v in state['input_map'].items()}
                input_map[state['next_template_node']] = len(state['nodes'])
                cur_next_node = {
                    'type': next_node['type'],
                    'inputs': [input_map[idx] for idx in next_node['inputs']],
                    'side_inputs': [val],
                }
                cur_next_vals = {k: v for k, v in state['vals'].items()}
                cur_next_vals[param_name] = val

                states.append({
                    'nodes':state['nodes'] + [cur_next_node],
                    'vals':cur_next_vals,
                    'input_map':input_map,
                    'next_template_node':state['next_template_node'] + 1,
                })
        # remaining node types: 
        # scene, unique, union, intersect, count, exist, 
        # query_shape, query_color, query_size, query_material
        # equal_color, equal_shape, equal_size, equal_material, 
        # same_size, same_color, same_material, same_shape, 
        # equal_object, equal_integer, less_than, greater_than
        else:
            input_map = {k: v for k, v in state['input_map'].items()}
            input_map[state['next_template_node']] = len(state['nodes'])
            next_node = {
                'type': next_node['type'],
                'inputs': [input_map[idx] for idx in next_node['inputs']],
            }
            states.append({
                'nodes':state['nodes'] + [next_node],
                'vals':state['vals'],
                'input_map':input_map,
                'next_template_node':state['next_template_node'] + 1,
            })

    return final_states, instantiate_questions(final_states, template, synonyms, scene_struct, period_idx, 
                                               turn_dependencies)