def generate_outliers()

in synthesis/KNN.py [0:0]


def generate_outliers(ID, input_index, negative_samples, ID_points_num=2, K=20, select=1, cov_mat=0.1, sampling_ratio=1.0, pic_nums=30, depth=342,
                      text_anchors=None, cls_mask=None):
    ncls, nsample, ndim = ID.shape
    length, _ = negative_samples.shape
    normed_data = ID / torch.norm(ID, p=2, dim=-1, keepdim=True)
    if cls_mask is not None:
        normed_data = normed_data[cls_mask] #.float()
        text_anchors = text_anchors[cls_mask]

    distance = torch.cdist(normed_data, normed_data.detach()).half()  # shape(ncls, nsample, nsample)
    k_th_distance = -torch.topk(-distance, K, dim=-1)[0][..., -1]  # k-th nearset (smallest distance), shape(ncls, nsample)
    minD_idx = torch.topk(k_th_distance, select, dim=1)[1]  # top-k largest distance, shape(ncls, select)
    minD_idx = minD_idx[:, np.random.choice(select, int(pic_nums), replace=False)]  #shape(ncls, pic_nums)
    cls_idx = torch.arange(ncls).view(ncls, 1)
    if cls_mask is not None:
        cls_idx = cls_idx[cls_mask]
    data_point_list = ID[cls_idx.repeat(1, pic_nums).view(-1), minD_idx.view(-1)].view(-1, pic_nums, 1, ndim)

    negative_sample_cov = cov_mat*negative_samples.view(1, 1, length, ndim)
    negative_sample_list = (negative_sample_cov + data_point_list).view(-1, pic_nums*length, ndim)

    normed_ood_feat = F.normalize(negative_sample_list, p=2, dim=-1)  #shape(cls, pic_nums*length, 512)
    distance = torch.cdist(normed_ood_feat, normed_data.half())  # shape(ncls, pic_nums*length, nsample)
    k_th_distance = -torch.topk(-distance, K, dim=-1)[0][..., -1]  # k-th nearset (smallest distance), shape(ncls, pic_nums*length)
    
    if text_anchors is not None:  # shape(cls,2,ndim)
        intra_similarity = torch.bmm(normed_ood_feat, text_anchors.permute(0, 2, 1))  #shape(cls,pic_nums*length,2)
        # only perserve samples with higher similarity to the perturbed text-feature
        intra_candidate = intra_similarity[..., 0] < intra_similarity[..., 1]
        
        # inter_similarity = normed_ood_feat.float() @ text_anchors[:, 0, :].float().T  #shape(cls, pic_nums*length,ncls)
        # # only perserve samples with highest similarity among in-distribution text-features
        # inter_candidate = inter_similarity.argmax(dim=-1) == torch.arange(ncls).view(ncls,1).to(inter_similarity.device)

        candidate = intra_candidate #& inter_candidate

        k_th_distance *= candidate.float()
    
    k_distance, minD_idx = torch.topk(k_th_distance, ID_points_num, dim=1)  # top-k largest distance, shape(ncls, ID_points_num)
    OOD_labels = torch.arange(normed_data.size(0)).view(-1, 1).repeat(1, ID_points_num).view(-1)
    OOD_syntheses = negative_sample_list[OOD_labels, minD_idx.view(-1)]    #shape(ncls*ID_points_num, 512)

    if text_anchors is not None:
        valid = k_distance.view(-1) > 0
        OOD_syntheses, OOD_labels = OOD_syntheses[valid], OOD_labels[valid]
    
    if OOD_syntheses.shape[0]:
        # concatenate ood_samples outside
        OOD_syntheses = torch.chunk(OOD_syntheses, OOD_syntheses.shape[0])
        OOD_labels = OOD_labels.numpy()

    return OOD_syntheses, OOD_labels