in automl21/scs_neural/experimentation/launcher.py [0:0]
def _learn_batched(self):
self.lowest_val_loss = -1
if self.cfg.use_train_seed:
torch.manual_seed(self.cfg.train_seed)
rng_train_data = np.random.default_rng(self.cfg.train_data_seed)
self.scs_neural.create_model(self.scs_problem)
if self.cfg.log_tensorboard:
self.sw = SummaryWriter(log_dir=self.cfg.tensorboard_dir)
self.opt = torch.optim.Adam(
self.scs_neural.accel.parameters(),
lr=self.cfg.lr, betas=(self.cfg.beta1, self.cfg.beta2))
if self.cfg.cosine_lr_decay:
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.opt, self.cfg.num_model_updates)
self.multi_instance = self.scs_neural.scale_and_cache_all_instances(
self.scs_problem.instances, use_scaling=self.cfg.scs.use_problem_scaling,
scale=self.cfg.scs.scale, rho_x=self.cfg.scs.rho_x
)
with torch.no_grad():
self.val_multi_instance = self.scs_neural.scale_and_cache_all_instances(
self.scs_validate_problem.instances, use_scaling=self.cfg.scs.use_problem_scaling,
scale=self.cfg.scs.scale, rho_x=self.cfg.scs.rho_x
)
self.test_multi_instance = self.scs_neural.scale_and_cache_all_instances(
self.scs_test_problem.instances, use_scaling=self.cfg.scs.use_problem_scaling,
scale=self.cfg.scs.scale, rho_x=self.cfg.scs.rho_x
)
self._reset_diffu_counts()
while self.itr < self.cfg.num_model_updates:
sampled_ids = rng_train_data.choice(len(self.scs_problem.instances),
size=self.cfg.train_batch_size,
replace=False)
num_tries = 0
while True:
num_tries += 1
curr_multi_instance = self.scs_neural.select_instances(
self.multi_instance, sampled_ids)
soln_neural, metrics, diffu_counts, loss_available = self.scs_neural.solve(
curr_multi_instance,
max_iters=self.cfg.num_iterations_train,
alpha=self.cfg.scs.alpha
)
if loss_available:
break
if num_tries > 10000:
raise RuntimeError("Unable to find feasible train samples")
losses = [soln_neural[i]['loss'] for i in range(self.cfg.train_batch_size)]
loss, index_nans = self._compute_loss(losses)
self.loss_meter.update(loss.item())
self._update_diffu_counts(diffu_counts)
self.opt.zero_grad()
loss.backward()
if self.cfg.clip_gradients:
torch.nn.utils.clip_grad_norm_(
self.scs_neural.accel.parameters(),
self.cfg.max_gradient)
if self.itr % self.cfg.test_freq == 0:
if len(index_nans) == 0:
if self.cfg.log_tensorboard and hasattr(self.scs_neural.accel, 'log'):
self.scs_neural.accel.log(self.sw, self.itr)
self.opt.step()
if self.cfg.cosine_lr_decay:
self.scheduler.step()
if self.itr % self.cfg.test_freq == 0:
print("Loss: ", loss.item())
self._plot_test_results(
n_iter=self.cfg.num_iterations_eval,
dataset_type='validate'
)
test_results = self._plot_test_results(
n_iter=self.cfg.num_iterations_eval, tag=f'{self.itr:06d}',
dir_tag=f'{self.itr // 1000:03d}'
)
train_results = self._plot_train_results(
n_iter=self.cfg.num_iterations_eval, tag=f'{self.itr:06d}',
dir_tag=f'{self.itr // 1000:03d}'
)
if len(test_results) > 0 and len(train_results) > 0:
self.plot_aggregate_results(
test_results, train_results,
tag=f'{self.itr:06d}',
dir_tag=f'{self.itr // 1000:03d}'
)
self._reset_diffu_counts()
if self.itr % self.cfg.save_freq == 0:
torch.save(self, 'latest.pt')
if self.val_loss_meter.avg < self.lowest_val_loss or \
self.lowest_val_loss == -1:
torch.save(self, 'best_model.pt')
self.lowest_val_loss = self.val_loss_meter.avg
self.itr += 1