def estimate()

in uimnet/measures/mog.py [0:0]


  def estimate(self, train_loader):
    _to_device = functools.partial(utils.to_device, device=self.algorithm.device)

    collected = collections.defaultdict(list)
    utils.message('Collecting logits and features')
    with torch.no_grad():
      for i, batch in enumerate(train_loader):
        batch = utils.apply_fun(_to_device, batch)
        x, y = batch['x'], batch['y']
        _feats = self.algorithm.get_features(x).detach()
        _y = y.detach()
        if utils.is_distributed():
          _feats = torch.cat(utils.all_gather(_feats), dim=0)
          _y = torch.cat(utils.all_gather(_feats), dim=0)
        collected['features'] += [_feats.cpu()]
        collected['y'] += [_y.cpu()]
        if __DEBUG__ and i > 2:
          break

      collected = dict(collected)
      collected = {k: torch.cat(v, dim=0) for k, v in collected.items()}

      num_classes = self.algorithm.num_classes
      all_classes = collected['y'].unique()
      utils.message(f'{type(self.algorithm)}:{len(all_classes)}, {num_classes}')
      assert len(all_classes) == num_classes

      for y in all_classes:
        mask = torch.where(y == collected['y'])
        X = collected['features'][mask]
        mu_hat = X.mean(dim=0)  # D
        cov_hat = self.estimate_cov(X, mu_hat, cov_estimator=self.cov_estimator)
        self.mus[int(y)] = mu_hat
        self.covs[int(y)] = cov_hat
        self.counts[int(y)] = len(y)
      self.N = sum(self.counts.values())