in trainers/catex.py [0:0]
def perturb_prompt(self, method='none'):
if method == 'none':
return self.ctx
coef_dict = {
'neg': [-1., 0.], 'zero': [0., 0.], 'randn': [0., 1.], 'randn_add': [1., 1.], 'swap': [0., 1.]
}
assert method in coef_dict
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
ncls, nctx, ndim = ctx.shape
assert nctx > 1
for i in range(self.cfg.TRAINER.ID_PERTUR_NUM):
# perturb one prompt for each class
ctx_ind = torch.randint(0, nctx, size=(ncls,))
cls_ind = torch.arange(ncls)
src_mask = torch.ones((ncls, nctx, 1)).type_as(ctx)
src_mask[cls_ind, ctx_ind] = 0.
src_ctx = ctx[cls_ind, ctx_ind].detach()
if method == 'swap':
ori_ind = torch.arange(ncls)
while True:
rand_ind = torch.randperm(ncls)
if (ori_ind != rand_ind).all():
noise = src_ctx[rand_ind]
break
else:
noise = torch.randn_like(ctx[:, 0, :])
src_coef, noise_coef = coef_dict[method]
perturb = torch.zeros_like(ctx)
perturb[cls_ind, ctx_ind] = src_coef * src_ctx + noise_coef * noise
ctx = ctx * src_mask + perturb * (1. - src_mask)
return ctx