def precompute_actions_filter_options()

in cater_preprocessing/utils.py [0:0]


def precompute_actions_filter_options(scene_struct, period_idx):
    attr_keys = ['action', 'size', 'color', 'material', 'shape']
    # Precompute masks
    masks = []
    for i in range(2**len(attr_keys)):
        mask = []
        for j in range(len(attr_keys)):
            mask.append((i // (2**j)) % 2)
        masks.append(mask)
    
    scene_struct['_actions_filter_options'][period_idx] = {} 
    for relation in ['during', 'excluding']:
        attribute_map = {}
        for object_idx, obj in enumerate(scene_struct['objects']):
            keys = []
            contained = False 
            attr_key = [obj[k] for k in attr_keys[1:]]
            actions = scene_struct['_action_list'][period_idx][relation][object_idx]
            
            if actions is not None: 
                for a in actions:
                    keys.append(tuple([a] + attr_key))
                no_contained_actions = [a for a in actions if a!='contained']
                if 'contained' in actions: 
                    contained = True
                if len(no_contained_actions)>0:
                    keys.append(tuple(['moving'] + attr_key))
            else:
                keys.append(tuple([None] + attr_key))

            for mask in masks:
                for key in keys:
                    masked_key = []
                    for a, b in zip(key, mask):
                        if b == 1:
                            masked_key.append(a)
                        else:
                            masked_key.append(None)
                    masked_key = tuple(masked_key)
                    if masked_key not in attribute_map:
                        attribute_map[masked_key] = set()
                    attribute_map[masked_key].add(object_idx)
        
        attribute_map = remove_snitch_filter_options(attribute_map, with_action=True)
        attribute_map = remove_invalid_action_options(attribute_map)
        scene_struct['_actions_filter_options'][period_idx][relation] = attribute_map