def calc_maha_score()

in models/feat_pool.py [0:0]


    def calc_maha_score(self, samples: torch.Tensor, force_calc=True):
        # samples: shape(n,ndim)
        ns, nc = samples.shape[0], self.class_num

        sample_num_per_cls = self.class_ptr.view(nc, 1)
        valid_mask = (self.queue != 0).any(dim=-1)  # shape(nc,ns)
        assert (valid_mask.sum(dim=1, keepdim=True) == sample_num_per_cls).all()
        mean_embed_id = self.queue.sum(dim=1) / sample_num_per_cls  # shape(nc,ndim)

        if force_calc or not hasattr(self, 'maha_cov_inv'):
            X = (self.queue - mean_embed_id[:, None, :])[valid_mask]  # shape(x,ndim)
            covariance = (X.T @ X) / len(X)  # shape(ndim,ndim), class-agnostic
            covariance += 0.0001 * torch.eye(len(covariance), device=X.device)
            maha_cov_inv = covariance.inverse()[None, :, :]
            setattr(self, 'maha_cov_inv', maha_cov_inv)
        else:
            maha_cov_inv = getattr(self, 'maha_cov_inv')
        
        samples = samples[:, None, :] - mean_embed_id[None, :, :]  # shape(ns,1,ndim) - shape(1,nc,ndim) = shape(ns,nc,ndim)
        samples = samples.view(ns*nc, self.feat_dim, 1)  # shape(ns*nc,ndim,1)
        maha_dist = torch.bmm(torch.bmm(samples.permute(0,2,1), maha_cov_inv.expand(ns*nc,-1,-1)), samples)  # f^T @ Cov^-1 @ f
        maha_dist = maha_dist.view(ns, nc)
        return - torch.max(-maha_dist, dim=1).values