in lib/metrics.py [0:0]
def __call__(self, net, epoch, args, all_logs):
"""
Evaluates the current state of the network without
and with quantization and stores a checkpoint.
"""
print("Valiation at epoch %d" % epoch)
# also store current state of network + arguments
res, score = evaluate(net, self.xq, self.xb, self.gt,
self.quantizers, self.best_key)
all_logs[-1]['val'] = res
if self.checkpoint_dir:
fname = join(self.checkpoint_dir, "checkpoint.pth")
print("storing", fname)
torch.save({
'state_dict': net.state_dict(),
'epoch': epoch,
'args': args,
'logs': all_logs
}, fname)
if score > self.best_score:
print("%s score improves (%g > %g), keeping as best" % (
self.best_key, score, self.best_score))
self.best_score = score
shutil.copyfile(fname, fname + '.best')
return res