def run_knn_at_layer()

in vissl/utils/knn_utils.py [0:0]


def run_knn_at_layer(cfg: AttrDict, layer_name: str = "heads"):
    """
    Run the Nearest Neighbour benchmark at the layer "layer_name"
    """
    temperature = cfg.NEAREST_NEIGHBOR.SIGMA
    num_neighbors = cfg.NEAREST_NEIGHBOR.TOPK
    feature_dir = cfg.NEAREST_NEIGHBOR.FEATURES.PATH
    output_dir = get_checkpoint_folder(cfg)
    logging.info(f"Testing with sigma: {temperature}, topk neighbors: {num_neighbors}")

    ############################################################################
    # Step 1: get train and test features
    train_out = ExtractedFeaturesLoader.load_features(
        feature_dir, "train", layer_name, flatten_features=True
    )
    train_features, train_labels = train_out["features"], train_out["targets"]
    test_out = ExtractedFeaturesLoader.load_features(
        feature_dir, "test", layer_name, flatten_features=True
    )
    test_features, test_labels = test_out["features"], test_out["targets"]
    train_features = torch.from_numpy(train_features).float()
    test_features = torch.from_numpy(test_features).float()
    train_labels = torch.LongTensor(train_labels)
    num_classes = train_labels.max() + 1

    ###########################################################################
    # Step 2: calculate the nearest neighbor and the metrics
    accuracies = Accuracies()
    if cfg.NEAREST_NEIGHBOR.L2_NORM_FEATS:
        train_features = nn.functional.normalize(train_features, dim=1, p=2)
        test_features = nn.functional.normalize(test_features, dim=1, p=2)

    # put train features and labels on gpu and transpose train features
    if cfg.NEAREST_NEIGHBOR.USE_CUDA:
        train_features = train_features.cuda().t()
        test_features = test_features.cuda()
        train_labels = train_labels.cuda()
    else:
        train_features = train_features.t()

    num_test_images, num_chunks = test_labels.shape[0], 100
    imgs_per_chunk = num_test_images // num_chunks
    output_targets, output_predicted_label, output_inds = [], [], []
    with torch.no_grad():
        for idx in range(0, num_test_images, imgs_per_chunk):
            # get the features for test images and normalize the features if needed
            features = test_features[
                idx : min((idx + imgs_per_chunk), num_test_images), :
            ]
            targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images), :]
            batch_size = targets.shape[0]
            targets = torch.LongTensor(targets)
            if cfg.NEAREST_NEIGHBOR.USE_CUDA:
                targets = torch.LongTensor(targets).cuda()

            # calculate the dot product and compute top-k neighbors
            similarity = torch.mm(features, train_features)
            distances, indices = similarity.topk(
                num_neighbors, largest=True, sorted=True
            )
            candidates = train_labels.view(1, -1).expand(batch_size, -1)
            retrieved_neighbors = torch.gather(candidates, 1, indices)

            retrieval_one_hot = torch.zeros(batch_size * num_neighbors, num_classes)
            if cfg.NEAREST_NEIGHBOR.USE_CUDA:
                retrieval_one_hot = retrieval_one_hot.cuda()
            retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
            predictions = _get_sorted_predictions(
                batch_size, num_classes, distances, retrieval_one_hot, temperature
            )

            # find the predictions that match the target
            accuracies = accuracies + Accuracies.from_batch(predictions, targets)

            # get the predictions, nearest neighbors, inds to save
            output_inds.extend(range(idx, min((idx + imgs_per_chunk), num_test_images)))
            output_predicted_label.append(predictions.data.cpu().numpy())
            output_targets.append(targets.data.cpu().numpy())

    _save_knn_results(
        output_dir, layer_name, output_inds, output_predicted_label, output_targets
    )
    accuracies.log(layer_name)
    return accuracies.top_1, accuracies.top_5, accuracies.total