in uimnet/measures/mog.py [0:0]
def estimate(self, train_loader):
_to_device = functools.partial(utils.to_device, device=self.algorithm.device)
collected = collections.defaultdict(list)
utils.message('Collecting logits and features')
with torch.no_grad():
for i, batch in enumerate(train_loader):
batch = utils.apply_fun(_to_device, batch)
x, y = batch['x'], batch['y']
_feats = self.algorithm.get_features(x).detach()
_y = y.detach()
if utils.is_distributed():
_feats = torch.cat(utils.all_gather(_feats), dim=0)
_y = torch.cat(utils.all_gather(_feats), dim=0)
collected['features'] += [_feats.cpu()]
collected['y'] += [_y.cpu()]
if __DEBUG__ and i > 2:
break
collected = dict(collected)
collected = {k: torch.cat(v, dim=0) for k, v in collected.items()}
num_classes = self.algorithm.num_classes
all_classes = collected['y'].unique()
utils.message(f'{type(self.algorithm)}:{len(all_classes)}, {num_classes}')
assert len(all_classes) == num_classes
for y in all_classes:
mask = torch.where(y == collected['y'])
X = collected['features'][mask]
mu_hat = X.mean(dim=0) # D
cov_hat = self.estimate_cov(X, mu_hat, cov_estimator=self.cov_estimator)
self.mus[int(y)] = mu_hat
self.covs[int(y)] = cov_hat
self.counts[int(y)] = len(y)
self.N = sum(self.counts.values())