def train()

in model/train.py [0:0]


def train():
    global train_step
    global best_val_nll

    log_train_loss = torch.tensor(0.0).float().to(device)
    log_grad_norm = torch.tensor(0.0).float().to(device)
    log_token_num = torch.tensor(0).to(device)

    # Discriminator related
    log_gen_train_loss = torch.tensor(0.0).float().to(device)  # Log discriminator loss
    log_gen_num = torch.tensor(0.0).float().to(device)

    log_dis_train_loss = torch.tensor(0.0).float().to(device)
    log_dis_num = torch.tensor(0.0).float().to(device)

    dis_iterations = 0  # Num dis iters
    best_gen_val_loss = np.inf

    if cfg.DISCRIMINATOR.type != "Null" and cfg.DISCRIMINATOR.type != "":
        dis_iterator = dis_iter()

    log_start_time = time.time()  # coding: utf-8

    mems = [None for _ in range(cfg.TRAIN.batch_chunk)]

    assert batch_size % cfg.TRAIN.batch_chunk == 0
    train_real_iter = train_iter()

    for batch, (data, target, reset_mems, batch_token_num, status_vec) in enumerate(
            train_real_iter
    ):
        beta = get_fixed_temperature(
            cfg.DISCRIMINATOR.beta_max,
            train_step,
            cfg.TRAIN.max_step,
            cfg.DISCRIMINATOR.adapt,
        )
        model.module.temperature = 1.0 / beta

        model.zero_grad()

        # Batch chunking

        data_chunks = torch.chunk(data, cfg.TRAIN.batch_chunk, 1)
        target_chunks = torch.chunk(target, cfg.TRAIN.batch_chunk, 1)
        reset_mems_chunks = torch.chunk(reset_mems, cfg.TRAIN.batch_chunk, 0)
        if status_vec is not None:
            status_vec_chunks = torch.chunk(status_vec, cfg.TRAIN.batch_chunk, 1)
        for i in range(cfg.TRAIN.batch_chunk):

            data = data_chunks[i].contiguous()
            target = target_chunks[i].contiguous()
            reset_mems = reset_mems_chunks[i].contiguous()
            if status_vec is not None:
                status_vec = status_vec_chunks[i].contiguous()

            # reset_mems = None
            ret = model(data, target, reset_mems, "mle", mems[i], status_vec=status_vec)
            loss, mems[i] = ret["mle"], ret["mems"]

            loss = loss[target != dataset.vocab.pad_id]
            loss = loss.float().mean() / cfg.TRAIN.batch_chunk
            log_train_loss += (
                    loss.item()
                    * (target != dataset.vocab.pad_id).sum()
                    * cfg.TRAIN.batch_chunk
            )

            if cfg.TRAIN.use_mle:
                if args.fp16:
                    with amp.scale_loss(loss, optimizer, loss_id=1) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

        log_token_num += int(batch_token_num)

        if cfg.TRAIN.use_mle:
            if args.fp16:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), cfg.TRAIN.clip
                )
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.generator.parameters(), cfg.TRAIN.clip
                )

            # a = [torch.norm(w.grad) for w in model.module.generator.parameters()]
            log_grad_norm += grad_norm
            optimizer.step()
            optimizer.zero_grad()

        # Train discriminator
        if train_step > cfg.DISCRIMINATOR.start_iter and (
                train_step % cfg.DISCRIMINATOR.dis_loss_freq == 0
        ):
            # TODO: dis training messes up memory structure maintained during batch loading
            # (we need another dataloader foor real data)
            if not (cfg.DISCRIMINATOR.freeze_discriminator):
                for dis_iterations in range(cfg.DISCRIMINATOR.dis_steps):

                    try:
                        dis_data, _ = next(dis_iterator)
                    except StopIteration:
                        dis_iterator = dis_iter()

                    # Batch chunking for generator and discriminator
                    dis_data_chunks = torch.chunk(
                        dis_data, cfg.DISCRIMINATOR.batch_chunk, 1
                    )

                    if cfg.DISCRIMINATOR.type == "bert":
                        for idx, p in enumerate(
                                model.module.discriminator.parameters()
                        ):
                            if idx in model.module.discriminator.unfreeze_idx:
                                p.requires_grad = True
                    else:
                        for p in model.module.discriminator.parameters():
                            p.requires_grad = True

                    for i in range(cfg.DISCRIMINATOR.batch_chunk):
                        dis_data = dis_data_chunks[i].contiguous()
                        # Share the same mems with mle iter
                        ret = model(dis_data, None, None, "dis_loss")

                        dis_loss = ret["dis_loss"]
                        log_dis_train_loss += dis_loss.float().item()
                        dis_loss = (
                                dis_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
                        )
                        if (
                                cfg.DISCRIMINATOR.type == "bert"
                                and "gp" in cfg.DISCRIMINATOR.BERT.loss_type
                        ):
                            gp_loss = ret["gp_loss"]
                            gp_loss = (
                                    gp_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
                            )
                        elif (
                                cfg.DISCRIMINATOR.type == "cnn"
                                and "gp" in cfg.DISCRIMINATOR.CNN.loss_type
                        ):
                            gp_loss = ret["gp_loss"]
                            gp_loss = (
                                    gp_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
                            )

                        log_dis_num += 1

                        if args.fp16:
                            with amp.scale_loss(
                                    dis_loss, dis_optimizer, loss_id=0
                            ) as scaled_dis_loss:
                                scaled_dis_loss.backward()
                        else:
                            if not cfg.DISCRIMINATOR.backprop_outside:
                                dis_loss.backward()
                                if (
                                        cfg.DISCRIMINATOR.type == "bert"
                                        and "gp" in cfg.DISCRIMINATOR.BERT.loss_type
                                ):
                                    gp_loss.backward()
                                elif (
                                        cfg.DISCRIMINATOR.type == "cnn"
                                        and "gp" in cfg.DISCRIMINATOR.CNN.loss_type
                                ):
                                    gp_loss.backward()

                    # TODO: investigate training tricks for dis different clip?
                    if args.fp16:
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            amp.master_params(dis_optimizer), cfg.TRAIN.clip
                        )
                    else:
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            model.module.discriminator.parameters(), cfg.TRAIN.clip,
                        )

                    dis_optimizer.step()
                    dis_optimizer.zero_grad()

            for p in model.module.discriminator.parameters():
                p.requires_grad = False

        if train_step > cfg.DISCRIMINATOR.start_iter and (
                train_step % cfg.DISCRIMINATOR.gen_loss_freq == 0
        ):

            # Train generator
            # Make dis parameters non trainable
            try:
                dis_data, _ = next(dis_iterator)
            except StopIteration:
                dis_iterator = dis_iter()

            # Batch chunking for generator and discriminator
            dis_data_chunks = torch.chunk(dis_data, cfg.DISCRIMINATOR.batch_chunk, 1)

            for i in range(cfg.DISCRIMINATOR.batch_chunk):
                dis_data = dis_data_chunks[i].contiguous()

                update_D0 = False
                if train_step % cfg.PPO.dis_D_update_D0_freq == 0:
                    update_D0 = True

                if 'ppo' in cfg.DISCRIMINATOR.BERT.loss_type or 'ppo' in cfg.DISCRIMINATOR.CNN.loss_type:
                    for p in model.module.dis_D.parameters():
                        p.requires_grad = True

                    #Use same real batch and generate new fake batch
                    # Always backprop outside
                    ret = model(dis_data, None, None, "classifier_loss")
                    torch.nn.utils.clip_grad_norm_(model.module.dis_D.parameters(), cfg.TRAIN.clip)
                    dis_D_optimizer.step()
                    dis_D_optimizer.zero_grad()

                    for p in model.module.dis_D.parameters():
                        p.requires_grad = False


                ret = model(dis_data, None, None, "gen_loss", update_D0=update_D0)

                gen_loss = ret["gen_loss"]
                log_gen_train_loss += gen_loss.float().item()

                gen_loss = gen_loss.float().mean() / cfg.DISCRIMINATOR.batch_chunk
                log_gen_num += 1

                # if args.fp16:
                #     with amp.scale_loss(gen_loss, optimizer, loss_id=2) as scaled_gen_loss:
                #         scaled_gen_loss.backward(retain_graph=True)
                # else:
                #     gen_loss.backward(retain_graph=True)

                if args.fp16:
                    with amp.scale_loss(gen_loss, optimizer, loss_id=1) as scaled_loss:
                        scaled_loss.backward()
                else:
                    # a = [torch.norm(w.grad) for w in model.module.generator.parameters()]
                    if not cfg.DISCRIMINATOR.backprop_outside:
                        gen_loss.backward()
                    # b = [torch.norm(w.grad) for w in model.module.generator.parameters()]
                    # c = [(j-i) for i,j in zip(a,b)]
                    # d = ([i/j for i,j in zip(a,c)])
                    # d = sum(d)/len(d)

                    pass

            if args.fp16:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), cfg.TRAIN.clip
                )
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.module.generator.parameters(), cfg.TRAIN.clip
                )

            gen_optimizer.step()
            gen_optimizer.zero_grad()



        # step-wise learning rate annealing
        train_step += 1

        if cfg.TRAIN.scheduler in ["cosine", "constant", "dev_perf"]:
            # linear warmup stage
            if train_step < cfg.TRAIN.warmup_step:
                curr_lr = cfg.TRAIN.lr * train_step / cfg.TRAIN.warmup_step
                optimizer.param_groups[0]["lr"] = curr_lr
            else:
                if cfg.TRAIN.scheduler == "cosine":
                    scheduler.step()
        elif cfg.TRAIN.scheduler == "inv_sqrt":
            scheduler.step()

        if cfg.DISCRIMINATOR.type != "Null" and cfg.DISCRIMINATOR.type != "":
            if cfg.DISCRIMINATOR.gen_scheduler in ["cosine", "constant", "dev_perf"]:
                # linear warmup stage
                if train_step < cfg.DISCRIMINATOR.gen_warmup_step:
                    curr_gen_lr = (
                            cfg.DISCRIMINATOR.gen_lr * train_step / cfg.TRAIN.warmup_step
                    )
                    gen_optimizer.param_groups[0]["lr"] = curr_gen_lr
                else:
                    if cfg.DISCRIMINATOR.gen_scheduler == "cosine":
                        gen_scheduler.step()
            elif cfg.DISCRIMINATOR.gen_scheduler == "inv_sqrt":
                gen_scheduler.step()

            if cfg.DISCRIMINATOR.dis_scheduler in ["cosine", "constant", "dev_perf"]:
                # linear warmup stage
                if train_step < cfg.DISCRIMINATOR.dis_warmup_step:
                    curr_dis_lr = (
                            cfg.DISCRIMINATOR.dis_lr * train_step / cfg.TRAIN.warmup_step
                    )
                    dis_optimizer.param_groups[0]["lr"] = curr_dis_lr
                else:
                    if cfg.DISCRIMINATOR.dis_scheduler == "cosine":
                        dis_scheduler.step()
            elif cfg.DISCRIMINATOR.dis_scheduler == "inv_sqrt":
                dis_scheduler.step()

        if train_step % cfg.TRAIN.log_interval == 0:
            torch.distributed.all_reduce(log_train_loss)
            torch.distributed.all_reduce(log_grad_norm)
            torch.distributed.all_reduce(log_token_num)

            torch.distributed.all_reduce(log_gen_train_loss)
            torch.distributed.all_reduce(log_gen_num)

            log_train_loss /= log_token_num
            log_grad_norm /= cfg.TRAIN.log_interval * num_gpus
            log_gen_train_loss = (
                log_gen_train_loss / log_gen_num
                if log_gen_num != 0
                else torch.tensor(0.0).float().to(device)
            )
            log_dis_train_loss = (
                log_dis_train_loss / log_dis_num
                if log_dis_num != 0
                else torch.tensor(0.0).float().to(device)
            )
            if args.local_rank == 0:
                elapsed = time.time() - log_start_time
                logging.info(
                    "Train Step {}/{}, lr={:f}, tokens/s={:.1f},"
                    " nll={:.4f}, ppl={:.2f}, grad norm={}, gen_loss={:5.4f}, dis_loss={:5.4f}".format(
                        train_step,
                        cfg.TRAIN.max_step,
                        optimizer.param_groups[0]["lr"],
                        log_token_num.item() / elapsed,
                        log_train_loss.item(),
                        math.exp(log_train_loss.item()),
                        log_grad_norm.item(),
                        log_gen_train_loss.item(),
                        log_dis_train_loss.item(),
                    )
                )

            log_train_loss[()] = 0
            log_grad_norm[()] = 0
            log_token_num[()] = 0

            log_gen_train_loss[()] = 0
            log_gen_num[()] = 0

            log_dis_train_loss[()] = 0
            log_dis_num[()] = 0

            log_start_time = time.time()

        if train_step % cfg.TRAIN.eval_interval == 0:
            eval_start_time = time.time()

            val_token_num, val_total_nll, val_metrics = evaluate(
                eval_iter=val_iter, dis_val_iter=None, mode="eval"
            )

            val_token_num_pt = torch.tensor(val_token_num).to(device)
            val_total_nll_pt = torch.tensor(val_total_nll / 10000.0).to(device)

            torch.distributed.all_reduce(val_token_num_pt)
            torch.distributed.all_reduce(val_total_nll_pt)

            val_token_num = val_token_num_pt.item()
            val_total_nll = val_total_nll_pt.item()

            val_nll = val_total_nll / (val_token_num / 10000.0)

            if args.local_rank == 0:
                logging.info(
                    "Eval step {}, time={}s, val nll={}, val ppl={}, #evaluated tokens={}, bleu={}, self_bleu={"
                    "}, class_acc={}".format(
                        train_step,
                        time.time() - eval_start_time,
                        val_nll,
                        math.exp(val_nll),
                        val_token_num,
                        val_metrics[0],
                        val_metrics[1],
                        val_metrics[2],
                    )
                )
            # Save the model if the validation loss is the best we've seen so far.

            # Always save after eval if save_all is true and not debug
            if not args.debug and args.save_all:
                name = f"checkpoint_{train_step}.pt"
                save_checkpoint(
                    args,
                    model,
                    optimizer,
                    dis_optimizer,
                    gen_optimizer,
                    dataset.vocab,
                    train_step,
                    val_nll,
                    scheduler,
                    dis_scheduler,
                    gen_scheduler,
                    name,
                )

            # Save last checkpoint if not debug and not save_all
            if not args.debug and not args.save_all:
                name = "checkpoint_last.pt"
                save_checkpoint(
                    args,
                    model,
                    optimizer,
                    dis_optimizer,
                    gen_optimizer,
                    dataset.vocab,
                    train_step,
                    val_nll,
                    scheduler,
                    dis_scheduler,
                    gen_scheduler,
                    name,
                )

            if not best_val_nll or val_nll < best_val_nll:
                best_val_nll = val_nll

                if not args.debug:
                    name = "checkpoint_best.pt"
                    save_checkpoint(
                        args,
                        model,
                        optimizer,
                        dis_optimizer,
                        gen_optimizer,
                        dataset.vocab,
                        train_step,
                        best_val_nll,
                        scheduler,
                        dis_scheduler,
                        gen_scheduler,
                        name,
                    )

                test_start_time = time.time()

                def calculate_test_nll_during_training(test_iter):

                    # Run on test data.
                    # test_token_num, test_total_nll, test_gen_loss, test_gen_num = evaluate(
                    #     eval_iter=test_iter, dis_val_iter=dis_test_iter
                    # )
                    test_token_num, test_total_nll, test_metrics = evaluate(
                        eval_iter=test_iter, dis_val_iter=None, mode="test"
                    )
                    test_token_num_pt = torch.tensor(test_token_num).to(device)
                    test_total_nll_pt = torch.tensor(test_total_nll / 10000.0).to(
                        device
                    )
                    # test_gen_loss_pt = torch.tensor(test_gen_loss).to(device)
                    # test_gen_num_pt = torch.tensor(test_gen_num).to(device)

                    torch.distributed.all_reduce(test_token_num_pt)
                    torch.distributed.all_reduce(test_total_nll_pt)
                    # torch.distributed.all_reduce(test_gen_loss_pt)
                    # torch.distributed.all_reduce(test_gen_num_pt)

                    test_token_num = test_token_num_pt.item()
                    test_nll = test_total_nll_pt.item() / (test_token_num / 10000.0)

                    # test_gen_loss = test_gen_loss_pt.item()
                    # test_gen_num = test_gen_num_pt.item()
                    # test_gen_loss = (
                    #     test_gen_loss / test_gen_num
                    #     if test_gen_num != 0
                    #     else torch.tensor(0.0).float().to(device)
                    # )
                    return test_token_num, test_nll, test_metrics

                (
                    test_token_num,
                    test_nll,
                    test_metrics,
                ) = calculate_test_nll_during_training(test_iter)

                if args.local_rank == 0:
                    logging.info(
                        "Test step {}, time={}s, test nll={}, test ppl={}, #evaluated tokens={}"
                        " test_bleu={}".format(
                            train_step,
                            time.time() - test_start_time,
                            test_nll,
                            math.exp(test_nll),
                            test_token_num,
                            test_metrics[0],
                        )
                    )
            # dev-performance based learning rate annealing
            if cfg.TRAIN.scheduler == "dev_perf":
                scheduler.step(val_nll)

        if train_step == cfg.TRAIN.max_step:
            logging.info("-" * 100)
            logging.info("End of training")
            break