def main()

in zero-shot-distillation/distill_classifier.py [0:0]


def main():
    parser = HfArgumentParser(
        (DataTrainingArguments, TeacherModelArguments, StudentModelArguments, DistillTrainingArguments),
        description=DESCRIPTION,
    )

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        data_args, teacher_args, student_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1])
        )
    else:
        data_args, teacher_args, student_args, training_args = parser.parse_args_into_dataclasses()

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        utils.logging.set_verbosity_info()
        utils.logging.enable_default_handler()
        utils.logging.enable_explicit_format()

    if training_args.local_rank != -1:
        raise ValueError("Distributed training is not currently supported.")
    if training_args.tpu_num_cores is not None:
        raise ValueError("TPU acceleration is not currently supported.")

    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # 1. read in data
    examples = read_lines(data_args.data_file)
    class_names = read_lines(data_args.class_names_file)

    # 2. get teacher predictions and load into dataset
    logger.info("Generating predictions from zero-shot teacher model")
    teacher_soft_preds = get_teacher_predictions(
        teacher_args.teacher_name_or_path,
        examples,
        class_names,
        teacher_args.hypothesis_template,
        teacher_args.teacher_batch_size,
        teacher_args.temperature,
        teacher_args.multi_label,
        data_args.use_fast_tokenizer,
        training_args.no_cuda,
        training_args.fp16,
    )
    dataset = Dataset.from_dict(
        {
            "text": examples,
            "labels": teacher_soft_preds,
        }
    )

    # 3. create student
    logger.info("Initializing student model")
    model = AutoModelForSequenceClassification.from_pretrained(
        student_args.student_name_or_path, num_labels=len(class_names)
    )
    tokenizer = AutoTokenizer.from_pretrained(student_args.student_name_or_path, use_fast=data_args.use_fast_tokenizer)
    model.config.id2label = dict(enumerate(class_names))
    model.config.label2id = {label: i for i, label in enumerate(class_names)}

    # 4. train student on teacher predictions
    dataset = dataset.map(tokenizer, input_columns="text")
    dataset.set_format("torch")

    def compute_metrics(p, return_outputs=False):
        preds = p.predictions.argmax(-1)
        proxy_labels = p.label_ids.argmax(-1)  # "label_ids" are actually distributions
        return {"agreement": (preds == proxy_labels).mean().item()}

    trainer = DistillationTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=dataset,
        compute_metrics=compute_metrics,
    )

    if training_args.do_train:
        logger.info("Training student model on teacher predictions")
        trainer.train()

    if training_args.do_eval:
        agreement = trainer.evaluate(eval_dataset=dataset)["eval_agreement"]
        logger.info(f"Agreement of student and teacher predictions: {agreement * 100:0.2f}%")

    trainer.save_model()