synthesis/KNN.py (139 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch import faiss # import umap # import time #import matplotlib.pyplot as plt import faiss.contrib.torch_utils # from sklearn import manifold, datasets # from torch.distributions import MultivariateNormal import torch.nn.functional as F def KNN_dis_search_decrease(target, index, K=50, select=1,): ''' data_point: Queue for searching k-th points target: the target of the search K ''' #Normalize the features target_norm = torch.norm(target, p=2, dim=1, keepdim=True) normed_target = target / target_norm #start_time = time.time() distance, output_index = index.search(normed_target, K) k_th_distance = distance[:, -1] #k_th_output_index = output_index[:, -1] k_th_distance, minD_idx = torch.topk(k_th_distance, select) #k_th_index = k_th_output_index[minD_idx] return minD_idx, k_th_distance def KNN_dis_search_distance(target, index, K=50, num_points=10, length=2000,depth=342): ''' data_point: Queue for searching k-th points target: the target of the search K ''' #Normalize the features target_norm = torch.norm(target, p=2, dim=1, keepdim=True) normed_target = target / target_norm #start_time = time.time() distance, output_index = index.search(normed_target, K) k_th_distance = distance[:, -1] k_th = k_th_distance.view(length, -1) target_new = target.view(length, -1, depth) #k_th_output_index = output_index[:, -1] k_th_distance, minD_idx = torch.topk(k_th, num_points, dim=0) # minD_idx = minD_idx.squeeze() point_list = [] for i in range(minD_idx.shape[1]): point_list.append(i*length + minD_idx[:,i]) #return torch.cat(point_list, dim=0) return target[torch.cat(point_list)] def generate_outliers(ID, input_index, negative_samples, ID_points_num=2, K=20, select=1, cov_mat=0.1, sampling_ratio=1.0, pic_nums=30, depth=342, text_anchors=None, cls_mask=None): ncls, nsample, ndim = ID.shape length, _ = negative_samples.shape normed_data = ID / torch.norm(ID, p=2, dim=-1, keepdim=True) if cls_mask is not None: normed_data = normed_data[cls_mask] #.float() text_anchors = text_anchors[cls_mask] distance = torch.cdist(normed_data, normed_data.detach()).half() # shape(ncls, nsample, nsample) k_th_distance = -torch.topk(-distance, K, dim=-1)[0][..., -1] # k-th nearset (smallest distance), shape(ncls, nsample) minD_idx = torch.topk(k_th_distance, select, dim=1)[1] # top-k largest distance, shape(ncls, select) minD_idx = minD_idx[:, np.random.choice(select, int(pic_nums), replace=False)] #shape(ncls, pic_nums) cls_idx = torch.arange(ncls).view(ncls, 1) if cls_mask is not None: cls_idx = cls_idx[cls_mask] data_point_list = ID[cls_idx.repeat(1, pic_nums).view(-1), minD_idx.view(-1)].view(-1, pic_nums, 1, ndim) negative_sample_cov = cov_mat*negative_samples.view(1, 1, length, ndim) negative_sample_list = (negative_sample_cov + data_point_list).view(-1, pic_nums*length, ndim) normed_ood_feat = F.normalize(negative_sample_list, p=2, dim=-1) #shape(cls, pic_nums*length, 512) distance = torch.cdist(normed_ood_feat, normed_data.half()) # shape(ncls, pic_nums*length, nsample) k_th_distance = -torch.topk(-distance, K, dim=-1)[0][..., -1] # k-th nearset (smallest distance), shape(ncls, pic_nums*length) if text_anchors is not None: # shape(cls,2,ndim) intra_similarity = torch.bmm(normed_ood_feat, text_anchors.permute(0, 2, 1)) #shape(cls,pic_nums*length,2) # only perserve samples with higher similarity to the perturbed text-feature intra_candidate = intra_similarity[..., 0] < intra_similarity[..., 1] # inter_similarity = normed_ood_feat.float() @ text_anchors[:, 0, :].float().T #shape(cls, pic_nums*length,ncls) # # only perserve samples with highest similarity among in-distribution text-features # inter_candidate = inter_similarity.argmax(dim=-1) == torch.arange(ncls).view(ncls,1).to(inter_similarity.device) candidate = intra_candidate #& inter_candidate k_th_distance *= candidate.float() k_distance, minD_idx = torch.topk(k_th_distance, ID_points_num, dim=1) # top-k largest distance, shape(ncls, ID_points_num) OOD_labels = torch.arange(normed_data.size(0)).view(-1, 1).repeat(1, ID_points_num).view(-1) OOD_syntheses = negative_sample_list[OOD_labels, minD_idx.view(-1)] #shape(ncls*ID_points_num, 512) if text_anchors is not None: valid = k_distance.view(-1) > 0 OOD_syntheses, OOD_labels = OOD_syntheses[valid], OOD_labels[valid] if OOD_syntheses.shape[0]: # concatenate ood_samples outside OOD_syntheses = torch.chunk(OOD_syntheses, OOD_syntheses.shape[0]) OOD_labels = OOD_labels.numpy() return OOD_syntheses, OOD_labels def generate_outliers_ours(ID, input_index, negative_samples, ID_points_num=2, K=20, select=1, cov_mat=0.1, sampling_ratio=1.0, pic_nums=30, depth=342, text_anchors=None): assert text_anchors is not None length = negative_samples.shape[0] ncls, nsample, ndim = ID.shape if True: rand_ind = np.random.choice(nsample, select, replace=False) id_data_points = ID[:, rand_ind, :].detach().view(ncls, select, 1, ndim) else: # rand_ind = [] # for ci in range(ncls): # rand_ind.append(torch.randperm(nsample) + ci * nsample) # rand_ind = torch.cat(rand_ind) # id_data_points = ID.view(ncls*nsample, ndim)[rand_ind].view(ncls, nsample, ndim) # id_data_points = id_data_points[:, :select, :].view(ncls, select, 1, ndim) id_data_points = [] normed_data = ID / torch.norm(ID, p=2, dim=-1, keepdim=True) for ci in range(ncls): input_index.add(normed_data[ci]) minD_idx, k_th = KNN_dis_search_decrease(ID[ci], input_index, K, select) id_data_points.append(ID[ci][minD_idx]) # shape(select,ndim) id_data_points = torch.stack(id_data_points).view(ncls, select, 1, ndim) negative_sample_cov = cov_mat * negative_samples.view(1, 1, length, ndim) negative_samples = (id_data_points + negative_sample_cov).view(ncls, select*length, ndim) normed_ood_feat = F.normalize(negative_samples, p=2, dim=1) #shape(select*length,512) inter_similarity = normed_ood_feat @ text_anchors[:, 0, :].T #shape(ncls,select*length,ncls) # only perserve samples with highest similarity among in-distribution text-features inter_candidate = inter_similarity.argmax(dim=-1) == torch.arange(ncls).cuda().view(ncls, 1) intra_similarity = torch.bmm(normed_ood_feat, text_anchors.transpose(1,2)) #shape(ncls,select*length,2) # only perserve samples with higher similarity to the perturbed text-feature intra_candidate = intra_similarity[..., 0] < intra_similarity[..., 1] candidate = inter_candidate & intra_candidate ood_samples, ood_labels = [], [] for ci in range(ncls): syntheses = negative_samples[ci][candidate[ci]] valid_num = len(syntheses) labels = np.full((valid_num,), ci, dtype=np.int64) if valid_num > ID_points_num: rand_ind = np.random.choice(valid_num, ID_points_num, replace=False, p=None) syntheses, labels = syntheses[rand_ind], labels[rand_ind] ood_samples.append(syntheses) ood_labels.append(labels) # concatenate ood_samples outside ood_labels = np.concatenate(ood_labels) return ood_samples, ood_labels, candidate def generate_outliers_OOD(ID, input_index, negative_samples, K=100, select=100, sampling_ratio=1.0): data_norm = torch.norm(ID, p=2, dim=1, keepdim=True) normed_data = ID / data_norm rand_ind = np.random.choice(normed_data.shape[1], int(normed_data.shape[1] * sampling_ratio), replace=False) index = input_index index.add(normed_data[rand_ind]) minD_idx, k_th = KNN_dis_search_decrease(negative_samples, index, K, select) return negative_samples[minD_idx] def generate_outliers_rand(ID, input_index, negative_samples, ID_points_num=2, K=20, select=1, cov_mat=0.1, sampling_ratio=1.0, pic_nums=10, repeat_times=30, depth=342): length = negative_samples.shape[0] data_norm = torch.norm(ID, p=2, dim=1, keepdim=True) normed_data = ID / data_norm rand_ind = np.random.choice(normed_data.shape[1], int(normed_data.shape[1] * sampling_ratio), replace=False) index = input_index index.add(normed_data[rand_ind]) minD_idx, k_th = KNN_dis_search_decrease(ID, index, K, select) ID_boundary = ID[minD_idx] negative_sample_list = [] for i in range(repeat_times): select_idx = np.random.choice(select, int(pic_nums), replace=False) sample_list = ID_boundary[select_idx] mean = sample_list.mean(0) var = torch.cov(sample_list.T) var = torch.mm(negative_samples, var) trans_samples = mean + var negative_sample_list.append(trans_samples) negative_sample_list = torch.cat(negative_sample_list, dim=0) point = KNN_dis_search_distance(negative_sample_list, index, K, ID_points_num, length,depth) index.reset() #return ID[minD_idx] return point