in eval_retrieve_knn_pred.py [0:0]
def main():
parser = argparse.ArgumentParser('retrieval eval')
parser.add_argument('--output-dir', type=str)
parser.add_argument('--trainsplit', type=str, required=True)
parser.add_argument('--valsplit', type=str, required=True)
parser.add_argument('--num_replica', type=int, default=8)
parser.add_argument('--data-source', type=str)
args = parser.parse_args()
for i in range(args.num_replica):
os.path.exists(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.trainsplit, i)))
os.path.exists(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.trainsplit, i)))
os.path.exists(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.valsplit, i)))
os.path.exists(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.valsplit, i)))
os.path.exists(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.trainsplit)))
os.path.exists(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.valsplit)))
vid_num_train = np.load(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.trainsplit)))
train_padding_num = vid_num_train[0] % args.num_replica
vid_num_val = np.load(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.valsplit)))
val_padding_num = vid_num_val[0] % args.num_replica
feat_train = []
feat_train_cls = []
for i in range(args.num_replica):
feat_train.append(np.load(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.trainsplit, i))))
feat_train_cls.append(
np.load(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.trainsplit, i))))
if train_padding_num > 0:
for i in range(train_padding_num, args.num_replica):
feat_train[i] = feat_train[i][:-1, :]
feat_train_cls[i] = feat_train_cls[i][:-1]
feat_train = np.concatenate(feat_train, axis=0).squeeze()
feat_train_cls = np.concatenate(feat_train_cls, axis=0).squeeze()
print('feat_train: {}'.format(feat_train.shape))
print('feat_train_cls: {}'.format(feat_train_cls.shape))
feat_val = []
feat_val_cls = []
for i in range(args.num_replica):
feat_val.append(np.load(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.valsplit, i))))
feat_val_cls.append(np.load(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.valsplit, i))))
if val_padding_num > 0:
for i in range(val_padding_num, args.num_replica):
feat_val[i] = feat_val[i][:-1, :]
feat_val_cls[i] = feat_val_cls[i][:-1]
feat_val = np.concatenate(feat_val, axis=0)
feat_val_cls = np.concatenate(feat_val_cls, axis=0)
print('feat_val: {}'.format(feat_val.shape))
print('feat_val_cls: {}'.format(feat_val_cls.shape))
# kNN retrieval
if args.valsplit == 'test':
ks = [3]
else:
ks = [1, 5, 10, 20, 50]
topk_correct = {k: 0 for k in ks}
class_top = 1
if args.data_source == 'ucf':
class_num = 101
elif args.data_source == 'hmdb':
class_num = 51
else:
raise Exception('The data-source argument no assigned!')
class_correct = {cls: 0 for cls in range(0, class_num)}
class_total = {cls: 0 for cls in range(0, class_num)}
X_train = feat_train
y_train = feat_train_cls
X_test = feat_val
y_test = feat_val_cls
distances = cosine_distances(X_test, X_train)
indices = np.argsort(distances)
for k in ks:
# print(k)
top_k_indices = indices[:, :k]
if args.valsplit == 'test':
print(top_k_indices)
np.save(os.path.join(args.output_dir, 'top_k_indices.npy'), top_k_indices)
# print(top_k_indices.shape, y_test.shape)
for ind, test_label in zip(top_k_indices, y_test):
labels = y_train[ind]
if test_label in labels:
# print(test_label, labels)
topk_correct[k] += 1
if k == class_top:
class_correct[test_label] += 1
if k == class_top:
class_total[test_label] += 1
for k in ks:
correct = topk_correct[k]
total = len(X_test)
print('Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}'.format(k, correct, total, correct / total))
# save label
if args.valsplit != 'test':
label_file = os.path.join(args.output_dir, 'class_retrieval_vclr.txt')
f = open(label_file, 'w')
for k in class_correct.keys():
correct = class_correct[k]
total = class_total[k]
info = 'Classs-{}, Top-{}, correct = {:.2f}, total = {}, acc = {:.3f}'.format(
k, class_top, correct, total, correct / total)
print(info)
f.write(info)
f.write('\n')
f.close()
with open(os.path.join(args.output_dir, 'topk_correct.json'), 'w') as fp:
json.dump(topk_correct, fp)