def define_split()

in scripts/attr_prep_tag_NP.py [0:0]


def define_split(database):

    with open(args.train_cap_file) as f:
        train_ids = json.load(f).keys()

    with open(args.val_cap_file) as f:
        valtest_ids = json.load(f).keys()

    val_split = np.random.rand(len(valtest_ids))>=0.5 # split a half as the test split
    val_ids = [valtest_ids[i] for i,j in enumerate(val_split) if j]
    test_ids = [valtest_ids[i] for i,j in enumerate(val_split) if ~j]

    vid_ids = set(database.keys())
    train_ann_ids = vid_ids.intersection(set(train_ids))
    val_ann_ids = vid_ids.intersection(set(val_ids))
    test_ann_ids = vid_ids.intersection(set(test_ids))

    print('All data - total: {}, train split: {}, val split: {}, test split: {}'.format(len(train_ids+val_ids+test_ids), len(train_ids), len(val_ids), len(test_ids)))
    print('Annotated data - total: {}, train split: {}, val split: {}, and test split: {}'.format(
        len(vid_ids), len(train_ann_ids), len(val_ann_ids), len(test_ann_ids)))

    return [train_ids, val_ids, test_ids]