def main()

in information-gain-filtration/run_clm_igf.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain data files for WikiText.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default=None,
        help=(
            "A jbl file containing tokenized data which can be split as objective dataset, "
            "train_dataset and test_dataset."
        ),
    )

    parser.add_argument(
        "--igf_data_file",
        type=str,
        default=None,
        help="A jbl file containing the context and information gain pairs to train secondary learner.",
    )

    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the final fine-tuned model is stored.",
    )

    parser.add_argument(
        "--tokenizer_name",
        default=None,
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")

    parser.add_argument(
        "--context_len",
        default=32,
        type=int,
        help=(
            "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        ),
    )

    parser.add_argument(
        "--size_objective_set",
        default=100,
        type=int,
        help="number of articles that are long enough to be used as our objective set",
    )
    parser.add_argument(
        "--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
    )

    parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")

    parser.add_argument(
        "--secondary_learner_batch_size",
        default=128,
        type=int,
        help="batch size of training data for secondary learner",
    )

    parser.add_argument(
        "--batch_size",
        default=16,
        type=int,
        help="batch size of training data of language model(openai-community/gpt2) ",
    )

    parser.add_argument(
        "--eval_interval",
        default=10,
        type=int,
        help=(
            "decay the selectivity of our secondary learner filter from "
            "1 standard deviation above average to 1 below average after 10 batches"
        ),
    )

    parser.add_argument(
        "--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
    )

    parser.add_argument(
        "--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
    )

    parser.add_argument(
        "--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
    )

    parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")

    parser.add_argument(
        "--threshold",
        default=1.0,
        type=float,
        help=(
            "The threshold value used by secondary learner to filter the train_data and allow only"
            " informative data as input to the model"
        ),
    )

    parser.add_argument(
        "--finetuned_model_name", default="openai-community/gpt2_finetuned.pt", type=str, help="finetuned_model_name"
    )

    parser.add_argument(
        "--recopy_model",
        default=recopy_gpt2,
        type=str,
        help="Reset the model to the original pretrained GPT-2 weights after each iteration",
    )

    # function calls
    # Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
    generate_n_pairs(
        context_len=32,
        max_steps=10,
        size_objective_set=100,
        min_len=1026,
        trim=True,
        data_file="data/tokenized_stories_train_wikitext103.jbl",
        igf_data_file="igf_context_pairs.jbl",
    )

    # Load train data for secondary learner
    secondary_learner_train_data = joblib.load("data/IGF_values.jbl")

    # Train secondary learner
    secondary_learner = training_secondary_learner(
        secondary_learner_train_data,
        secondary_learner_max_epochs=15,
        secondary_learner_batch_size=128,
        eval_freq=100,
        igf_model_path="igf_model.pt",
    )

    # load pretrained openai-community/gpt2 model
    model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
    set_seed(42)

    # Generate train and test data to train and evaluate openai-community/gpt2 model
    train_dataset, test_dataset = generate_datasets(
        context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
    )

    # fine-tuning of the openai-community/gpt2 model using igf (Information Gain Filtration)
    finetune(
        model,
        train_dataset,
        test_dataset,
        context_len=32,
        max_steps=1000,
        batch_size=16,
        threshold=1.0,
        recopy_model=recopy_gpt2,
        secondary_learner=secondary_learner,
        eval_interval=10,
        finetuned_model_name="openai-community/gpt2_finetuned.pt",
    )