def train_one_epoch()

in EAIEvaluation/HiTUT/hitut_train/custom_train.py [0:0]


def train_one_epoch(model, optimizer, data_loader, counts, loss_weights=None):
    model.train()

    count = 0
    logging.info('Training [Level: %s]' % (model.args.train_level))
    epoch_start_time = time.time()
    logging.info('************ epoch: %d ************' %(counts['epoch']))
    for batch in tqdm(data_loader):
        # count += 1
        # if count == 30:
        #     break
        if model.args.use_bert and random.random() > 0.5:
            continue

        task_type = batch['batch_type'][0]

        enable_mask = 'mani' in task_type
        enable_navi_aux = 'navi' in task_type and model.args.auxiliary_loss_navi

        type_loss, arg_loss, mask_loss, navi_losses = model(batch)
        navi_loss = sum([l for l in navi_losses.values()]) if enable_navi_aux else 0

        if model.args.weigh_loss:
            weights = loss_weights[task_type]
            w = [torch.exp(-i) for i in weights]
            loss = w[0] * type_loss + w[1] * arg_loss + w[2] * mask_loss + w[3] * navi_loss+0.5*weights.sum()
            record_loss = (type_loss + arg_loss + mask_loss + navi_loss).item()
        else:
            loss = type_loss + arg_loss + mask_loss + navi_loss
            record_loss = loss.item()

        iter_num = counts['iter_%s'%task_type]
        local_iter_num = counts['iter_%s'%task_type] % counts['dlen_%s'%task_type]

        if local_iter_num %(max(counts['dlen_%s'%task_type]//30,1)) == 0:
            lr = optimizer.param_groups[0]["lr"]
            mask_str = 'mask: %.4f |'%mask_loss.item() if enable_mask else ''
            navi_str = 'vis: %.4f |rea: %.4f |prog: %.4f |'%(navi_losses['visible'].item(),
                navi_losses['reached'].item(), navi_losses['progress'].item()) if enable_navi_aux else ''
            logging.info('[%8s iter%4d] loss total: %.4f |type: %.4f |arg: %.4f |%s%slr: %.1e'%(
                task_type, local_iter_num, record_loss, type_loss.item(), arg_loss.item(), mask_str, navi_str, lr))

            writer.add_scalar('train_loss/%s/total'%task_type, record_loss, iter_num)
            writer.add_scalar('train_loss/%s/type'%task_type, type_loss.item(), iter_num)
            writer.add_scalar('train_loss/%s/arg'%task_type, arg_loss.item(), iter_num)
            if enable_mask:
                writer.add_scalar('train_loss/low/mask', mask_loss.item(), iter_num)
            if enable_navi_aux:
                for k, v in navi_losses.items():
                    writer.add_scalar('train_loss/low/%s'%k, v.item(), iter_num)
            if model.args.weigh_loss:
                writer.add_scalars('weights/%s'%task_type, {'type': 1/w[0], 'arg': 1/w[1], 'mask': 1/w[2]}, iter_num)
                if local_iter_num %(counts['dlen_%s'%task_type]//30 * 6) == 0:
                    weights = '(' + ' '.join(['%.3f'%i.item() for i in w]) + ')'
                    logging.info('%s loss weights 1/var: (type, arg, mask)=%s'%(task_type, weights))
            writer.flush()

        counts['iter_%s'%task_type] += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if 'scheduler' in counts:
            counts['scheduler'].step()

    et = '%.1fm'%((time.time() - epoch_start_time)/60)
    tt = time.time() - counts['start_time']
    tt = '%dh%dm'%(tt//3600, tt//60%60)
    logging.info('[%s] epoch %d finished (epoch time: %s | total time: %s)' %(model.args.train_level,
        counts['epoch'], et, tt))