def __call__()

in uimnet/workers/mog.py [0:0]


  def __call__(self, mog_cfg, train_cfg, Algorithm, dataset):
    elapsed_time = time.time()
    self.mog_cfg = mog_cfg
    self.train_cfg = train_cfg
    self.Algorithm = Algorithm
    self.dataset = dataset

    self.setup(mog_cfg)
    utils.message(mog_cfg)

    if utils.is_not_distributed_or_is_rank0():
      if not utils.trace_exists('train.done', dir_=train_cfg.output_dir):
        err_msg = f'Training not finished'
        raise RuntimeError(err_msg)
      utils.write_trace('mog.running', dir_=train_cfg.output_dir)

    utils.message('Instantiating data node')
    # Training will be done either on the validation set or the training set
    self.datanode = datasets.SplitDataNode(
      dataset=dataset,
      transforms=datasets.TRANSFORMS,
      splits_props=train_cfg.dataset.splits_props,
      seed=train_cfg.dataset.seed)
    num_classes = self.datanode.splits['train'].num_classes

    utils.message('Instantiating algorithm')

    self.algorithm = Algorithm(num_classes=num_classes,
                               arch=train_cfg.algorithm.arch,
                               device=mog_cfg.experiment.device,
                               use_mixed_precision=train_cfg.algorithm.use_mixed_precision,
                               seed=train_cfg.algorithm.seed,
                               sn=train_cfg.algorithm.sn,
                               sn_coef=train_cfg.algorithm.sn_coef,
                               sn_bn=train_cfg.algorithm.sn_bn)

    self.algorithm.initialize()
    utils.message(self.algorithm)
    self.algorithm.load_state(train_cfg.output_dir, map_location=mog_cfg.experiment.device)
    utils.message(f'Algorithm gaussian mixture attribute {self.algorithm.gaussian_mixture}')

    # Preparing for mixture of gaussian estimation
    self.datanode.eval()
    self.algorithm.eval()
    eval_loader = self.datanode.get_loader('train',
                                            batch_size=mog_cfg.dataset.batch_size,
                                            shuffle=False,
                                            pin_memory=True if 'cuda' in mog_cfg.experiment.device else False,
                                            num_workers=mog_cfg.experiment.num_workers)

    collected = collections.defaultdict(list)
    utils.message('Collecting logits and features')
    with torch.no_grad():
      for i, batch in enumerate(eval_loader):
        batch = utils.apply_fun(functools.partial(utils.to_device, device=mog_cfg.experiment.device), batch)
        x, y = batch['x'], batch['y']
        collected['features'] += [self.algorithm.get_features(x).cpu()]
        collected['y'] += [y]
        #if __DEBUG__ and i > 2:
          #break
    collected = dict(collected)

    # Concatenating locally and acorss workers
    utils.message('Concatenating accross workers')
    #collected = {k: utils.all_cat(v, dim=0) for k, v in collected.items()}
    collected = {k: torch.cat(v, dim=0) for k, v in collected.items()}

    all_classes = collected['y'].unique()
    utils.message(f'{type(self.algorithm)}:{len(all_classes)}, {num_classes}')
    assert len(all_classes) == num_classes

    #gaussian_mixture = GaussianMixture(K=num_classes, D=collected['features'].shape[1])
    for y in all_classes:
      mask = torch.where(y == collected['y'])
      self.algorithm.gaussian_mixture.add_gaussian_from_data(collected['features'][mask], y,eps=mog_cfg.eps)

    utils.message('Serialiazing estimated mixture of gaussians.')
    if utils.is_not_distributed_or_is_rank0():
      #del self.algorithm.gaussian_mixture
      #self.algorithm.add_module('gaussian_mixture', gaussian_mixture)
      self.algorithm.save_state(train_cfg.output_dir)
      # mog_path = Path(train_cfg.output_dir) / 'mog.pkl'
      # with open(mog_path, 'wb') as fp:
      #   pickle.dump(gaussian_mixture, fp, protocol=pickle.HIGHEST_PROTOCOL)
      utils.write_trace('mog.done', dir_=train_cfg.output_dir)

    utils.message('Mixture of gaussians estimation completed.')

    return {'data': None, #self.algorithm.gaussian_mixture,
            'mog_cfg': mog_cfg,
            'train_cfg': train_cfg,
            'elapsed_time': time.time() - elapsed_time,
            'status': 'done'}