in scripts/attr_prep_tag_NP.py [0:0]
def define_split(database):
with open(args.train_cap_file) as f:
train_ids = json.load(f).keys()
with open(args.val_cap_file) as f:
valtest_ids = json.load(f).keys()
val_split = np.random.rand(len(valtest_ids))>=0.5 # split a half as the test split
val_ids = [valtest_ids[i] for i,j in enumerate(val_split) if j]
test_ids = [valtest_ids[i] for i,j in enumerate(val_split) if ~j]
vid_ids = set(database.keys())
train_ann_ids = vid_ids.intersection(set(train_ids))
val_ann_ids = vid_ids.intersection(set(val_ids))
test_ann_ids = vid_ids.intersection(set(test_ids))
print('All data - total: {}, train split: {}, val split: {}, test split: {}'.format(len(train_ids+val_ids+test_ids), len(train_ids), len(val_ids), len(test_ids)))
print('Annotated data - total: {}, train split: {}, val split: {}, and test split: {}'.format(
len(vid_ids), len(train_ann_ids), len(val_ann_ids), len(test_ann_ids)))
return [train_ids, val_ids, test_ids]