def generate_outliers_ours()

in synthesis/KNN.py [0:0]


def generate_outliers_ours(ID, input_index, negative_samples, ID_points_num=2, K=20, select=1, cov_mat=0.1, sampling_ratio=1.0, pic_nums=30, depth=342,
                           text_anchors=None):
    assert text_anchors is not None
    length = negative_samples.shape[0]
    ncls, nsample, ndim = ID.shape

    if True:
        rand_ind = np.random.choice(nsample, select, replace=False)
        id_data_points = ID[:, rand_ind, :].detach().view(ncls, select, 1, ndim)
    else:
        # rand_ind = []
        # for ci in range(ncls):
        #     rand_ind.append(torch.randperm(nsample) + ci * nsample)
        # rand_ind = torch.cat(rand_ind)
        # id_data_points = ID.view(ncls*nsample, ndim)[rand_ind].view(ncls, nsample, ndim)
        # id_data_points = id_data_points[:, :select, :].view(ncls, select, 1, ndim)

        id_data_points = []
        normed_data = ID / torch.norm(ID, p=2, dim=-1, keepdim=True)
        for ci in range(ncls):
            input_index.add(normed_data[ci])
            minD_idx, k_th = KNN_dis_search_decrease(ID[ci], input_index, K, select)
            id_data_points.append(ID[ci][minD_idx])  # shape(select,ndim)
        id_data_points = torch.stack(id_data_points).view(ncls, select, 1, ndim)

    negative_sample_cov = cov_mat * negative_samples.view(1, 1, length, ndim)
    negative_samples = (id_data_points + negative_sample_cov).view(ncls, select*length, ndim)
    normed_ood_feat = F.normalize(negative_samples, p=2, dim=1)   #shape(select*length,512)

    inter_similarity = normed_ood_feat @ text_anchors[:, 0, :].T  #shape(ncls,select*length,ncls)
    # only perserve samples with highest similarity among in-distribution text-features
    inter_candidate = inter_similarity.argmax(dim=-1) == torch.arange(ncls).cuda().view(ncls, 1)

    intra_similarity = torch.bmm(normed_ood_feat, text_anchors.transpose(1,2))  #shape(ncls,select*length,2)
    # only perserve samples with higher similarity to the perturbed text-feature
    intra_candidate = intra_similarity[..., 0] < intra_similarity[..., 1]

    candidate = inter_candidate & intra_candidate

    ood_samples, ood_labels = [], []
    for ci in range(ncls):
        syntheses = negative_samples[ci][candidate[ci]]
        valid_num = len(syntheses)
        labels = np.full((valid_num,), ci, dtype=np.int64)
        if valid_num > ID_points_num:
            rand_ind = np.random.choice(valid_num, ID_points_num, replace=False, p=None)
            syntheses, labels = syntheses[rand_ind], labels[rand_ind]

        ood_samples.append(syntheses)
        ood_labels.append(labels)
    
    # concatenate ood_samples outside
    ood_labels = np.concatenate(ood_labels)

    return ood_samples, ood_labels, candidate