in uimnet/workers/embeddings_extractor.py [0:0]
def __call__(self, cfg, encoder, datanode):
self.setup(cfg)
# Dump configuration file
output_path = Path(cfg.sweep_dir)
output_path.mkdir(parents=True, exist_ok=True)
with open(output_path / f'cfg_{cfg.experiment.rank}', 'w') as fp:
OmegaConf.save(cfg, f=fp.name)
self.encoder = encoder.to(cfg.experiment.device)
if utils.is_distributed():
encoder = torch.nn.parallel.DistributedDataParallel(
encoder, device_ids=[cfg.experiment.device])
self.datanode = datanode
_to_device = functools.partial(utils.to_device,
device=cfg.experiment.device)
self.encoder.eval()
self.datanode.eval()
_pin_memory = True if 'cuda' in cfg.experiment.device else False,
loader = self.datanode.get_loader(batch_size=cfg.dataset.batch_size,
shuffle=False,
pin_memory=_pin_memory,
num_workers=cfg.experiment.num_workers)
out = collections.defaultdict(list)
with torch.no_grad():
for i, batch in enumerate(loader):
batch = utils.apply_fun(_to_device, batch)
out['embeddings'] += [self.encoder(batch['x'])]
out['targets'] += [batch['y']]
out['indices'] += [batch['index']]
if i > 2 and __DEBUG__:
break
out = {k: torch.cat(v, dim=0).detach().cpu() for k, v in out.items()}
return out