def train_mico()

in mico/model/train.py [0:0]


def train_mico(subprocess_index, model):
    """This is the main training function.

    There is an inner loop for updating `q_z`.
    The outer loop is both optimizing MICO and finetuning BERT.

    Parameters
    ----------
    subprocess_index : int
        The index of the sub-process. 
    model : MutualInfoCotrain object
        This is the MICO model we are going to train.

    """
    start = timer()

    hparams = model.hparams
    data = QueryDocumentsPair(train_folder_path=hparams.train_folder_path, test_folder_path=hparams.test_folder_path, 
                              is_csv_header=hparams.is_csv_header, val_ratio=hparams.val_ratio)
    if subprocess_index == 0: # The logging only happens in the first sub-process.
        logging.info("Number of queries in each dataset:    Train / Val / Test = %d / %d / %d" %
                     (len(data.train_dataset), len(data.val_dataset), len(data.test_dataset)))

    starting_epoch = 1
    resume = hparams.resume
    if resume:
        load_path_suffix = '/model_current_iter.pt'
        load_path = hparams.model_path + load_path_suffix
        if not os.path.isfile(load_path):
            load_path_suffix = '/model_current_epoch.pt'
            load_path = hparams.model_path + load_path_suffix
            if not os.path.isfile(load_path):
                logging.info("No previous training model file found in the path %s" % load_path)
                raise ValueError("Although you set --resume, there is no model that we can load.")

        model.load(suffix=load_path_suffix, subprocess_index=subprocess_index) # Here you can also set it to load the model of the most recent epoch.
        model.to(subprocess_index)

        resume_iteration = model.resume_iteration
        starting_epoch = model.resume_epoch
        logging.info('Resume training from iteration %d' % resume_iteration)
        hparams = model.hparams
    else:
        resume_iteration = 0
        model.to(subprocess_index)

    # find_unused_parameters=True is needed since when we update `q_z`, 
    # there are lots of other parameters in MICO and BERT do not have gradients and are not changed.
    model = DistributedDataParallel(model, device_ids=[subprocess_index], find_unused_parameters=True)
    if subprocess_index == 0:
        logging.info("{} params in this whole model.".format(sum([p.numel() for p in model.parameters()])))
        tb_logging = SummaryWriter(hparams.model_path + '/log')

    # initialize optimizers
    mico_param_list = []
    bert_param_list = []
    mico_q_z_param_list = []
    for name, param in (model.named_parameters()):
        if 'q_z' in name:
            mico_q_z_param_list.append(param)
        if 'p_z_y' in name or 'q_z' in name:
            mico_param_list.append(param)
        elif 'bert' in name:
            bert_param_list.append(param)
        else:
            raise ValueError("%s is not in MICO or in BERT. What is this parameter?" % name)

    params = [{"name": "MICO", "params": mico_param_list}]
    if not hparams.bert_fix:
        params.append({"name": "BERT", "params": bert_param_list, "lr": hparams.lr_bert})
    optimizer = torch.optim.Adam(params, lr=hparams.lr)
    optimizers = [optimizer]
    lr_schedulers = []
    for optimizer in optimizers:
        lr_schedulers.append(get_constant_schedule_with_warmup(optimizer=optimizer,
                                                               num_warmup_steps=hparams.num_warmup_steps))

    gradient_clippers = [(mico_param_list, hparams.clip)]
    if not hparams.bert_fix:
        gradient_clippers.append((bert_param_list, hparams.clip))

    # here we obtain test_loader to check whether the performance on the test set
    # is similar to the one on validation set. 
    # We do not make any decision by test performance. 
    # The best model is selected via the validation performance.
    train_loader, val_loader, test_loader = data.get_loaders(
        hparams.batch_size, hparams.num_workers,
        is_shuffle_train=True, is_get_test=True)
    best_val_perf = float('-inf')
    forward_sum = {'loss': 0}
    num_steps = 0
    bad_epochs = 0

    try:
        for epoch in (range(starting_epoch, hparams.epochs + 1)):
            model.train()
            start = timer()
            if epoch <= resume_iteration / len(train_loader):
                continue
            elif epoch == (resume_iteration // len(train_loader) + 1):
                starting_iteration = resume_iteration % len(train_loader) + 1
            else:
                starting_iteration = 0
            for batch_num, batch in enumerate(train_loader):
                batch_num += starting_iteration
                if batch_num >= len(train_loader):
                    break
                current_iteration_number = (epoch - 1) * len(train_loader) + batch_num
                if hparams.early_quit and hparams.early_quit == batch_num:
                    logging.info('stop early at hparams.early_quit')
                    break
                for optimizer in optimizers:
                    optimizer.zero_grad()
                query, document = batch

                if batch_num % hparams.log_interval == 0:
                    is_monitor_forward = True
                else:
                    is_monitor_forward = False

                # Here we only optimize `q_z` for a better evaluation of later cross entropy.
                q_z_loss = model.forward(document=document, query=query, is_monitor_forward=is_monitor_forward, forward_method="update_q_z", device=subprocess_index)

                optimizer_prior = torch.optim.Adam(mico_q_z_param_list, lr=hparams.lr_prior)
                for _ in range(hparams.num_steps_prior):
                    optimizer_prior.zero_grad()
                    q_z_loss.backward()
                    nn.utils.clip_grad_norm_(mico_q_z_param_list, hparams.clip)
                    optimizer_prior.step()

                # Now we can evaluate the cross entropy between query routing and document assignment
                forward = model.forward(document=document, query=query, is_monitor_forward=is_monitor_forward, forward_method="update_all", device=subprocess_index)

                forward['loss'].backward()

                for params, clip in gradient_clippers:
                    nn.utils.clip_grad_norm_(params, clip)

                for optimizer, lr_scheduler in zip(optimizers, lr_schedulers):
                    optimizer.step()
                    lr_scheduler.step()

                if subprocess_index == 0: # This part is only for monitoring the training process and writing logs.
                    for key in forward:
                        if key in forward_sum:
                            try:
                                forward_sum[key] += forward[key].detach().item()
                            except:
                                forward_sum[key] += forward[key]
                        else:
                            try:
                                forward_sum[key] = forward[key].detach().item()
                            except:
                                forward_sum[key] = forward[key]
                    num_steps += 1

                    if batch_num % (20 * hparams.log_interval) == 0:
                        logging.info('Epoch\t | \t batch \t | \tlr_MICO\t | \tlr_BERT\t' + \
                                     '\t'.join([' | {:8s}'.format(key) \
                                                for key in forward_sum]))

                    for param_group in optimizer.param_groups:
                        if param_group['name'] == 'MICO':
                            curr_lr_mico = param_group['lr']
                        if param_group['name'] == 'BERT':
                            curr_lr_bert = param_group['lr']

                    if batch_num % hparams.log_interval == 0:
                        logging.info('{:d}\t | {:5d}/{:5d} \t | {:.3e} \t | {:.3e} \t'.format(
                            epoch, batch_num, len(train_loader), curr_lr_mico, curr_lr_bert) +
                                     '\t'.join([' | {:8.2f}'.format(forward[key])
                                                for key in forward_sum]))
                        for key in forward_sum:
                            tb_logging.add_scalar('Train/' + key, forward[key], current_iteration_number)

                    if current_iteration_number >= 0 and current_iteration_number % hparams.check_val_test_interval == 0:
                        part_batch_num = 1000
                        val_part_perf = evaluate(model, val_loader, num_batches=part_batch_num, device=subprocess_index)
                        test_part_perf = evaluate(model, test_loader, num_batches=part_batch_num, device=subprocess_index)

                        logging.info('Evaluate on {:d} batches | '.format(part_batch_num) +
                                     'Iteration number {:10d}'.format(current_iteration_number) +
                                     ' '.join([' | val_partial {:s} {:8.2f}'.format(key, val_part_perf[key])
                                               for key in val_part_perf]) +
                                     ' '.join([' | test_partial {:s} {:8.2f}'.format(key, test_part_perf[key])
                                               for key in test_part_perf]))

                        for key in val_part_perf:
                            tb_logging.add_scalar('Val_partial/' + key, val_part_perf[key], current_iteration_number)
                        for key in test_part_perf:
                            tb_logging.add_scalar('Test_partial/' + key, test_part_perf[key], current_iteration_number)

                        torch.save(
                            {'hparams': hparams, 'state_dict': model.module.state_dict(), 'optimizer': optimizers,
                             'epoch': epoch, 'current_iteration_number': current_iteration_number},
                            hparams.model_path + '/model_current_iter.pt')

                    if math.isnan(forward_sum['loss']):
                        logging.info("Stopping epoch because loss is NaN")
                        break

                    tb_logging.flush()

            if math.isnan(forward_sum['loss']):
                logging.info("Stopping training session at ep %d batch %d because loss is NaN" % (epoch, batch_num))
                break

            if subprocess_index == 0:
                val_perf = evaluate(model, val_loader, num_batches=10000, device=subprocess_index)
                test_perf = evaluate(model, test_loader, num_batches=10000, device=subprocess_index)

                logging.info('End of epoch {:3d}'.format(epoch) +
                             ' '.join([' | train ave {:s} {:8.2f}'.format(key, forward_sum[key] / num_steps)
                                       for key in ['loss', 'h_z_cond', 'h_z']]) +
                             ' '.join([' | val {:s} {:8.2f}'.format(key, val_perf[key])
                                       for key in val_perf]) +
                             ' '.join([' | test {:s} {:8.2f}'.format(key, test_perf[key])
                                       for key in test_perf]) +
                             ' | Time %s' % str(timedelta(seconds=round(timer() - start))))

                for key in ['loss', 'h_z_cond', 'h_z']:
                    tb_logging.add_scalar('Train/ave_' + key, forward_sum[key] / num_steps, epoch)
                for key in val_perf:
                    tb_logging.add_scalar('Val/' + key, val_perf[key], epoch)
                for key in test_perf:
                    tb_logging.add_scalar('Test/' + key, test_perf[key], epoch)
                for key in val_perf:
                    tb_logging.add_scalar('Val_in_steps/' + key, val_perf[key], current_iteration_number)
                for key in test_perf:
                    tb_logging.add_scalar('Test_in_steps/' + key, test_perf[key], current_iteration_number)

                forward_sum = {}
                num_steps = 0

                val_perf = val_perf['AUC']
                if val_perf > best_val_perf:
                    best_val_perf = val_perf
                    bad_epochs = 0
                    logging.info('*** Best model so far, saving ***')
                    torch.save({'hparams': hparams, 'state_dict': model.module.state_dict(), 'epoch': epoch},
                               hparams.model_path + '/model_best.pt')
                else:
                    bad_epochs += 1
                    logging.info("Bad epoch %d" % bad_epochs)

                torch.save({'hparams': hparams, 'state_dict': model.module.state_dict(), 'optimizer': optimizers,
                            'epoch': epoch, 'current_iteration_number': current_iteration_number},
                           hparams.model_path + '/model_current_epoch.pt')

                if bad_epochs > hparams.num_bad_epochs:
                    break

                if epoch % hparams.save_per_num_epoch == 0:
                    torch.save({'hparams': hparams, 'state_dict': model.module.state_dict(), 'optimizer': optimizers,
                                'epoch': epoch, 'current_iteration_number': current_iteration_number},
                               hparams.model_path + '/model_opt_epoch' + str(epoch) + '.pt')

                tb_logging.flush()

    except KeyboardInterrupt:
        if subprocess_index == 0:
            logging.info('-' * 40)
            logging.info('Exiting from training early')

    if subprocess_index == 0:
        tb_logging.close()
        logging.info("Total training time: %s" % str(timedelta(seconds=round(timer() - start))))