def KNN_dis_search_distance()

in models/feat_pool.py [0:0]


def KNN_dis_search_distance(target, index, K=50, num_points=10, length=2000,depth=342):
    '''
    data_point: Queue for searching k-th points
    target: the target of the search
    K
    '''
    #Normalize the features

    target_norm = torch.norm(target, p=2, dim=1,  keepdim=True)
    normed_target = target / target_norm
    #start_time = time.time()

    distance, output_index = index.search(normed_target, K)
    k_th_distance = distance[:, -1]
    k_th = k_th_distance.view(length, -1)
    target_new = target.view(length, -1, depth)
    #k_th_output_index = output_index[:, -1]
    k_th_distance, minD_idx = torch.topk(k_th, num_points, dim=0)
    # minD_idx = minD_idx.squeeze()
    point_list = []
    for i in range(minD_idx.shape[1]):
        point_list.append(i*length + minD_idx[:,i])
    #return torch.cat(point_list, dim=0)
    return target[torch.cat(point_list)]