in uimnet/workers/evaluator.py [0:0]
def __call__(self, model_dir, eval_cfg, train_cfg, Algorithm, Measure, partitions):
self.model_dir = model_dir
self.eval_cfg = copy.deepcopy(eval_cfg)
self.train_cfg = copy.deepcopy(train_cfg)
self.Algorithm = Algorithm
self.datasets = datasets
self.setup(eval_cfg)
eval_cfg, train_cfg = [utils.guard(el) for el in (eval_cfg, train_cfg)]
self.trace = f'ood_{Measure.__name__}'
# Check if training was completed
path = Path(self.model_dir)
if not utils.trace_exists('train.done', dir_=str(path)):
utils.message(f'Train completion tracer missing!')
return dict(status='missing', records=None)
utils.message(f'Train completion tracer found.')
utils.write_trace(f'{self.trace}.running', dir_=str(path))
# utils.write_trace('ood_evaluation.running', dir_=str(path))
##############
## Datasets ##
##############
loaders_kwargs = dict(
batch_size=train_cfg.dataset.batch_size,
shuffle=False,
pin_memory=True if 'cuda' in eval_cfg.experiment.device else False,
num_workers=eval_cfg.experiment.num_workers)
loaders, datanodes = get_loaders_datanodes(partitions, train_cfg,
loaders_kwargs=loaders_kwargs,
seed=eval_cfg.dataset.seed)
for datanode in datanodes.values():
datanode.eval()
num_classes = partitions[('train', 'in')].num_classes
###############
## Algorithm ##
###############
self.algorithm = Algorithm(num_classes=num_classes,
arch=train_cfg.algorithm.arch,
device=eval_cfg.experiment.device,
use_mixed_precision=train_cfg.algorithm.use_mixed_precision,
seed=train_cfg.algorithm.seed,
sn=train_cfg.algorithm.sn,
sn_coef=train_cfg.algorithm.sn_coef,
sn_bn=train_cfg.algorithm.sn_bn
)
utils.message(eval_cfg)
self.algorithm.initialize()
self.algorithm.load_state(train_cfg.output_dir,
map_location=eval_cfg.experiment.device)
records = []
self.algorithm.eval()
with torch.no_grad():
for temperature_mode in ['initial', 'learned']:
self.algorithm.set_temperature(temperature_mode)
measure = Measure(algorithm=self.algorithm)
measure.estimate(loaders[('train', 'in')])
measurements = collections.defaultdict(list)
features = collections.defaultdict(list)
labels = collections.defaultdict(list)
for (partition, split) in [('eval', 'in'),
('val', 'in'), ('val', OUT)]:
key = (partition, split)
for i, batch in enumerate(loaders[key]):
x, y = batch['x'].cuda(), batch['y'].cuda()
# if key == ('val', OUT):
# x = torch.zeros_like(x).normal_()
_measurement = measure(x)
_labels = y.view(-1, 1)
if utils.is_distributed():
_measurement = torch.cat(utils.all_gather(_measurement), dim=0)
_labels = torch.cat(utils.all_gather(_labels), dim=0)
measurements[key] += [_measurement.detach().cpu()]
labels[key] += [_labels.detach().cpu()]
if __DEBUG__ and i > 1:
break
measurements = {k: torch.cat(l) for k, l in dict(measurements).items()}
labels = {k: torch.cat(l) for k, l in dict(labels).items()}
if utils.is_not_distributed_or_is_rank0():
torch.save(measurements, path / f'{Measure.__name__}_measurements.pth')
torch.save(labels, path / f'{Measure.__name__}_labels.pth')
# Evaluating metrics
for Metric in METRICS:
metric = Metric(measurements[('eval', 'in')])
value = metric(measurements[('val', 'in')],
measurements[('val', OUT)])
record = dict(metric=metric.__class__.__name__,
measure=Measure.__name__,
value=value,
temperature_mode=temperature_mode)
record.update(utils.flatten_nested_dicts(train_cfg))
records += [record]
# Saving records
if utils.is_not_distributed_or_is_rank0():
save_path = path / f'{Measure.__name__}_results.pkl'
shutil.rmtree(save_path, ignore_errors=True)
with open(save_path, 'wb') as fp:
pickle.dump(records, fp, protocol=pickle.HIGHEST_PROTOCOL)
utils.message(pd.DataFrame.from_records(utils.apply_fun(utils.to_scalar, records)).round(4))
utils.write_trace(f'{self.trace}.done', dir_=str(path))
return records