in mico/model/train.py [0:0]
def train_mico(subprocess_index, model):
"""This is the main training function.
There is an inner loop for updating `q_z`.
The outer loop is both optimizing MICO and finetuning BERT.
Parameters
----------
subprocess_index : int
The index of the sub-process.
model : MutualInfoCotrain object
This is the MICO model we are going to train.
"""
start = timer()
hparams = model.hparams
data = QueryDocumentsPair(train_folder_path=hparams.train_folder_path, test_folder_path=hparams.test_folder_path,
is_csv_header=hparams.is_csv_header, val_ratio=hparams.val_ratio)
if subprocess_index == 0: # The logging only happens in the first sub-process.
logging.info("Number of queries in each dataset: Train / Val / Test = %d / %d / %d" %
(len(data.train_dataset), len(data.val_dataset), len(data.test_dataset)))
starting_epoch = 1
resume = hparams.resume
if resume:
load_path_suffix = '/model_current_iter.pt'
load_path = hparams.model_path + load_path_suffix
if not os.path.isfile(load_path):
load_path_suffix = '/model_current_epoch.pt'
load_path = hparams.model_path + load_path_suffix
if not os.path.isfile(load_path):
logging.info("No previous training model file found in the path %s" % load_path)
raise ValueError("Although you set --resume, there is no model that we can load.")
model.load(suffix=load_path_suffix, subprocess_index=subprocess_index) # Here you can also set it to load the model of the most recent epoch.
model.to(subprocess_index)
resume_iteration = model.resume_iteration
starting_epoch = model.resume_epoch
logging.info('Resume training from iteration %d' % resume_iteration)
hparams = model.hparams
else:
resume_iteration = 0
model.to(subprocess_index)
# find_unused_parameters=True is needed since when we update `q_z`,
# there are lots of other parameters in MICO and BERT do not have gradients and are not changed.
model = DistributedDataParallel(model, device_ids=[subprocess_index], find_unused_parameters=True)
if subprocess_index == 0:
logging.info("{} params in this whole model.".format(sum([p.numel() for p in model.parameters()])))
tb_logging = SummaryWriter(hparams.model_path + '/log')
# initialize optimizers
mico_param_list = []
bert_param_list = []
mico_q_z_param_list = []
for name, param in (model.named_parameters()):
if 'q_z' in name:
mico_q_z_param_list.append(param)
if 'p_z_y' in name or 'q_z' in name:
mico_param_list.append(param)
elif 'bert' in name:
bert_param_list.append(param)
else:
raise ValueError("%s is not in MICO or in BERT. What is this parameter?" % name)
params = [{"name": "MICO", "params": mico_param_list}]
if not hparams.bert_fix:
params.append({"name": "BERT", "params": bert_param_list, "lr": hparams.lr_bert})
optimizer = torch.optim.Adam(params, lr=hparams.lr)
optimizers = [optimizer]
lr_schedulers = []
for optimizer in optimizers:
lr_schedulers.append(get_constant_schedule_with_warmup(optimizer=optimizer,
num_warmup_steps=hparams.num_warmup_steps))
gradient_clippers = [(mico_param_list, hparams.clip)]
if not hparams.bert_fix:
gradient_clippers.append((bert_param_list, hparams.clip))
# here we obtain test_loader to check whether the performance on the test set
# is similar to the one on validation set.
# We do not make any decision by test performance.
# The best model is selected via the validation performance.
train_loader, val_loader, test_loader = data.get_loaders(
hparams.batch_size, hparams.num_workers,
is_shuffle_train=True, is_get_test=True)
best_val_perf = float('-inf')
forward_sum = {'loss': 0}
num_steps = 0
bad_epochs = 0
try:
for epoch in (range(starting_epoch, hparams.epochs + 1)):
model.train()
start = timer()
if epoch <= resume_iteration / len(train_loader):
continue
elif epoch == (resume_iteration // len(train_loader) + 1):
starting_iteration = resume_iteration % len(train_loader) + 1
else:
starting_iteration = 0
for batch_num, batch in enumerate(train_loader):
batch_num += starting_iteration
if batch_num >= len(train_loader):
break
current_iteration_number = (epoch - 1) * len(train_loader) + batch_num
if hparams.early_quit and hparams.early_quit == batch_num:
logging.info('stop early at hparams.early_quit')
break
for optimizer in optimizers:
optimizer.zero_grad()
query, document = batch
if batch_num % hparams.log_interval == 0:
is_monitor_forward = True
else:
is_monitor_forward = False
# Here we only optimize `q_z` for a better evaluation of later cross entropy.
q_z_loss = model.forward(document=document, query=query, is_monitor_forward=is_monitor_forward, forward_method="update_q_z", device=subprocess_index)
optimizer_prior = torch.optim.Adam(mico_q_z_param_list, lr=hparams.lr_prior)
for _ in range(hparams.num_steps_prior):
optimizer_prior.zero_grad()
q_z_loss.backward()
nn.utils.clip_grad_norm_(mico_q_z_param_list, hparams.clip)
optimizer_prior.step()
# Now we can evaluate the cross entropy between query routing and document assignment
forward = model.forward(document=document, query=query, is_monitor_forward=is_monitor_forward, forward_method="update_all", device=subprocess_index)
forward['loss'].backward()
for params, clip in gradient_clippers:
nn.utils.clip_grad_norm_(params, clip)
for optimizer, lr_scheduler in zip(optimizers, lr_schedulers):
optimizer.step()
lr_scheduler.step()
if subprocess_index == 0: # This part is only for monitoring the training process and writing logs.
for key in forward:
if key in forward_sum:
try:
forward_sum[key] += forward[key].detach().item()
except:
forward_sum[key] += forward[key]
else:
try:
forward_sum[key] = forward[key].detach().item()
except:
forward_sum[key] = forward[key]
num_steps += 1
if batch_num % (20 * hparams.log_interval) == 0:
logging.info('Epoch\t | \t batch \t | \tlr_MICO\t | \tlr_BERT\t' + \
'\t'.join([' | {:8s}'.format(key) \
for key in forward_sum]))
for param_group in optimizer.param_groups:
if param_group['name'] == 'MICO':
curr_lr_mico = param_group['lr']
if param_group['name'] == 'BERT':
curr_lr_bert = param_group['lr']
if batch_num % hparams.log_interval == 0:
logging.info('{:d}\t | {:5d}/{:5d} \t | {:.3e} \t | {:.3e} \t'.format(
epoch, batch_num, len(train_loader), curr_lr_mico, curr_lr_bert) +
'\t'.join([' | {:8.2f}'.format(forward[key])
for key in forward_sum]))
for key in forward_sum:
tb_logging.add_scalar('Train/' + key, forward[key], current_iteration_number)
if current_iteration_number >= 0 and current_iteration_number % hparams.check_val_test_interval == 0:
part_batch_num = 1000
val_part_perf = evaluate(model, val_loader, num_batches=part_batch_num, device=subprocess_index)
test_part_perf = evaluate(model, test_loader, num_batches=part_batch_num, device=subprocess_index)
logging.info('Evaluate on {:d} batches | '.format(part_batch_num) +
'Iteration number {:10d}'.format(current_iteration_number) +
' '.join([' | val_partial {:s} {:8.2f}'.format(key, val_part_perf[key])
for key in val_part_perf]) +
' '.join([' | test_partial {:s} {:8.2f}'.format(key, test_part_perf[key])
for key in test_part_perf]))
for key in val_part_perf:
tb_logging.add_scalar('Val_partial/' + key, val_part_perf[key], current_iteration_number)
for key in test_part_perf:
tb_logging.add_scalar('Test_partial/' + key, test_part_perf[key], current_iteration_number)
torch.save(
{'hparams': hparams, 'state_dict': model.module.state_dict(), 'optimizer': optimizers,
'epoch': epoch, 'current_iteration_number': current_iteration_number},
hparams.model_path + '/model_current_iter.pt')
if math.isnan(forward_sum['loss']):
logging.info("Stopping epoch because loss is NaN")
break
tb_logging.flush()
if math.isnan(forward_sum['loss']):
logging.info("Stopping training session at ep %d batch %d because loss is NaN" % (epoch, batch_num))
break
if subprocess_index == 0:
val_perf = evaluate(model, val_loader, num_batches=10000, device=subprocess_index)
test_perf = evaluate(model, test_loader, num_batches=10000, device=subprocess_index)
logging.info('End of epoch {:3d}'.format(epoch) +
' '.join([' | train ave {:s} {:8.2f}'.format(key, forward_sum[key] / num_steps)
for key in ['loss', 'h_z_cond', 'h_z']]) +
' '.join([' | val {:s} {:8.2f}'.format(key, val_perf[key])
for key in val_perf]) +
' '.join([' | test {:s} {:8.2f}'.format(key, test_perf[key])
for key in test_perf]) +
' | Time %s' % str(timedelta(seconds=round(timer() - start))))
for key in ['loss', 'h_z_cond', 'h_z']:
tb_logging.add_scalar('Train/ave_' + key, forward_sum[key] / num_steps, epoch)
for key in val_perf:
tb_logging.add_scalar('Val/' + key, val_perf[key], epoch)
for key in test_perf:
tb_logging.add_scalar('Test/' + key, test_perf[key], epoch)
for key in val_perf:
tb_logging.add_scalar('Val_in_steps/' + key, val_perf[key], current_iteration_number)
for key in test_perf:
tb_logging.add_scalar('Test_in_steps/' + key, test_perf[key], current_iteration_number)
forward_sum = {}
num_steps = 0
val_perf = val_perf['AUC']
if val_perf > best_val_perf:
best_val_perf = val_perf
bad_epochs = 0
logging.info('*** Best model so far, saving ***')
torch.save({'hparams': hparams, 'state_dict': model.module.state_dict(), 'epoch': epoch},
hparams.model_path + '/model_best.pt')
else:
bad_epochs += 1
logging.info("Bad epoch %d" % bad_epochs)
torch.save({'hparams': hparams, 'state_dict': model.module.state_dict(), 'optimizer': optimizers,
'epoch': epoch, 'current_iteration_number': current_iteration_number},
hparams.model_path + '/model_current_epoch.pt')
if bad_epochs > hparams.num_bad_epochs:
break
if epoch % hparams.save_per_num_epoch == 0:
torch.save({'hparams': hparams, 'state_dict': model.module.state_dict(), 'optimizer': optimizers,
'epoch': epoch, 'current_iteration_number': current_iteration_number},
hparams.model_path + '/model_opt_epoch' + str(epoch) + '.pt')
tb_logging.flush()
except KeyboardInterrupt:
if subprocess_index == 0:
logging.info('-' * 40)
logging.info('Exiting from training early')
if subprocess_index == 0:
tb_logging.close()
logging.info("Total training time: %s" % str(timedelta(seconds=round(timer() - start))))