def train()

in anli/src/nli/train_with_confidence.py [0:0]


def train(local_rank, args):
    # debug = False
    # print("GPU:", gpu)
    # world_size = args.world_size
    args.global_rank = args.node_rank * args.gpus_per_node + local_rank
    args.local_rank = local_rank
    # args.warmup_steps = 20
    debug_count = 1000
    num_epoch = args.epochs

    actual_train_batch_size = (
        args.world_size
        * args.per_gpu_train_batch_size
        * args.gradient_accumulation_steps
    )
    args.actual_train_batch_size = actual_train_batch_size

    set_seed(args.seed)
    num_labels = 3  # we are doing NLI so we set num_labels = 3, for other task we can change this value.

    max_length = args.max_length

    model_class_item = MODEL_CLASSES[args.model_class_name]
    model_name = model_class_item["model_name"]
    do_lower_case = (
        model_class_item["do_lower_case"]
        if "do_lower_case" in model_class_item
        else False
    )

    tokenizer = model_class_item["tokenizer"].from_pretrained(
        model_name,
        cache_dir=str(config.PRO_ROOT / "trans_cache"),
        do_lower_case=do_lower_case,
    )

    model = model_class_item["sequence_classification"].from_pretrained(
        model_name,
        cache_dir=str(config.PRO_ROOT / "trans_cache"),
        num_labels=num_labels,
    )

    if args.train_with_lm:
        model.lm_head = RobertaLMHead(model.config)

    if args.train_from_scratch:
        print("Training model from scratch!")
        model.init_weights()

    padding_token_value = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
    padding_segement_value = model_class_item["padding_segement_value"]
    padding_att_value = model_class_item["padding_att_value"]
    left_pad = model_class_item["left_pad"] if "left_pad" in model_class_item else False

    batch_size_per_gpu_train = args.per_gpu_train_batch_size
    batch_size_per_gpu_eval = args.per_gpu_eval_batch_size

    if not args.cpu and not args.single_gpu:
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            world_size=args.world_size,
            rank=args.global_rank,
        )

    train_data_str = args.train_data
    train_data_weights_str = args.train_weights
    eval_data_str = args.eval_data

    train_data_name = []
    train_data_path = []
    train_data_list = []
    train_data_weights = []

    eval_data_name = []
    eval_data_path = []
    eval_data_list = []

    train_data_named_path = train_data_str.split(",")
    weights_str = (
        train_data_weights_str.split(",")
        if train_data_weights_str is not None
        else None
    )

    eval_data_named_path = eval_data_str.split(",")

    for named_path in train_data_named_path:
        ind = named_path.find(":")
        name = named_path[:ind]
        path = named_path[ind + 1 :]
        if name in registered_path:
            d_list = common.load_jsonl(registered_path[name])
        else:
            d_list = common.load_jsonl(path)

        train_data_name.append(name)
        train_data_path.append(path)

        train_data_list.append(d_list)

    if weights_str is not None:
        for weights in weights_str:
            train_data_weights.append(float(weights))
    else:
        for i in range(len(train_data_list)):
            train_data_weights.append(1)

    for named_path in eval_data_named_path:
        ind = named_path.find(":")
        name = named_path[:ind]
        path = name[ind + 1 :]
        if name in registered_path:
            d_list = common.load_jsonl(registered_path[name])
        else:
            d_list = common.load_jsonl(path)
        eval_data_name.append(name)
        eval_data_path.append(path)

        eval_data_list.append(d_list)

    assert len(train_data_weights) == len(train_data_list)

    batching_schema = {
        "uid": RawFlintField(),
        "y": LabelFlintField(),
        "input_ids": ArrayIndexFlintField(
            pad_idx=padding_token_value, left_pad=left_pad
        ),
        "token_type_ids": ArrayIndexFlintField(
            pad_idx=padding_segement_value, left_pad=left_pad
        ),
        "attention_mask": ArrayIndexFlintField(
            pad_idx=padding_att_value, left_pad=left_pad
        ),
    }

    if args.flip_sent:
        print("Flipping hypothesis and premise")
        data_transformer = FlippedNLITransform(model_name, tokenizer, max_length)
    else:
        data_transformer = NLITransform(model_name, tokenizer, max_length)
    # data_transformer = NLITransform(model_name, tokenizer, max_length)
    # data_transformer = NLITransform(model_name, tokenizer, max_length, with_element=True)

    eval_data_loaders = []
    for eval_d_list in eval_data_list:
        d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(
            eval_d_list, data_transformer, batching_schema, batch_size_per_gpu_eval
        )
        eval_data_loaders.append(d_dataloader)

    # Estimate the training size:
    training_list = []
    for i in range(len(train_data_list)):
        print("Build Training Data ...")
        train_d_list = train_data_list[i]
        train_d_name = train_data_name[i]
        train_d_weight = train_data_weights[i]
        cur_train_list = sample_data_list(
            train_d_list, train_d_weight
        )  # change later  # we can apply different sample strategy here.
        print(
            f"Data Name:{train_d_name}; Weight: {train_d_weight}; "
            f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}"
        )
        training_list.extend(cur_train_list)
    estimated_training_size = len(training_list)
    print("Estimated training size:", estimated_training_size)
    # Estimate the training size ends:

    # t_total = estimated_training_size // args.gradient_accumulation_steps * num_epoch
    t_total = estimated_training_size * num_epoch // args.actual_train_batch_size
    if (
        args.warmup_steps <= 0
    ):  # set the warmup steps to 0.1 * total step if the given warmup step is -1.
        args.warmup_steps = int(t_total * 0.1)

    if not args.cpu:
        torch.cuda.set_device(args.local_rank)
        model.cuda(args.local_rank)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    optimizer = AdamW(
        optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
    )
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.fp16_opt_level
        )

    if not args.cpu and not args.single_gpu:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
        )

    args_dict = dict(vars(args))
    file_path_prefix = "."
    if args.global_rank in [-1, 0]:
        print("Total Steps:", t_total)
        args.total_step = t_total
        print("Warmup Steps:", args.warmup_steps)
        print("Actual Training Batch Size:", actual_train_batch_size)
        print("Arguments", pp.pprint(args))

    # Let build the logger and log everything before the start of the first training epoch.
    if args.global_rank in [-1, 0]:  # only do logging if we use cpu or global_rank=0
        if not args.debug_mode:
            file_path_prefix, date = save_tool.gen_file_prefix(
                f"{args.experiment_name}"
            )
            # # # Create Log File
            # Save the source code.
            script_name = os.path.basename(__file__)
            with open(os.path.join(file_path_prefix, script_name), "w") as out_f, open(
                __file__, "r"
            ) as it:
                out_f.write(it.read())
                out_f.flush()

            # Save option file
            common.save_json(args_dict, os.path.join(file_path_prefix, "args.json"))
            checkpoints_path = Path(file_path_prefix) / "checkpoints"
            if not checkpoints_path.exists():
                checkpoints_path.mkdir()
            prediction_path = Path(file_path_prefix) / "predictions"
            if not prediction_path.exists():
                prediction_path.mkdir()

    global_step = 0

    # print(f"Global Rank:{args.global_rank} ### ", 'Init!')

    for epoch in tqdm(
        range(num_epoch), desc="Epoch", disable=args.global_rank not in [-1, 0]
    ):
        # Let's build up training dataset for this epoch
        training_list = []
        for i in range(len(train_data_list)):
            print("Build Training Data ...")
            train_d_list = train_data_list[i]
            train_d_name = train_data_name[i]
            train_d_weight = train_data_weights[i]
            cur_train_list = sample_data_list(
                train_d_list, train_d_weight
            )  # change later  # we can apply different sample strategy here.
            print(
                f"Data Name:{train_d_name}; Weight: {train_d_weight}; "
                f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}"
            )
            training_list.extend(cur_train_list)

        random.shuffle(training_list)
        train_dataset = NLIDataset(training_list, data_transformer)

        train_sampler = SequentialSampler(train_dataset)
        if not args.cpu and not args.single_gpu:
            print("Use distributed sampler.")
            train_sampler = DistributedSampler(
                train_dataset, args.world_size, args.global_rank, shuffle=True
            )

        train_dataloader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size_per_gpu_train,
            shuffle=False,  #
            num_workers=0,
            pin_memory=True,
            sampler=train_sampler,
            collate_fn=BaseBatchBuilder(batching_schema),
        )  #
        # training build finished.

        print(debug_node_info(args), "epoch: ", epoch)

        if not args.cpu and not args.single_gpu:
            train_sampler.set_epoch(
                epoch
            )  # setup the epoch to ensure random sampling at each epoch

        pb = tqdm(
            total=len(train_dataloader),
            desc="Iteration",
            disable=args.global_rank not in [-1, 0],
        )
        for forward_step, batch in enumerate(train_dataloader, 0,):
            model.train()

            batch = move_to_device(batch, local_rank)
            # print(batch['input_ids'], batch['y'])
            if args.model_class_name in ["distilbert", "bart-large"]:
                outputs = model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["y"],
                )
            else:
                outputs = model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    token_type_ids=batch["token_type_ids"],
                    labels=batch["y"],
                )
            loss, logits = outputs[:2]
            loss_val = loss.item()
            lm_loss_val = 0
            total_loss_val = 0

            # Change: Update logic change here
            # First, make the update to a copy of the model, and then compute the evaluation

            pb.set_description(
                f"Iteration: Loss : {loss_val}, LM Loss : {lm_loss_val}, Total Loss : {total_loss_val}"
            )

            # Accumulated loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # if this forward step need model updates
            # handle fp16
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

                # Gradient clip: if max_grad_norm < 0
            if (forward_step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm
                        )
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), args.max_grad_norm
                        )

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()

                global_step += 1

                if (
                    args.global_rank in [-1, 0]
                    and args.eval_frequency > 0
                    and global_step % args.eval_frequency == 0
                ):
                    entropy, acc = eval_step(
                        args,
                        model,
                        checkpoints_path,
                        prediction_path,
                        epoch,
                        global_step,
                        eval_data_name,
                        eval_data_list,
                        eval_data_loaders,
                        evaluation_dataset,
                        optimizer,
                        scheduler,
                        skip_model_save=True,
                        save_prediction=False,
                    )
                    log = {
                        "valid_entropy": entropy.item(),
                        "valid_acc": acc,
                        "global_step": global_step,
                    }
                    lbk.write_metric(log)

            pb.update(1)
        pb.close()

        # End of epoch evaluation.
        if args.global_rank in [-1, 0]:
            eval_step(
                args,
                model,
                checkpoints_path,
                prediction_path,
                epoch,
                global_step,
                eval_data_name,
                eval_data_list,
                eval_data_loaders,
                evaluation_dataset,
                optimizer,
                scheduler,
            )