def run_knn_at_layer_low_memory()

in vissl/utils/knn_utils.py [0:0]


def run_knn_at_layer_low_memory(cfg: AttrDict, layer_name: str = "heads"):
    """
    Alternate implementation of kNN which scales to bigger features
    and bigger "train" splits
    """
    if cfg.NEAREST_NEIGHBOR.USE_CUDA:
        logging.warning(
            "config.NEAREST_NEIGHBOR.USE_CUDA is not available when "
            "config.NEAREST_NEIGHBOR.OPTIMIZE_MEMORY is set to True, "
            "using CPU instead"
        )

    temperature = cfg.NEAREST_NEIGHBOR.SIGMA
    num_neighbors = cfg.NEAREST_NEIGHBOR.TOPK
    feature_dir = cfg.NEAREST_NEIGHBOR.FEATURES.PATH
    output_dir = get_checkpoint_folder(cfg)
    logging.info(f"Testing with sigma: {temperature}, topk neighbors: {num_neighbors}")

    # Step 1: get the test features (the train features might not feat in memory)
    test_out = ExtractedFeaturesLoader.load_features(
        feature_dir, "test", layer_name, flatten_features=True
    )
    test_features, test_labels = test_out["features"], test_out["targets"]
    test_features = torch.from_numpy(test_features).float()
    test_feature_num = test_features.shape[0]

    # Step 2: normalize the features if needed
    if cfg.NEAREST_NEIGHBOR.L2_NORM_FEATS:
        test_features = nn.functional.normalize(test_features, dim=1, p=2)

    # Step 3: collect the similarity score of each test feature
    # to all the train features, making sure:
    # - never to load the all train features at once to avoid OOM
    # - to keep just the 'num_neighbors' best similarity scores
    shard_paths = ExtractedFeaturesLoader.get_shard_file_names(
        input_dir=feature_dir, split="train", layer=layer_name
    )
    similarity_queue = MaxSimilarityPriorityQueue(max_size=num_neighbors)
    num_classes = 0
    for shard_path in shard_paths:
        shard_content = ExtractedFeaturesLoader.load_feature_shard(shard_path)
        train_features = torch.from_numpy(shard_content.features)
        train_features = train_features.float().reshape((train_features.shape[0], -1))
        if cfg.NEAREST_NEIGHBOR.L2_NORM_FEATS:
            train_features = nn.functional.normalize(train_features, dim=1, p=2)
        train_features = train_features.t()

        train_labels = torch.LongTensor(shard_content.targets).squeeze(-1)
        num_classes = max(num_classes, train_labels.max().item() + 1)
        similarities = torch.mm(test_features, train_features)
        if similarities.shape[0] > num_neighbors:
            distances, indices = similarities.topk(
                num_neighbors, largest=True, sorted=True
            )
        else:
            distances, indices = torch.sort(similarities, descending=True)
        closest_labels = train_labels[indices]
        similarity_queue.push_all(distances, closest_labels)

    # Step 4: collect the samples with the closest similarities
    # for each test sample, and assemble it in a matrix with
    # shape (num_test_samples, num_neighbors)
    topk_distances, topk_labels = similarity_queue.pop_all()

    # Step 5: go through each of the test samples, batch by batch,
    # to compute the label of each test sample based on the top k
    # nearest neighbors and their corresponding labels
    accuracies = Accuracies()
    output_targets, output_predicted_label, output_inds = [], [], []

    batch_size = 100
    num_test_images = test_feature_num
    for idx in range(0, num_test_images, batch_size):
        min_idx = idx
        max_idx = min(idx + batch_size, num_test_images)

        distances = topk_distances[min_idx:max_idx, ...]
        retrieved_neighbors = topk_labels[min_idx:max_idx, ...]
        targets = torch.LongTensor(test_labels[min_idx:max_idx])

        retrieval_one_hot = torch.zeros(batch_size * num_neighbors, num_classes)
        retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
        predictions = _get_sorted_predictions(
            batch_size, num_classes, distances, retrieval_one_hot, temperature
        )

        # find the predictions that match the target
        accuracies = accuracies + Accuracies.from_batch(predictions, targets)

        # get the predictions, nearest neighbors, inds to save
        output_inds.extend(range(min_idx, max_idx))
        output_predicted_label.append(predictions.data.cpu().numpy())
        output_targets.append(targets.data.cpu().numpy())

    _save_knn_results(
        output_dir, layer_name, output_inds, output_predicted_label, output_targets
    )
    accuracies.log(layer_name)
    return accuracies.top_1, accuracies.top_5, accuracies.total