in synthesis/KNN.py [0:0]
def generate_outliers_ours(ID, input_index, negative_samples, ID_points_num=2, K=20, select=1, cov_mat=0.1, sampling_ratio=1.0, pic_nums=30, depth=342,
text_anchors=None):
assert text_anchors is not None
length = negative_samples.shape[0]
ncls, nsample, ndim = ID.shape
if True:
rand_ind = np.random.choice(nsample, select, replace=False)
id_data_points = ID[:, rand_ind, :].detach().view(ncls, select, 1, ndim)
else:
# rand_ind = []
# for ci in range(ncls):
# rand_ind.append(torch.randperm(nsample) + ci * nsample)
# rand_ind = torch.cat(rand_ind)
# id_data_points = ID.view(ncls*nsample, ndim)[rand_ind].view(ncls, nsample, ndim)
# id_data_points = id_data_points[:, :select, :].view(ncls, select, 1, ndim)
id_data_points = []
normed_data = ID / torch.norm(ID, p=2, dim=-1, keepdim=True)
for ci in range(ncls):
input_index.add(normed_data[ci])
minD_idx, k_th = KNN_dis_search_decrease(ID[ci], input_index, K, select)
id_data_points.append(ID[ci][minD_idx]) # shape(select,ndim)
id_data_points = torch.stack(id_data_points).view(ncls, select, 1, ndim)
negative_sample_cov = cov_mat * negative_samples.view(1, 1, length, ndim)
negative_samples = (id_data_points + negative_sample_cov).view(ncls, select*length, ndim)
normed_ood_feat = F.normalize(negative_samples, p=2, dim=1) #shape(select*length,512)
inter_similarity = normed_ood_feat @ text_anchors[:, 0, :].T #shape(ncls,select*length,ncls)
# only perserve samples with highest similarity among in-distribution text-features
inter_candidate = inter_similarity.argmax(dim=-1) == torch.arange(ncls).cuda().view(ncls, 1)
intra_similarity = torch.bmm(normed_ood_feat, text_anchors.transpose(1,2)) #shape(ncls,select*length,2)
# only perserve samples with higher similarity to the perturbed text-feature
intra_candidate = intra_similarity[..., 0] < intra_similarity[..., 1]
candidate = inter_candidate & intra_candidate
ood_samples, ood_labels = [], []
for ci in range(ncls):
syntheses = negative_samples[ci][candidate[ci]]
valid_num = len(syntheses)
labels = np.full((valid_num,), ci, dtype=np.int64)
if valid_num > ID_points_num:
rand_ind = np.random.choice(valid_num, ID_points_num, replace=False, p=None)
syntheses, labels = syntheses[rand_ind], labels[rand_ind]
ood_samples.append(syntheses)
ood_labels.append(labels)
# concatenate ood_samples outside
ood_labels = np.concatenate(ood_labels)
return ood_samples, ood_labels, candidate