in models/feat_pool.py [0:0]
def calc_maha_score(self, samples: torch.Tensor, force_calc=True):
# samples: shape(n,ndim)
ns, nc = samples.shape[0], self.class_num
sample_num_per_cls = self.class_ptr.view(nc, 1)
valid_mask = (self.queue != 0).any(dim=-1) # shape(nc,ns)
assert (valid_mask.sum(dim=1, keepdim=True) == sample_num_per_cls).all()
mean_embed_id = self.queue.sum(dim=1) / sample_num_per_cls # shape(nc,ndim)
if force_calc or not hasattr(self, 'maha_cov_inv'):
X = (self.queue - mean_embed_id[:, None, :])[valid_mask] # shape(x,ndim)
covariance = (X.T @ X) / len(X) # shape(ndim,ndim), class-agnostic
covariance += 0.0001 * torch.eye(len(covariance), device=X.device)
maha_cov_inv = covariance.inverse()[None, :, :]
setattr(self, 'maha_cov_inv', maha_cov_inv)
else:
maha_cov_inv = getattr(self, 'maha_cov_inv')
samples = samples[:, None, :] - mean_embed_id[None, :, :] # shape(ns,1,ndim) - shape(1,nc,ndim) = shape(ns,nc,ndim)
samples = samples.view(ns*nc, self.feat_dim, 1) # shape(ns*nc,ndim,1)
maha_dist = torch.bmm(torch.bmm(samples.permute(0,2,1), maha_cov_inv.expand(ns*nc,-1,-1)), samples) # f^T @ Cov^-1 @ f
maha_dist = maha_dist.view(ns, nc)
return - torch.max(-maha_dist, dim=1).values