in information-gain-filtration/run_clm_igf.py [0:0]
def main():
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain data files for WikiText.",
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--data_file",
type=str,
default=None,
help=(
"A jbl file containing tokenized data which can be split as objective dataset, "
"train_dataset and test_dataset."
),
)
parser.add_argument(
"--igf_data_file",
type=str,
default=None,
help="A jbl file containing the context and information gain pairs to train secondary learner.",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the final fine-tuned model is stored.",
)
parser.add_argument(
"--tokenizer_name",
default=None,
type=str,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--context_len",
default=32,
type=int,
help=(
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
),
)
parser.add_argument(
"--size_objective_set",
default=100,
type=int,
help="number of articles that are long enough to be used as our objective set",
)
parser.add_argument(
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
)
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
parser.add_argument(
"--secondary_learner_batch_size",
default=128,
type=int,
help="batch size of training data for secondary learner",
)
parser.add_argument(
"--batch_size",
default=16,
type=int,
help="batch size of training data of language model(openai-community/gpt2) ",
)
parser.add_argument(
"--eval_interval",
default=10,
type=int,
help=(
"decay the selectivity of our secondary learner filter from "
"1 standard deviation above average to 1 below average after 10 batches"
),
)
parser.add_argument(
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
)
parser.add_argument(
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
)
parser.add_argument(
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
)
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
parser.add_argument(
"--threshold",
default=1.0,
type=float,
help=(
"The threshold value used by secondary learner to filter the train_data and allow only"
" informative data as input to the model"
),
)
parser.add_argument(
"--finetuned_model_name", default="openai-community/gpt2_finetuned.pt", type=str, help="finetuned_model_name"
)
parser.add_argument(
"--recopy_model",
default=recopy_gpt2,
type=str,
help="Reset the model to the original pretrained GPT-2 weights after each iteration",
)
# function calls
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
generate_n_pairs(
context_len=32,
max_steps=10,
size_objective_set=100,
min_len=1026,
trim=True,
data_file="data/tokenized_stories_train_wikitext103.jbl",
igf_data_file="igf_context_pairs.jbl",
)
# Load train data for secondary learner
secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
# Train secondary learner
secondary_learner = training_secondary_learner(
secondary_learner_train_data,
secondary_learner_max_epochs=15,
secondary_learner_batch_size=128,
eval_freq=100,
igf_model_path="igf_model.pt",
)
# load pretrained openai-community/gpt2 model
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
set_seed(42)
# Generate train and test data to train and evaluate openai-community/gpt2 model
train_dataset, test_dataset = generate_datasets(
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
)
# fine-tuning of the openai-community/gpt2 model using igf (Information Gain Filtration)
finetune(
model,
train_dataset,
test_dataset,
context_len=32,
max_steps=1000,
batch_size=16,
threshold=1.0,
recopy_model=recopy_gpt2,
secondary_learner=secondary_learner,
eval_interval=10,
finetuned_model_name="openai-community/gpt2_finetuned.pt",
)