anli/src/nli/train_with_confidence.py [64:502]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
MODEL_CLASSES = {
    "bert-base": {
        "model_name": "bert-base-uncased",
        "tokenizer": BertTokenizer,
        "sequence_classification": BertForSequenceClassification,
        # "padding_token_value": 0,
        "padding_segement_value": 0,
        "padding_att_value": 0,
        "do_lower_case": True,
    },
    "bert-large": {
        "model_name": "bert-large-uncased",
        "tokenizer": BertTokenizer,
        "sequence_classification": BertForSequenceClassification,
        # "padding_token_value": 0,
        "padding_segement_value": 0,
        "padding_att_value": 0,
        "do_lower_case": True,
    },
    "xlnet-base": {
        "model_name": "xlnet-base-cased",
        "tokenizer": XLNetTokenizer,
        "sequence_classification": XLNetForSequenceClassification,
        # "padding_token_value": 0,
        "padding_segement_value": 4,
        "padding_att_value": 0,
        "left_pad": True,
    },
    "xlnet-large": {
        "model_name": "xlnet-large-cased",
        "tokenizer": XLNetTokenizer,
        "sequence_classification": XLNetForSequenceClassification,
        "padding_segement_value": 4,
        "padding_att_value": 0,
        "left_pad": True,
    },
    "roberta-base": {
        "model_name": "roberta-base",
        "tokenizer": RobertaTokenizer,
        "sequence_classification": RobertaForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "roberta-large": {
        "model_name": "roberta-large",
        "tokenizer": RobertaTokenizer,
        "sequence_classification": RobertaForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "albert-xxlarge": {
        "model_name": "albert-xxlarge-v2",
        "tokenizer": AlbertTokenizer,
        "sequence_classification": AlbertForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "distilbert": {
        "model_name": "distilbert-base-cased",
        "tokenizer": DistilBertTokenizer,
        "sequence_classification": DistilBertForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "bart-large": {
        "model_name": "facebook/bart-large",
        "tokenizer": BartTokenizer,
        "sequence_classification": BartForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "electra-base": {
        "model_name": "google/electra-base-discriminator",
        "tokenizer": ElectraTokenizer,
        "sequence_classification": ElectraForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "electra-large": {
        "model_name": "google/electra-large-discriminator",
        "tokenizer": ElectraTokenizer,
        "sequence_classification": ElectraForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "chinese-roberta-large": {
        "model_name": "hfl/chinese-roberta-wwm-ext-large",
        "tokenizer": AutoTokenizer,
        "sequence_classification": AutoModelForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
}

registered_path = {
    "snli_train": config.PRO_ROOT / "data/build/snli/train.jsonl",
    "snli_dev": config.PRO_ROOT / "data/build/snli/dev.jsonl",
    "snli_test": config.PRO_ROOT / "data/build/snli/test.jsonl",
    "mnli_train": config.PRO_ROOT / "data/build/mnli/train.jsonl",
    "mnli_m_dev": config.PRO_ROOT / "data/build/mnli/m_dev.jsonl",
    "mnli_mm_dev": config.PRO_ROOT / "data/build/mnli/mm_dev.jsonl",
    "mnli_rand_train": config.PRO_ROOT / "data/build/mnli/rand_train.jsonl",
    "mnli_rand_dev": config.PRO_ROOT / "data/build/mnli/rand_dev.jsonl",
    "mnli_rand_test": config.PRO_ROOT / "data/build/mnli/rand_test.jsonl",
    "anli_r1_train": config.PRO_ROOT / "data/build/anli/r1/train.jsonl",
    "anli_r1_dev": config.PRO_ROOT / "data/build/anli/r1/dev.jsonl",
    "anli_r1_test": config.PRO_ROOT / "data/build/anli/r1/test.jsonl",
    "anli_r2_train": config.PRO_ROOT / "data/build/anli/r2/train.jsonl",
    "anli_r2_dev": config.PRO_ROOT / "data/build/anli/r2/dev.jsonl",
    "anli_r2_test": config.PRO_ROOT / "data/build/anli/r2/test.jsonl",
    "anli_r3_train": config.PRO_ROOT / "data/build/anli/r3/train.jsonl",
    "anli_r3_dev": config.PRO_ROOT / "data/build/anli/r3/dev.jsonl",
    "anli_r3_test": config.PRO_ROOT / "data/build/anli/r3/test.jsonl",
    "ocnli_train": config.PRO_ROOT / "data/build/ocnli/train.jsonl",
    "ocnli_dev": config.PRO_ROOT / "data/build/ocnli/dev.jsonl",
}

nli_label2index = {
    "e": 0,
    "n": 1,
    "c": 2,
    "h": -1,
}


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class NLIDataset(Dataset):
    def __init__(self, data_list, transform) -> None:
        super().__init__()
        self.d_list = data_list
        self.len = len(self.d_list)
        self.transform = transform

    def __getitem__(self, index: int):
        return self.transform(self.d_list[index])

    # you should write schema for each of the input elements

    def __len__(self) -> int:
        return self.len


class NLITransform(object):
    def __init__(self, model_name, tokenizer, max_length=None):
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, sample):
        processed_sample = dict()
        processed_sample["uid"] = sample["uid"]
        processed_sample["gold_label"] = sample["label"]
        processed_sample["y"] = nli_label2index[sample["label"]]

        # premise: str = sample['premise']
        premise: str = sample["context"] if "context" in sample else sample["premise"]
        hypothesis: str = sample["hypothesis"]

        if premise.strip() == "":
            premise = "empty"

        if hypothesis.strip() == "":
            hypothesis = "empty"

        tokenized_input_seq_pair = self.tokenizer.encode_plus(
            premise,
            hypothesis,
            max_length=self.max_length,
            return_token_type_ids=True,
            truncation=True,
        )

        processed_sample.update(tokenized_input_seq_pair)

        return processed_sample


class FlippedNLITransform(object):
    def __init__(self, model_name, tokenizer, max_length=None):
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, sample):
        processed_sample = dict()
        processed_sample["uid"] = sample["uid"]
        processed_sample["gold_label"] = sample["label"]
        processed_sample["y"] = nli_label2index[sample["label"]]

        # premise: str = sample['premise']
        premise: str = sample["context"] if "context" in sample else sample["premise"]
        hypothesis: str = sample["hypothesis"]

        if premise.strip() == "":
            premise = "empty"

        if hypothesis.strip() == "":
            hypothesis = "empty"

        tokenized_input_seq_pair = self.tokenizer.encode_plus(
            hypothesis,
            premise,
            max_length=self.max_length,
            return_token_type_ids=True,
            truncation=True,
        )

        processed_sample.update(tokenized_input_seq_pair)

        return processed_sample


def build_eval_dataset_loader_and_sampler(
    d_list, data_transformer, batching_schema, batch_size_per_gpu_eval
):
    d_dataset = NLIDataset(d_list, data_transformer)
    d_sampler = SequentialSampler(d_dataset)
    d_dataloader = DataLoader(
        dataset=d_dataset,
        batch_size=batch_size_per_gpu_eval,
        shuffle=False,  #
        num_workers=0,
        pin_memory=True,
        sampler=d_sampler,
        collate_fn=BaseBatchBuilder(batching_schema),
    )  #
    return d_dataset, d_sampler, d_dataloader


def sample_data_list(d_list, ratio):
    if ratio <= 0:
        raise ValueError(
            "Invalid training weight ratio. Please change --train_weights."
        )
    upper_int = int(math.ceil(ratio))
    if upper_int == 1:
        return d_list  # if ratio is 1 then we just return the data list
    else:
        sampled_d_list = []
        for _ in range(upper_int):
            sampled_d_list.extend(copy.deepcopy(d_list))
        if np.isclose(ratio, upper_int):
            return sampled_d_list
        else:
            sampled_length = int(ratio * len(d_list))
            random.shuffle(sampled_d_list)
            return sampled_d_list[:sampled_length]


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cpu", action="store_true", help="If set, we only use CPU.")
    parser.add_argument(
        "--single_gpu", action="store_true", help="If set, we only use single GPU."
    )
    parser.add_argument("--fp16", action="store_true", help="If set, we will use fp16.")

    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )

    # environment arguments
    parser.add_argument(
        "-s", "--seed", default=1, type=int, metavar="N", help="manual random seed"
    )
    parser.add_argument(
        "-n", "--num_nodes", default=1, type=int, metavar="N", help="number of nodes"
    )
    parser.add_argument(
        "-g", "--gpus_per_node", default=1, type=int, help="number of gpus per node"
    )
    parser.add_argument(
        "-nr", "--node_rank", default=0, type=int, help="ranking within the nodes"
    )

    # experiments specific arguments
    parser.add_argument(
        "--debug_mode",
        action="store_true",
        dest="debug_mode",
        help="weather this is debug mode or normal",
    )

    parser.add_argument(
        "--model_class_name", type=str, help="Set the model class of the experiment.",
    )

    parser.add_argument(
        "--experiment_name",
        type=str,
        help="Set the name of the experiment. [model_name]/[data]/[task]/[other]",
    )

    parser.add_argument(
        "--save_prediction",
        action="store_true",
        dest="save_prediction",
        help="Do we want to save prediction",
    )

    parser.add_argument(
        "--epochs",
        default=2,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--per_gpu_train_batch_size",
        default=16,
        type=int,
        help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=64,
        type=int,
        help="Batch size per GPU/CPU for evaluation.",
    )

    parser.add_argument(
        "--max_length", default=160, type=int, help="Max length of the sequences."
    )

    parser.add_argument(
        "--warmup_steps", default=-1, type=int, help="Linear warmup over warmup_steps."
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--learning_rate",
        default=1e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--weight_decay", default=0.0, type=float, help="Weight decay if we apply some."
    )
    parser.add_argument(
        "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
    )

    parser.add_argument(
        "--eval_frequency",
        default=1000,
        type=int,
        help="set the evaluation frequency, evaluate every X global step.",
    )

    parser.add_argument(
        "--train_data", type=str, help="The training data used in the experiments."
    )

    parser.add_argument(
        "--train_weights",
        type=str,
        help="The training data weights used in the experiments.",
    )

    parser.add_argument(
        "--eval_data", type=str, help="The training data used in the experiments."
    )

    parser.add_argument(
        "--flip_sent",
        default=False,
        action="store_true",
        help="Flip the hypothesis and premise",
    )

    parser.add_argument(
        "--train_from_scratch",
        default=False,
        action="store_true",
        help="Train model without using the pretrained weights",
    )

    parser.add_argument(
        "--train_with_lm",
        default=False,
        action="store_true",
        help="Train model with LM",
    )

    parser.add_argument(
        "--add_lm",
        default=False,
        action="store_true",
        help="Train model with LM add loss",
    )

    parser.add_argument(
        "--lm_lambda", default=0.1, type=float, help="lambda to train LM loss",
    )

    parser.add_argument("--skip_model_save", default=False, action="store_true")
    parser.add_argument("--save_on_wandb", default=False, action="store_true")

    # parser.add_argument("--local_rank", default=0, type=int)

    parser.add_argument("--slurm", default=False, action="store_true")

    args = parser.parse_args()
    return args


def main(args):
    if args.cpu:
        args.world_size = 1
        train(-1, args)
    elif args.single_gpu:
        args.world_size = 1
        train(0, args)
    else:  # distributed multiGPU training
        #########################################################
        args.world_size = args.gpus_per_node * args.num_nodes  #
        # train(args.local_rank, args)
        os.environ["MASTER_ADDR"] = "127.0.0.1"  # This is the IP address for nlp5
        # maybe we will automatically retrieve the IP later.
        os.environ["MASTER_PORT"] = "88888"  #
        mp.spawn(
            train, nprocs=args.gpus_per_node, args=(args,)
        )  # spawn how many process in this node
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



anli/src/nli/training.py [58:496]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
MODEL_CLASSES = {
    "bert-base": {
        "model_name": "bert-base-uncased",
        "tokenizer": BertTokenizer,
        "sequence_classification": BertForSequenceClassification,
        # "padding_token_value": 0,
        "padding_segement_value": 0,
        "padding_att_value": 0,
        "do_lower_case": True,
    },
    "bert-large": {
        "model_name": "bert-large-uncased",
        "tokenizer": BertTokenizer,
        "sequence_classification": BertForSequenceClassification,
        # "padding_token_value": 0,
        "padding_segement_value": 0,
        "padding_att_value": 0,
        "do_lower_case": True,
    },
    "xlnet-base": {
        "model_name": "xlnet-base-cased",
        "tokenizer": XLNetTokenizer,
        "sequence_classification": XLNetForSequenceClassification,
        # "padding_token_value": 0,
        "padding_segement_value": 4,
        "padding_att_value": 0,
        "left_pad": True,
    },
    "xlnet-large": {
        "model_name": "xlnet-large-cased",
        "tokenizer": XLNetTokenizer,
        "sequence_classification": XLNetForSequenceClassification,
        "padding_segement_value": 4,
        "padding_att_value": 0,
        "left_pad": True,
    },
    "roberta-base": {
        "model_name": "roberta-base",
        "tokenizer": RobertaTokenizer,
        "sequence_classification": RobertaForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "roberta-large": {
        "model_name": "roberta-large",
        "tokenizer": RobertaTokenizer,
        "sequence_classification": RobertaForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "albert-xxlarge": {
        "model_name": "albert-xxlarge-v2",
        "tokenizer": AlbertTokenizer,
        "sequence_classification": AlbertForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "distilbert": {
        "model_name": "distilbert-base-cased",
        "tokenizer": DistilBertTokenizer,
        "sequence_classification": DistilBertForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "bart-large": {
        "model_name": "facebook/bart-large",
        "tokenizer": BartTokenizer,
        "sequence_classification": BartForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "electra-base": {
        "model_name": "google/electra-base-discriminator",
        "tokenizer": ElectraTokenizer,
        "sequence_classification": ElectraForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "electra-large": {
        "model_name": "google/electra-large-discriminator",
        "tokenizer": ElectraTokenizer,
        "sequence_classification": ElectraForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
    "chinese-roberta-large": {
        "model_name": "hfl/chinese-roberta-wwm-ext-large",
        "tokenizer": AutoTokenizer,
        "sequence_classification": AutoModelForSequenceClassification,
        "padding_segement_value": 0,
        "padding_att_value": 0,
    },
}

registered_path = {
    "snli_train": config.PRO_ROOT / "data/build/snli/train.jsonl",
    "snli_dev": config.PRO_ROOT / "data/build/snli/dev.jsonl",
    "snli_test": config.PRO_ROOT / "data/build/snli/test.jsonl",
    "mnli_train": config.PRO_ROOT / "data/build/mnli/train.jsonl",
    "mnli_m_dev": config.PRO_ROOT / "data/build/mnli/m_dev.jsonl",
    "mnli_mm_dev": config.PRO_ROOT / "data/build/mnli/mm_dev.jsonl",
    "mnli_rand_train": config.PRO_ROOT / "data/build/mnli/rand_train.jsonl",
    "mnli_rand_dev": config.PRO_ROOT / "data/build/mnli/rand_dev.jsonl",
    "mnli_rand_test": config.PRO_ROOT / "data/build/mnli/rand_test.jsonl",
    "anli_r1_train": config.PRO_ROOT / "data/build/anli/r1/train.jsonl",
    "anli_r1_dev": config.PRO_ROOT / "data/build/anli/r1/dev.jsonl",
    "anli_r1_test": config.PRO_ROOT / "data/build/anli/r1/test.jsonl",
    "anli_r2_train": config.PRO_ROOT / "data/build/anli/r2/train.jsonl",
    "anli_r2_dev": config.PRO_ROOT / "data/build/anli/r2/dev.jsonl",
    "anli_r2_test": config.PRO_ROOT / "data/build/anli/r2/test.jsonl",
    "anli_r3_train": config.PRO_ROOT / "data/build/anli/r3/train.jsonl",
    "anli_r3_dev": config.PRO_ROOT / "data/build/anli/r3/dev.jsonl",
    "anli_r3_test": config.PRO_ROOT / "data/build/anli/r3/test.jsonl",
    "ocnli_train": config.PRO_ROOT / "data/build/ocnli/train.jsonl",
    "ocnli_dev": config.PRO_ROOT / "data/build/ocnli/dev.jsonl",
}

nli_label2index = {
    "e": 0,
    "n": 1,
    "c": 2,
    "h": -1,
}


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class NLIDataset(Dataset):
    def __init__(self, data_list, transform) -> None:
        super().__init__()
        self.d_list = data_list
        self.len = len(self.d_list)
        self.transform = transform

    def __getitem__(self, index: int):
        return self.transform(self.d_list[index])

    # you should write schema for each of the input elements

    def __len__(self) -> int:
        return self.len


class NLITransform(object):
    def __init__(self, model_name, tokenizer, max_length=None):
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, sample):
        processed_sample = dict()
        processed_sample["uid"] = sample["uid"]
        processed_sample["gold_label"] = sample["label"]
        processed_sample["y"] = nli_label2index[sample["label"]]

        # premise: str = sample['premise']
        premise: str = sample["context"] if "context" in sample else sample["premise"]
        hypothesis: str = sample["hypothesis"]

        if premise.strip() == "":
            premise = "empty"

        if hypothesis.strip() == "":
            hypothesis = "empty"

        tokenized_input_seq_pair = self.tokenizer.encode_plus(
            premise,
            hypothesis,
            max_length=self.max_length,
            return_token_type_ids=True,
            truncation=True,
        )

        processed_sample.update(tokenized_input_seq_pair)

        return processed_sample


class FlippedNLITransform(object):
    def __init__(self, model_name, tokenizer, max_length=None):
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __call__(self, sample):
        processed_sample = dict()
        processed_sample["uid"] = sample["uid"]
        processed_sample["gold_label"] = sample["label"]
        processed_sample["y"] = nli_label2index[sample["label"]]

        # premise: str = sample['premise']
        premise: str = sample["context"] if "context" in sample else sample["premise"]
        hypothesis: str = sample["hypothesis"]

        if premise.strip() == "":
            premise = "empty"

        if hypothesis.strip() == "":
            hypothesis = "empty"

        tokenized_input_seq_pair = self.tokenizer.encode_plus(
            hypothesis,
            premise,
            max_length=self.max_length,
            return_token_type_ids=True,
            truncation=True,
        )

        processed_sample.update(tokenized_input_seq_pair)

        return processed_sample


def build_eval_dataset_loader_and_sampler(
    d_list, data_transformer, batching_schema, batch_size_per_gpu_eval
):
    d_dataset = NLIDataset(d_list, data_transformer)
    d_sampler = SequentialSampler(d_dataset)
    d_dataloader = DataLoader(
        dataset=d_dataset,
        batch_size=batch_size_per_gpu_eval,
        shuffle=False,  #
        num_workers=0,
        pin_memory=True,
        sampler=d_sampler,
        collate_fn=BaseBatchBuilder(batching_schema),
    )  #
    return d_dataset, d_sampler, d_dataloader


def sample_data_list(d_list, ratio):
    if ratio <= 0:
        raise ValueError(
            "Invalid training weight ratio. Please change --train_weights."
        )
    upper_int = int(math.ceil(ratio))
    if upper_int == 1:
        return d_list  # if ratio is 1 then we just return the data list
    else:
        sampled_d_list = []
        for _ in range(upper_int):
            sampled_d_list.extend(copy.deepcopy(d_list))
        if np.isclose(ratio, upper_int):
            return sampled_d_list
        else:
            sampled_length = int(ratio * len(d_list))
            random.shuffle(sampled_d_list)
            return sampled_d_list[:sampled_length]


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cpu", action="store_true", help="If set, we only use CPU.")
    parser.add_argument(
        "--single_gpu", action="store_true", help="If set, we only use single GPU."
    )
    parser.add_argument("--fp16", action="store_true", help="If set, we will use fp16.")

    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )

    # environment arguments
    parser.add_argument(
        "-s", "--seed", default=1, type=int, metavar="N", help="manual random seed"
    )
    parser.add_argument(
        "-n", "--num_nodes", default=1, type=int, metavar="N", help="number of nodes"
    )
    parser.add_argument(
        "-g", "--gpus_per_node", default=1, type=int, help="number of gpus per node"
    )
    parser.add_argument(
        "-nr", "--node_rank", default=0, type=int, help="ranking within the nodes"
    )

    # experiments specific arguments
    parser.add_argument(
        "--debug_mode",
        action="store_true",
        dest="debug_mode",
        help="weather this is debug mode or normal",
    )

    parser.add_argument(
        "--model_class_name", type=str, help="Set the model class of the experiment.",
    )

    parser.add_argument(
        "--experiment_name",
        type=str,
        help="Set the name of the experiment. [model_name]/[data]/[task]/[other]",
    )

    parser.add_argument(
        "--save_prediction",
        action="store_true",
        dest="save_prediction",
        help="Do we want to save prediction",
    )

    parser.add_argument(
        "--epochs",
        default=2,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--per_gpu_train_batch_size",
        default=16,
        type=int,
        help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=64,
        type=int,
        help="Batch size per GPU/CPU for evaluation.",
    )

    parser.add_argument(
        "--max_length", default=160, type=int, help="Max length of the sequences."
    )

    parser.add_argument(
        "--warmup_steps", default=-1, type=int, help="Linear warmup over warmup_steps."
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--learning_rate",
        default=1e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--weight_decay", default=0.0, type=float, help="Weight decay if we apply some."
    )
    parser.add_argument(
        "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
    )

    parser.add_argument(
        "--eval_frequency",
        default=1000,
        type=int,
        help="set the evaluation frequency, evaluate every X global step.",
    )

    parser.add_argument(
        "--train_data", type=str, help="The training data used in the experiments."
    )

    parser.add_argument(
        "--train_weights",
        type=str,
        help="The training data weights used in the experiments.",
    )

    parser.add_argument(
        "--eval_data", type=str, help="The training data used in the experiments."
    )

    parser.add_argument(
        "--flip_sent",
        default=False,
        action="store_true",
        help="Flip the hypothesis and premise",
    )

    parser.add_argument(
        "--train_from_scratch",
        default=False,
        action="store_true",
        help="Train model without using the pretrained weights",
    )

    parser.add_argument(
        "--train_with_lm",
        default=False,
        action="store_true",
        help="Train model with LM",
    )

    parser.add_argument(
        "--add_lm",
        default=False,
        action="store_true",
        help="Train model with LM add loss",
    )

    parser.add_argument(
        "--lm_lambda", default=0.1, type=float, help="lambda to train LM loss",
    )

    parser.add_argument("--skip_model_save", default=False, action="store_true")
    parser.add_argument("--save_on_wandb", default=False, action="store_true")

    # parser.add_argument("--local_rank", default=0, type=int)

    parser.add_argument("--slurm", default=False, action="store_true")

    args = parser.parse_args()
    return args


def main(args):
    if args.cpu:
        args.world_size = 1
        train(-1, args)
    elif args.single_gpu:
        args.world_size = 1
        train(0, args)
    else:  # distributed multiGPU training
        #########################################################
        args.world_size = args.gpus_per_node * args.num_nodes  #
        # train(args.local_rank, args)
        os.environ["MASTER_ADDR"] = "127.0.0.1"  # This is the IP address for nlp5
        # maybe we will automatically retrieve the IP later.
        os.environ["MASTER_PORT"] = "88888"  #
        mp.spawn(
            train, nprocs=args.gpus_per_node, args=(args,)
        )  # spawn how many process in this node
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



