in models/feat_pool.py [0:0]
def KNN_dis_search_distance(target, index, K=50, num_points=10, length=2000,depth=342):
'''
data_point: Queue for searching k-th points
target: the target of the search
K
'''
#Normalize the features
target_norm = torch.norm(target, p=2, dim=1, keepdim=True)
normed_target = target / target_norm
#start_time = time.time()
distance, output_index = index.search(normed_target, K)
k_th_distance = distance[:, -1]
k_th = k_th_distance.view(length, -1)
target_new = target.view(length, -1, depth)
#k_th_output_index = output_index[:, -1]
k_th_distance, minD_idx = torch.topk(k_th, num_points, dim=0)
# minD_idx = minD_idx.squeeze()
point_list = []
for i in range(minD_idx.shape[1]):
point_list.append(i*length + minD_idx[:,i])
#return torch.cat(point_list, dim=0)
return target[torch.cat(point_list)]