in compare_models.py [0:0]
def run_valid(self, n_valid=1000):
state = self.state
args = self.args
self.print_losses(42)
self.print_time(42)
self.other_dl.valid()
logging.info("Running validation...")
valid_ret = self._run_inference(
"validation after_n_samples {} with_n_valid {}".format(
self.state.n_samples, n_valid), n_valid)
lar = np.array(state.running_loss)
for mi, model in enumerate(self.models):
if args.save: # TODO save on other metrics?
logging.log(42, "saving {}".format(model.model_name))
self.save(model, self.optimizers[mi])
if valid_ret['loss'][mi] < state.best_valid[mi]:
self.save(model, self.optimizers[mi], "_best.pth")
state.best_valid[mi] = valid_ret['loss'][mi]
if args.lr_decay and state.n_samples > 10000 or self.args.small:
logging.info("Doing LR decay")
recent_mean_loss = lar[mi][-400:].mean()
older_mean_loss = lar[mi][-1400:-1000].mean()
if recent_mean_loss > 0.99 * older_mean_loss or self.args.small: # TODO with lr_decay only
with th.cuda.device(self.args.gpu): # TODO wrap everything?
sd = self.optimizers[mi]['model'].state_dict()
sd['param_groups'][0]['lr'] = max(
sd['param_groups'][0]['lr'] * 0.5, 1E-8)
logging.log(42, "lr_decay {}".format(sd['param_groups'][0][
'lr']))
self.optimizers[mi]['model'].load_state_dict(sd)
model.train()
# TODO plot debug outputs on the valid set
# if args.debug < 3:
# th.save(self.models, tempfile.gettempdir()+ '/models_' + str(int(start_time)) + '.pth')
self.__update_plots(valid_ret)