in trainers/catex.py [0:0]
def calc_loss(self, logits, label, image=None, image_features=None, text_features=None, return_norm=False):
nb, ncls = logits.shape
# 1. classification
if self.cfg.TRAINER.OOD_PROMPT and self.epoch >= self.cfg.TRAINER.START_EPOCH \
and self.is_large_ID() and self.id_pool.ready():
if not return_norm:
image_features = F.normalize(image_features, p=2, dim=1)
ood_text_features = self.model.get_text_features(ood_prompt=True)
if self.cfg.TRAINER.OOD_PROMPT_CE_LOSS:
logits = self.model.get_logits(image_features, torch.cat((text_features, ood_text_features)))
logits[torch.arange(nb), label+ncls] = -10. # generally -inf
loss = 2. * self.ce_criterion(logits, label)
# 2. prompt perturbation
perturbed_text_features = None
if self.cfg.TRAINER.OOD_PROMPT and self.cfg.TRAINER.ID_PERTURB_LOSS and self.epoch >= self.cfg.TRAINER.START_EPOCH:
with torch.no_grad():
perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods))
loss += 0.1 * self.model.calc_prompt_loss(text_features, perturbed_text_features)
# 3. outlier exposure
assert text_features is not None
if self.is_large_ID():
if self.id_pool.ready() and self.cfg.TRAINER.OOD_PROMPT and self.epoch >= self.cfg.TRAINER.START_EPOCH:
if logits.size(0) < self.id_pool.queue.size(0):
cls_mask = torch.unique(label).cpu()
else:
cls_mask = None
if self.cfg.TRAINER.OOD_ANCHOR:
if self.cfg.TRAINER.ID_PERTURB_LOSS and False:
perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods))
logit_scale = self.model.logit_scale.exp()
id_pos_sim = (image_features * text_features[label]).sum(dim=-1) * logit_scale
id_neg_sim = (image_features * perturbed_text_features[label]).sum(dim=-1) * logit_scale
loss += F.cross_entropy(torch.stack((id_pos_sim, id_neg_sim), dim=1),
torch.zeros((len(id_pos_sim),), dtype=torch.long, device=self.device)) * 0.5
elif perturbed_text_features is None:
with torch.no_grad():
perturbed_text_features = self.model.get_text_features(perturb=random.choice(perturb_methods))
text_anchors = torch.stack((text_features, perturbed_text_features), dim=1).detach()
else:
text_anchors = None
ood_features, ood_labels = self.id_pool.gen_ood(anchors=text_anchors, device=self.device, cls_mask=cls_mask)
if self.cfg.TRAINER.OOD_OE_LOSS:
ood_logits = self.model.get_logits(ood_features, text_features, logit_scale=1.)
loss += 0.5 * -(ood_logits.mean(1) - torch.logsumexp(ood_logits, dim=1)).mean()
if self.cfg.TRAINER.OOD_PROMPT:
# ood_text_features = self.model.get_text_features(ood_prompt=True)
if self.cfg.TRAINER.OOD_PROMPT_ORTH:
assert self.cfg.TRAINER.OOD_PROMPT_NUM > 1
all_ood_text_features = self.model.get_all_ood_text_features()
# (1000,5,512) x (1000,512,5) -> (1000,5,5)
ood_sim_matrix = torch.bmm(all_ood_text_features, all_ood_text_features.transpose(1,2))
ood_text_num = ood_sim_matrix.shape[-1]
zrange = torch.arange(ood_text_num)
ood_sim_matrix[:, zrange, zrange] = 0.
loss += 0.1 * ood_sim_matrix.mean()
if self.cfg.TRAINER.OOD_PROMPT_CE_LOSS:
ood_logits = self.model.get_logits(ood_features,
torch.cat((ood_text_features, text_features)))
ood_logits[torch.arange(ood_logits.shape[0]), ood_labels+ncls] = -10. # generally -inf
loss += 0.5 * self.ce_criterion(ood_logits, ood_labels)
if self.cfg.TRAINER.OOD_PROMPT_MARGIN_LOSS:
if self.cfg.TRAINER.OOD_PROMPT_MARGIN_SOFT_LOSS:
logit_scale = self.model.logit_scale.exp()
else:
logit_scale = 1.
id_pos_sim = (image_features * text_features[label]).sum(dim=-1) * logit_scale
id_neg_sim = (image_features * ood_text_features[label]).sum(dim=-1) * logit_scale
ood_pos_sim = (ood_features * ood_text_features[ood_labels]).sum(dim=-1) * logit_scale
ood_neg_sim = (ood_features * text_features[ood_labels]).sum(dim=-1) * logit_scale
# id_pos_sim = (image_features @ text_features.T).max(dim=-1)[0] * logit_scale
# id_neg_sim = (image_features @ ood_text_features.T).max(dim=-1)[0] * logit_scale
# ood_pos_sim = (ood_features @ ood_text_features.T).max(dim=-1)[0] * logit_scale
# ood_neg_sim = (ood_features @ text_features.T).max(dim=-1)[0] * logit_scale
if self.cfg.TRAINER.OOD_PROMPT_MARGIN_SOFT_LOSS:
loss += F.cross_entropy(torch.stack((id_pos_sim, id_neg_sim), dim=1),
torch.zeros((len(id_pos_sim),), dtype=torch.long, device=self.device)) + \
F.cross_entropy(torch.stack((ood_pos_sim, ood_neg_sim), dim=1),
torch.zeros((len(ood_pos_sim),), dtype=torch.long, device=self.device))
else:
loss += (id_neg_sim - id_pos_sim).relu().mean() + (ood_neg_sim - ood_pos_sim).relu().mean()
else:
ood_data, _ = next(self.ood_loader.__iter__())
ood_data = ood_data.to(self.device)
ood_logits = self.model(ood_data) #/ self.model.logit_scale.exp()
loss += 0.1 * -(ood_logits.mean(1) - torch.logsumexp(ood_logits, dim=1)).mean()
return loss