def gen_ood()

in synthesis/feature_sample.py [0:0]


    def gen_ood(self, anchors=None, normed=True, device='cuda:0', cls_mask=None, ret_cand=False):
        ood_samples, ood_labels = [], []

        if self.mode == 'vos':
            ood_samples, ood_labels = [], []
            mean_embed_id = self.queue.mean(dim=1, keepdim=True)
            X = self.queue - mean_embed_id
            covariance = torch.bmm(X.permute(0,2,1), X).mean(dim=0) / self.sample_number
            covariance += 0.0001 * torch.eye(len(covariance), device=X.device)

            new_dis = MultivariateNormal(torch.zeros(512).cuda(), covariance_matrix=covariance)
            negative_samples = new_dis.rsample((self.sample_from,))
            prob_density = new_dis.log_prob(negative_samples)
            cur_samples, index_prob = torch.topk(- prob_density, self.select)
            negative_samples = negative_samples[index_prob]

            for ci, miu in enumerate(mean_embed_id):
                rand_ind = torch.randperm(self.select)[:self.pick_nums]
                ood_samples.append(miu + negative_samples[rand_ind])
                ood_labels.extend([ci] * self.pick_nums)
            ood_samples = torch.cat(ood_samples)

        elif self.mode == 'npos':
            negative_samples = self.new_dis.rsample((self.sample_from,)).half().to(self.device)

            text_anchors = anchors.to(self.device) if anchors is not None else None
            ood_samples, ood_labels = generate_outliers(self.queue, input_index=self.KNN_index, negative_samples=negative_samples, 
                                                        ID_points_num=self.ID_points_num, K=self.K, select=self.select, 
                                                        sampling_ratio=1.0, pic_nums=self.pick_nums, depth=self.feat_dim,
                                                        text_anchors=text_anchors, cls_mask=cls_mask)
        
        elif self.mode == 'ours':
            negative_samples = self.new_dis.rsample((self.sample_from,)).to(self.device)
            ood_samples, ood_labels, candidates = \
                generate_outliers_ours(self.queue.float(), input_index=self.KNN_index, negative_samples=negative_samples, 
                                       ID_points_num=self.ID_points_num, K=self.K, select=self.select, 
                                       sampling_ratio=1.0, pic_nums=self.pick_nums, depth=self.feat_dim,
                                       text_anchors=anchors.float())

        ood_samples = torch.cat(ood_samples).to(device)
        if normed:
            ood_samples = F.normalize(ood_samples, p=2, dim=1)
        ood_labels = torch.tensor(ood_labels).to(device)

        if ret_cand:
            return ood_samples, ood_labels, candidates
        return ood_samples, ood_labels