def perturb_prompt()

in trainers/catex.py [0:0]


    def perturb_prompt(self, method='none'):
        if method == 'none':
            return self.ctx
        
        coef_dict = {
            'neg': [-1., 0.], 'zero': [0., 0.], 'randn': [0., 1.], 'randn_add': [1., 1.], 'swap': [0., 1.]
        }
        assert method in coef_dict
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
        ncls, nctx, ndim = ctx.shape
        assert nctx > 1

        for i in range(self.cfg.TRAINER.ID_PERTUR_NUM):
            # perturb one prompt for each class
            ctx_ind = torch.randint(0, nctx, size=(ncls,))
            cls_ind = torch.arange(ncls)

            src_mask = torch.ones((ncls, nctx, 1)).type_as(ctx)
            src_mask[cls_ind, ctx_ind] = 0.
            src_ctx = ctx[cls_ind, ctx_ind].detach()
            
            if method == 'swap':
                ori_ind = torch.arange(ncls)
                while True:
                    rand_ind = torch.randperm(ncls)
                    if (ori_ind != rand_ind).all():
                        noise = src_ctx[rand_ind]
                        break
            else:
                noise = torch.randn_like(ctx[:, 0, :])
            src_coef, noise_coef = coef_dict[method]

            perturb = torch.zeros_like(ctx)
            perturb[cls_ind, ctx_ind] = src_coef * src_ctx + noise_coef * noise

            ctx = ctx * src_mask + perturb * (1. - src_mask)

        return ctx