def trainer()

in src/baselines/dnn.py [0:0]


def trainer(args, model, train_dataset, dev_dataset, test_dataset, log_fp=None):
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter(log_dir=args.tb_log_dir)
    if log_fp is None:
        log_fp = open(os.path.join(args.log_dir, "logs.txt"), "w")
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    # multi-gpu training
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)
    train_dataset_total_num = len(train_dataset)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
    train_batch_total_num = len(train_dataloader)
    print("Train dataset len: %d, batch num: %d" % (train_dataset_total_num, train_batch_total_num))

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (train_batch_total_num // args.gradient_accumulation_steps) + 1
    else:
        t_total = train_batch_total_num // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                              output_device=args.local_rank,
                                                              find_unused_parameters=True)

    # Train
    log_fp.write("***** Running training *****\n")
    logger.info("***** Running training *****")
    log_fp.write("Train Dataset Num examples = %d\n" % train_dataset_total_num)
    logger.info("Train Dataset  Num examples = %d", train_dataset_total_num)
    log_fp.write("Train Dataset Num Epochs = %d\n" % args.num_train_epochs)
    logger.info("Train Dataset Num Epochs = %d", args.num_train_epochs)
    log_fp.write("Train Dataset Instantaneous batch size per GPU = %d\n" % args.per_gpu_train_batch_size)
    logger.info("Train Dataset Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    log_fp.write("Train Dataset Total train batch size (w. parallel, distributed & accumulation) = %d\n" % (args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)))
    logger.info("Train Dataset Total train batch size (w. parallel, distributed & accumulation) = %d",
                args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    log_fp.write("Train Dataset Gradient Accumulation steps = %d\n" % args.gradient_accumulation_steps)
    logger.info("Train Dataset Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    log_fp.write("Train Dataset Total optimization steps = %d\n" % t_total)
    logger.info("Train Dataset Total optimization steps = %d", t_total)
    log_fp.write("#" * 50 + "\n")
    log_fp.flush()

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility (even between python 2 and 3)

    max_metric_type = args.max_metric_type
    max_metric_value = 0
    max_metric_model_info = {}
    last_max_metric_global_step = None
    cur_max_metric_global_step = None
    use_time = 0
    run_begin_time = time.time()
    real_epoch = 0

    for epoch in train_iterator:
        if args.tfrecords:
            epoch_iterator = tqdm(train_dataloader, total=train_batch_total_num, desc="Iteration", disable=args.local_rank not in [-1, 0])
        else:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            begin_time = time.time()
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "inputs": batch[0],
                "labels": batch[-1]
            }
            outputs = model(**inputs)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                # The loss of each batch is divided by gradient_accumulation_steps
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            epoch_iterator.set_description("loss {}".format(round(loss.item(), 5)))

            tr_loss += loss.item()
            end_time = time.time()
            use_time += (end_time - begin_time)
            if (step + 1) % args.gradient_accumulation_steps == 0:
                # Clear the gradient after completing gradient_accumulation_steps steps
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                # evaluate per logging_steps steps
                update_flag = False
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        result = evaluate(args, model, dev_dataset, prefix="checkpoint-{}".format(global_step), log_fp=log_fp)
                        # update_flag = False
                        for key, value in result.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value
                            if key == max_metric_type:
                                if max_metric_value < value:
                                    max_metric_value = value
                                    update_flag = True
                                    last_max_metric_global_step = cur_max_metric_global_step
                                    cur_max_metric_global_step = global_step
                        logs["update_flag"] = update_flag
                        if update_flag:
                            max_metric_model_info.update({"epoch": epoch + 1, "global_step": global_step})
                            max_metric_model_info.update(logs)
                        _, _, test_result = predict(args, model, test_dataset, "checkpoint-{}".format(global_step), log_fp=log_fp)
                        for key, value in test_result.items():
                            eval_key = "test_{}".format(key)
                            logs[eval_key] = value
                    avg_iter_time = round(use_time / (args.gradient_accumulation_steps * args.logging_steps), 2)
                    logger.info("avg time per batch(s): %f\n" % avg_iter_time)
                    log_fp.write("avg time per batch (s): %f\n" % avg_iter_time)
                    use_time = 0
                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logs["epoch"] = epoch + 1
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        if isinstance(value, dict):
                            for key1, value1 in value.items():
                                tb_writer.add_scalar(key + "_" + key1, value1, global_step)
                        else:
                            tb_writer.add_scalar(key, value, global_step)

                    logger.info(json.dumps({**logs, **{"step": global_step}}, ensure_ascii=False))
                    log_fp.write(json.dumps({**logs, **{"step": global_step}}, ensure_ascii=False) + "\n")
                    log_fp.write("##############################\n")
                    log_fp.flush()
                # save checkpoint per save_steps steps
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    print("save dir: ", output_dir)
                    if args.save_all:
                        save_check_point(args, model, output_dir)
                    elif update_flag:
                        if args.delete_old:
                            # delete the old chechpoint
                            filename_list = os.listdir(args.output_dir)
                            for filename in filename_list:
                                if "checkpoint-" in filename and filename != "checkpoint-{}".format(global_step):
                                    shutil.rmtree(os.path.join(args.output_dir, filename))
                        save_check_point(args, model, output_dir)
            if 0 < args.max_steps < global_step:
                epoch_iterator.close()
                break
        real_epoch = epoch + 1
        if 0 < args.max_steps < global_step:
            train_iterator.close()
            break
    run_end_time = time.time()
    if args.local_rank in [-1, 0]:
        tb_writer.close()
    log_fp.write(json.dumps(max_metric_model_info, ensure_ascii=False) + "\n")
    log_fp.write("##############################\n")
    avg_time_per_epoch = round((run_end_time - run_begin_time)/real_epoch, 2)
    logger.info("Avg time per epoch(s, %d epoch): %f\n" %(real_epoch, avg_time_per_epoch))
    log_fp.write("Avg time per epoch(s, %d epoch): %f\n" %(real_epoch, avg_time_per_epoch))

    return global_step, tr_loss / global_step, max_metric_model_info