def prep_all()

in scripts/attr_prep_tag_NP.py [0:0]


def prep_all(database, database_cap, obj_cls_lst, w2l, nlp):

    w2d = {}
    for ind, obj in enumerate(obj_cls_lst):
        w2d[obj] = ind

    avg_box = [] # number of boxes per segment
    avg_attr = [] # number of attributes per box
    attr_all = [] # all the attributes
    crowd_all = [] # all the crowd labels

    attr_dict = defaultdict(list)
    with open(args.attr_to_video_file) as f:
        for line in f.readlines():
            line_split = line.split(',')
            attr_id = line_split[0]
            vid_name = line_split[-1]
            attr = ','.join(line_split[1:-1])
            vid_id, seg_id = vid_name.strip().split('_segment_')
            attr_dict[(vid_id, str(int(seg_id)))].append([int(attr_id), attr])

    print('Number of segments with attributes: {}'.format(len(attr_dict)))

    vid_seg_dict = {}
    for vid_id, vid in database.items():
        for seg_id, _ in vid['segments'].items():
            vid_seg_dict[(vid_id, seg_id)] = vid_seg_dict.get((vid_id, seg_id), 0) + 1

    new_database = {}
    new_database_np = {}
    seg_counter = 0
    for vid_id, cap in database_cap.items():
        new_database_np[vid_id] = {'segments':{}}
        new_seg = {}
        for cap_id in range(len(cap['sentences'])):
            new_obj_lst = defaultdict(list)
            seg_id = str(cap_id)
            new_database_np[vid_id]['segments'][seg_id] = {'objects':[]}
            if vid_seg_dict.get((vid_id, seg_id), 0) == 0:
                new_obj_lst['tokens'] = nlp.word_tokenize(cap['sentences'][cap_id].encode('utf-8')) # sentences not in ANet-BB
            else:
                vid = database[vid_id]
                seg = vid['segments'][seg_id]

                # preprocess attributes
                attr_sent = sorted(attr_dict[(vid_id, seg_id)], key=lambda x:x[0])
                start_ind = attr_sent[0][0]

                # legacy token issues from our annotation tool
                for ind, tup in enumerate(attr_sent):
                    if attr_sent[ind][1] == '\\,':
                        attr_sent[ind][1] = ','

                new_obj_lst['tokens'] = [i[1] for i in attr_sent] # all the word tokens

                for obj in seg['objects']:
                    assert(len(obj['frame_ind']) == 1)

                    np_ann = {}

                    box_id = obj['frame_ind'].keys()[0]
                    box = obj['frame_ind'].values()[0]

                    np_ann['frame_ind'] = int(box_id)
                    np_ann.update(box)

                    if len(box['attributes']) > 0: # just in case the attribute is empty, though it should not be
                        tmp = []
                        tmp_ind = []
                        tmp_obj = []
                        attr_lst = []
                        attr_ind_lst = []
                        tmp_np_ind = []
                        np_lst = []
                        sorted_attr = sorted(box['attributes'], key=lambda x:x[0]) # the attributes are unordered
                        sorted_attr = [(x[0]-start_ind, x[1]) for x in sorted_attr] # index relative to the sent
    
                        for ind, attr in enumerate(sorted_attr):
                            assert(attr[0] >= 0)
                            attr_w = attr[1].lower()
                            if len(tmp) == 0:
                                tmp.append(attr_w) # convert to lowercase
                                tmp_np_ind.append(attr[0])
                                if w2l.get(attr_w, -1) != -1:
                                    attr_l = w2l[attr_w]
                                    if w2d.get(attr_l, -1) != -1:
                                        tmp_obj.append(attr_l)
                                        tmp_ind.append(attr[0])
                            else:
                                if attr[0] == (sorted_attr[ind-1][0]+1):
                                    tmp.append(attr_w)
                                    tmp_np_ind.append(attr[0])
                                    if w2l.get(attr_w, -1) != -1:
                                        attr_l = w2l[attr_w]
                                        if w2d.get(attr_l, -1) != -1:
                                            tmp_obj.append(attr_l)
                                            tmp_ind.append(attr[0])
                                else:
                                    np_lst.append([' '.join(tmp), tmp_np_ind])
                                    if len(tmp_obj) >= 1:
                                        attr_lst.append(tmp_obj[-1]) # the last noun is usually the head noun
                                        attr_ind_lst.append(tmp_ind[-1])

                                    tmp = [attr_w]
                                    tmp_np_ind = [attr[0]]
                                    if w2l.get(attr_w, -1) != -1:
                                        attr_l = w2l[attr_w]
                                        if w2d.get(attr_l, -1) != -1:
                                            tmp_obj = [attr_l]
                                            tmp_ind = [attr[0]]
                                        else:
                                            tmp_obj = []
                                            tmp_ind = []
                                    else:
                                        tmp_obj = []
                                        tmp_ind = []

                        if len(tmp) > 0: # the last one
                            np_lst.append([' '.join(tmp), tmp_np_ind])
                            if len(tmp_obj) >= 1:
                                attr_lst.append(tmp_obj[-1]) # the last noun is usually the head noun
                                attr_ind_lst.append(tmp_ind[-1])

                        assert(len(np_lst) > 0)

                        np_ann['noun_phrases'] = np_lst
                        np_ann.pop('attributes', None)
                        new_database_np[vid_id]['segments'][seg_id]['objects'].append(np_ann)
    
                        # exclude empty box (no attribute)
                        # crowd boxes are ok for now
                        if len(attr_lst) == 0: # or box['crowds'] == 1
                            pass
                            # print('empty attribute at video {}, segment {}, box {}'.format(vid_id, seg_id, box_id))
                        else:
                            new_obj_lst['process_bnd_box'].append([box['xtl'], box['ytl'], box['xbr'], box['ybr']])
                            new_obj_lst['frame_ind'].append(int(box_id))
                            new_obj_lst['crowds'].append(box['crowds'])
                            new_obj_lst['process_clss'].append(attr_lst)
                            new_obj_lst['process_idx'].append(attr_ind_lst)
                            avg_attr.append(len(attr_lst))
                            attr_all.extend([' '.join(i) for i in attr_lst])
                            crowd_all.append(box['crowds'])
    
            avg_box.append(len(new_obj_lst['frame_ind'])) # cound be 0
            if len(new_obj_lst['frame_ind']) == 0:
                new_obj_lst['process_bnd_box'] = []
                new_obj_lst['frame_ind'] = [] # all empty
                new_obj_lst['crowds'] = []
                new_obj_lst['process_clss'] = []
                new_obj_lst['process_idx'] = []
            seg_counter += 1
            new_seg[seg_id] = new_obj_lst
            new_database_np[vid_id]['segments'][seg_id]['tokens'] = new_obj_lst['tokens']

        new_database[vid_id] = {'segments':new_seg}

    # quick stats
    print('Number of videos: {} (including empty ones)'.format(len(new_database)))
    print('Number of segments: {}'.format(seg_counter))
    print('Average number of valid segments per video: {}'.format(np.mean([len(vid['segments']) for vid_id, vid in new_database.items()])))
    print('Average number of box per segment: {} and frequency: {}'.format(np.mean(avg_box), Counter(avg_box)))

    print('Average number of attributes per box: {} and frequency: {} (for valid box only)'.format(np.mean(avg_attr), Counter(avg_attr)))
    crowd_freq = Counter(crowd_all)
    print('Percentage of crowds: {} (for valid box only)'.format(crowd_freq[1]*1./(crowd_freq[1]+crowd_freq[0])))

    return new_database, new_database_np