anli/src/nli/train_with_scramble.py [1040:1253]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                        model_output_dir = (
                            checkpoints_path / current_checkpoint_filename
                        )
                        if not model_output_dir.exists():
                            model_output_dir.mkdir()
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training

                        torch.save(
                            model_to_save.state_dict(),
                            str(model_output_dir / "model.pt"),
                        )
                        torch.save(
                            optimizer.state_dict(),
                            str(model_output_dir / "optimizer.pt"),
                        )
                        torch.save(
                            scheduler.state_dict(),
                            str(model_output_dir / "scheduler.pt"),
                        )

                    # save prediction:
                    if not args.debug_mode and args.save_prediction:
                        cur_results_path = prediction_path / current_checkpoint_filename
                        if not cur_results_path.exists():
                            cur_results_path.mkdir(parents=True)
                        for key, item in r_dict.items():
                            common.save_jsonl(
                                item["predictions"], cur_results_path / f"{key}.jsonl"
                            )

                        # avoid saving too many things
                        for key, item in r_dict.items():
                            del r_dict[key]["predictions"]
                        common.save_json(
                            r_dict, cur_results_path / "results_dict.json", indent=2
                        )
            pb.update(1)
        pb.close()

        # End of epoch evaluation.
        if args.global_rank in [-1, 0]:
            r_dict = dict()
            # Eval loop:
            for i in range(len(eval_data_name)):
                cur_eval_data_name = eval_data_name[i]
                cur_eval_data_list = eval_data_list[i]
                cur_eval_dataloader = eval_data_loaders[i]
                # cur_eval_raw_data_list = eval_raw_data_list[i]

                evaluation_dataset(
                    args,
                    cur_eval_dataloader,
                    cur_eval_data_list,
                    model,
                    r_dict,
                    eval_name=cur_eval_data_name,
                )

            # saving checkpoints
            current_checkpoint_filename = f"e({epoch})|i({global_step})"

            for i in range(len(eval_data_name)):
                cur_eval_data_name = eval_data_name[i]
                current_checkpoint_filename += f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})'

            if not args.debug_mode:
                # save model:
                model_output_dir = checkpoints_path / current_checkpoint_filename
                if not model_output_dir.exists():
                    model_output_dir.mkdir()
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training

                torch.save(
                    model_to_save.state_dict(), str(model_output_dir / "model.pt")
                )
                torch.save(
                    optimizer.state_dict(), str(model_output_dir / "optimizer.pt")
                )
                torch.save(
                    scheduler.state_dict(), str(model_output_dir / "scheduler.pt")
                )

            # save prediction:
            if not args.debug_mode and args.save_prediction:
                cur_results_path = prediction_path / current_checkpoint_filename
                if not cur_results_path.exists():
                    cur_results_path.mkdir(parents=True)
                for key, item in r_dict.items():
                    common.save_jsonl(
                        item["predictions"], cur_results_path / f"{key}.jsonl"
                    )

                # avoid saving too many things
                for key, item in r_dict.items():
                    del r_dict[key]["predictions"]
                common.save_json(
                    r_dict, cur_results_path / "results_dict.json", indent=2
                )


id2label = {
    0: "e",
    1: "n",
    2: "c",
    -1: "-",
}


def count_acc(gt_list, pred_list):
    assert len(gt_list) == len(pred_list)
    gt_dict = list_dict_data_tool.list_to_dict(gt_list, "uid")
    pred_list = list_dict_data_tool.list_to_dict(pred_list, "uid")
    total_count = 0
    hit = 0
    for key, value in pred_list.items():
        if gt_dict[key]["label"] == value["predicted_label"]:
            hit += 1
        total_count += 1
    return hit, total_count


def evaluation_dataset(args, eval_dataloader, eval_list, model, r_dict, eval_name):
    # r_dict = dict()
    pred_output_list = eval_model(model, eval_dataloader, args.global_rank, args)
    predictions = pred_output_list
    hit, total = count_acc(eval_list, pred_output_list)

    print(debug_node_info(args), f"{eval_name} Acc:", hit, total, hit / total)

    r_dict[f"{eval_name}"] = {
        "acc": hit / total,
        "correct_count": hit,
        "total_count": total,
        "predictions": predictions,
    }


def eval_model(model, dev_dataloader, device_num, args):
    model.eval()

    uid_list = []
    y_list = []
    pred_list = []
    logits_list = []

    with torch.no_grad():
        for i, batch in enumerate(dev_dataloader, 0):
            batch = move_to_device(batch, device_num)

            if args.model_class_name in ["distilbert", "bart-large"]:
                outputs = model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["y"],
                )
            else:
                outputs = model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    token_type_ids=batch["token_type_ids"],
                    labels=batch["y"],
                )

            loss, logits = outputs[:2]

            uid_list.extend(list(batch["uid"]))
            y_list.extend(batch["y"].tolist())
            pred_list.extend(torch.max(logits, 1)[1].view(logits.size(0)).tolist())
            logits_list.extend(logits.tolist())

    assert len(pred_list) == len(logits_list)
    assert len(pred_list) == len(logits_list)

    result_items_list = []
    for i in range(len(uid_list)):
        r_item = dict()
        r_item["uid"] = uid_list[i]
        r_item["logits"] = logits_list[i]
        r_item["predicted_label"] = id2label[pred_list[i]]

        result_items_list.append(r_item)

    return result_items_list


def debug_node_info(args):
    names = ["global_rank", "local_rank", "node_rank"]
    values = []

    for name in names:
        if name in args:
            values.append(getattr(args, name))
        else:
            return "Pro:No node info "

    return (
        "Pro:"
        + "|".join([f"{name}:{value}" for name, value in zip(names, values)])
        + "||Print:"
    )


if __name__ == "__main__":
    args = get_args()
    d = datetime.datetime.today()
    main_exp_type = f"nli_{args.model_class_name}_{args.experiment_name}"
    # logdir = Path.cwd()
    exp_dir = (
        Path("/checkpoint/koustuvs")
        / "projects"
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



anli/src/nli/training.py [928:1142]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                        model_output_dir = (
                            checkpoints_path / current_checkpoint_filename
                        )
                        if not model_output_dir.exists():
                            model_output_dir.mkdir()
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training

                        torch.save(
                            model_to_save.state_dict(),
                            str(model_output_dir / "model.pt"),
                        )
                        torch.save(
                            optimizer.state_dict(),
                            str(model_output_dir / "optimizer.pt"),
                        )
                        torch.save(
                            scheduler.state_dict(),
                            str(model_output_dir / "scheduler.pt"),
                        )

                    # save prediction:
                    if not args.debug_mode and args.save_prediction:
                        cur_results_path = prediction_path / current_checkpoint_filename
                        if not cur_results_path.exists():
                            cur_results_path.mkdir(parents=True)
                        for key, item in r_dict.items():
                            common.save_jsonl(
                                item["predictions"], cur_results_path / f"{key}.jsonl"
                            )

                        # avoid saving too many things
                        for key, item in r_dict.items():
                            del r_dict[key]["predictions"]
                        common.save_json(
                            r_dict, cur_results_path / "results_dict.json", indent=2
                        )

            pb.update(1)
        pb.close()

        # End of epoch evaluation.
        if args.global_rank in [-1, 0]:
            r_dict = dict()
            # Eval loop:
            for i in range(len(eval_data_name)):
                cur_eval_data_name = eval_data_name[i]
                cur_eval_data_list = eval_data_list[i]
                cur_eval_dataloader = eval_data_loaders[i]
                # cur_eval_raw_data_list = eval_raw_data_list[i]

                evaluation_dataset(
                    args,
                    cur_eval_dataloader,
                    cur_eval_data_list,
                    model,
                    r_dict,
                    eval_name=cur_eval_data_name,
                )

            # saving checkpoints
            current_checkpoint_filename = f"e({epoch})|i({global_step})"

            for i in range(len(eval_data_name)):
                cur_eval_data_name = eval_data_name[i]
                current_checkpoint_filename += f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})'

            if not args.debug_mode:
                # save model:
                model_output_dir = checkpoints_path / current_checkpoint_filename
                if not model_output_dir.exists():
                    model_output_dir.mkdir()
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training

                torch.save(
                    model_to_save.state_dict(), str(model_output_dir / "model.pt")
                )
                torch.save(
                    optimizer.state_dict(), str(model_output_dir / "optimizer.pt")
                )
                torch.save(
                    scheduler.state_dict(), str(model_output_dir / "scheduler.pt")
                )

            # save prediction:
            if not args.debug_mode and args.save_prediction:
                cur_results_path = prediction_path / current_checkpoint_filename
                if not cur_results_path.exists():
                    cur_results_path.mkdir(parents=True)
                for key, item in r_dict.items():
                    common.save_jsonl(
                        item["predictions"], cur_results_path / f"{key}.jsonl"
                    )

                # avoid saving too many things
                for key, item in r_dict.items():
                    del r_dict[key]["predictions"]
                common.save_json(
                    r_dict, cur_results_path / "results_dict.json", indent=2
                )


id2label = {
    0: "e",
    1: "n",
    2: "c",
    -1: "-",
}


def count_acc(gt_list, pred_list):
    assert len(gt_list) == len(pred_list)
    gt_dict = list_dict_data_tool.list_to_dict(gt_list, "uid")
    pred_list = list_dict_data_tool.list_to_dict(pred_list, "uid")
    total_count = 0
    hit = 0
    for key, value in pred_list.items():
        if gt_dict[key]["label"] == value["predicted_label"]:
            hit += 1
        total_count += 1
    return hit, total_count


def evaluation_dataset(args, eval_dataloader, eval_list, model, r_dict, eval_name):
    # r_dict = dict()
    pred_output_list = eval_model(model, eval_dataloader, args.global_rank, args)
    predictions = pred_output_list
    hit, total = count_acc(eval_list, pred_output_list)

    print(debug_node_info(args), f"{eval_name} Acc:", hit, total, hit / total)

    r_dict[f"{eval_name}"] = {
        "acc": hit / total,
        "correct_count": hit,
        "total_count": total,
        "predictions": predictions,
    }


def eval_model(model, dev_dataloader, device_num, args):
    model.eval()

    uid_list = []
    y_list = []
    pred_list = []
    logits_list = []

    with torch.no_grad():
        for i, batch in enumerate(dev_dataloader, 0):
            batch = move_to_device(batch, device_num)

            if args.model_class_name in ["distilbert", "bart-large"]:
                outputs = model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["y"],
                )
            else:
                outputs = model(
                    batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    token_type_ids=batch["token_type_ids"],
                    labels=batch["y"],
                )

            loss, logits = outputs[:2]

            uid_list.extend(list(batch["uid"]))
            y_list.extend(batch["y"].tolist())
            pred_list.extend(torch.max(logits, 1)[1].view(logits.size(0)).tolist())
            logits_list.extend(logits.tolist())

    assert len(pred_list) == len(logits_list)
    assert len(pred_list) == len(logits_list)

    result_items_list = []
    for i in range(len(uid_list)):
        r_item = dict()
        r_item["uid"] = uid_list[i]
        r_item["logits"] = logits_list[i]
        r_item["predicted_label"] = id2label[pred_list[i]]

        result_items_list.append(r_item)

    return result_items_list


def debug_node_info(args):
    names = ["global_rank", "local_rank", "node_rank"]
    values = []

    for name in names:
        if name in args:
            values.append(getattr(args, name))
        else:
            return "Pro:No node info "

    return (
        "Pro:"
        + "|".join([f"{name}:{value}" for name, value in zip(names, values)])
        + "||Print:"
    )


if __name__ == "__main__":
    args = get_args()
    d = datetime.datetime.today()
    main_exp_type = f"nli_{args.model_class_name}_{args.experiment_name}"
    # logdir = Path.cwd()
    exp_dir = (
        Path("/checkpoint/koustuvs")
        / "projects"
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



