def sample_unique_answer_obj()

in dvd_generation/utils/dialogue_utils.py [0:0]


def sample_unique_answer_obj(turn, state, ans_obj, used_objects, scene_struct):
    obj_ids = turn['program'][-2]['_output']
    last_turn_que_type = turn['template']['nodes'][-1]['type']
    if last_turn_que_type in ['exist'] or 'filter_exist' in last_turn_que_type:
        return obj_ids, -1, None, None
    assert 'filter_count' in last_turn_que_type \
        or last_turn_que_type in ['count']
    assert len(obj_ids)!=1
    if len(obj_ids)==0:
        return obj_ids, -1, None, None
    if set(obj_ids).issubset(set(used_objects.keys())):
        return obj_ids, -1, None, None
    unfound_objs = list(set(obj_ids) - set(used_objects.keys()))
    random.shuffle(unfound_objs)
    
    got_sample = False 
    for sampled_obj in unfound_objs:
        #sampled_obj = random.choice(list(unfound_objs))
        if len(obj_ids) == len(scene_struct['objects']):
            identifiers = scene_struct['_minimal_object_identifiers'][sampled_obj]
            if identifiers is None: continue 
            sampled_obj_attr = random.choice(identifiers)
            got_sample = True 
            break 
        else:
            objs = {}
            for oi in obj_ids:
                obj = {}
                for ai, a in enumerate(identifier_attrs):
                    obj[a] = scene_struct['objects'][oi][identifier_attr_names[ai]]
                objs[oi] = obj
            attr_maps = precompute_object_filter_options(objs)
            _, all_identifiers = precompute_obj_identifiers(attr_maps, objs)
            identifiers = all_identifiers[sampled_obj]
            if identifiers == [None]: continue 
            sampled_obj_attr = random.choice(identifiers)
            got_sample = True
            break
        
    if not got_sample:
    #    pdb.set_trace()
        return obj_ids, -1, None, None
        
    obj_attr = {}
    for idx, a in enumerate(identifier_attrs):
        if sampled_obj_attr[idx] is not None:
            obj_attr[a] = sampled_obj_attr[idx]
    sampled_obj_attr = copy.deepcopy(obj_attr)
    if 'side_inputs' not in ans_obj[1]: 
        assert last_turn_que_type in ['count']
        return obj_ids, sampled_obj, obj_attr, sampled_obj_attr
    for a in ans_obj[1]['side_inputs']:
        if a not in state['vals']: pdb.set_trace() 
        if '<A' in a: 
            obj_attr[a] = {'val': state['vals'][a], 
                           'period': turn['template']['used_periods'][-1], 
                           'interval_type': turn['template']['interval_type']}
        elif state['vals'][a] is not None and state['vals'][a] not in ['','thing']:
            obj_attr[a] = state['vals'][a] 
    return obj_ids, sampled_obj, obj_attr, sampled_obj_attr