in zero-shot-distillation/distill_classifier.py [0:0]
def main():
parser = HfArgumentParser(
(DataTrainingArguments, TeacherModelArguments, StudentModelArguments, DistillTrainingArguments),
description=DESCRIPTION,
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
data_args, teacher_args, student_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
data_args, teacher_args, student_args, training_args = parser.parse_args_into_dataclasses()
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
utils.logging.set_verbosity_info()
utils.logging.enable_default_handler()
utils.logging.enable_explicit_format()
if training_args.local_rank != -1:
raise ValueError("Distributed training is not currently supported.")
if training_args.tpu_num_cores is not None:
raise ValueError("TPU acceleration is not currently supported.")
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# 1. read in data
examples = read_lines(data_args.data_file)
class_names = read_lines(data_args.class_names_file)
# 2. get teacher predictions and load into dataset
logger.info("Generating predictions from zero-shot teacher model")
teacher_soft_preds = get_teacher_predictions(
teacher_args.teacher_name_or_path,
examples,
class_names,
teacher_args.hypothesis_template,
teacher_args.teacher_batch_size,
teacher_args.temperature,
teacher_args.multi_label,
data_args.use_fast_tokenizer,
training_args.no_cuda,
training_args.fp16,
)
dataset = Dataset.from_dict(
{
"text": examples,
"labels": teacher_soft_preds,
}
)
# 3. create student
logger.info("Initializing student model")
model = AutoModelForSequenceClassification.from_pretrained(
student_args.student_name_or_path, num_labels=len(class_names)
)
tokenizer = AutoTokenizer.from_pretrained(student_args.student_name_or_path, use_fast=data_args.use_fast_tokenizer)
model.config.id2label = dict(enumerate(class_names))
model.config.label2id = {label: i for i, label in enumerate(class_names)}
# 4. train student on teacher predictions
dataset = dataset.map(tokenizer, input_columns="text")
dataset.set_format("torch")
def compute_metrics(p, return_outputs=False):
preds = p.predictions.argmax(-1)
proxy_labels = p.label_ids.argmax(-1) # "label_ids" are actually distributions
return {"agreement": (preds == proxy_labels).mean().item()}
trainer = DistillationTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=dataset,
compute_metrics=compute_metrics,
)
if training_args.do_train:
logger.info("Training student model on teacher predictions")
trainer.train()
if training_args.do_eval:
agreement = trainer.evaluate(eval_dataset=dataset)["eval_agreement"]
logger.info(f"Agreement of student and teacher predictions: {agreement * 100:0.2f}%")
trainer.save_model()