def train()

in distilvit/train.py [0:0]


def train(args):
    get_nltk()
    rouge = evaluate.load("rouge")
    meteor = evaluate.load("meteor")

    feature_extractor = AutoImageProcessor.from_pretrained(args.feature_extractor_model)
    if args.base_model:
        if args.base_model_revision:
            model = VisionEncoderDecoderModel.from_pretrained(
                args.base_model, revision=args.base_model_revision
            )
        else:
            model = VisionEncoderDecoderModel.from_pretrained(args.base_model)

        model_name = f"{args.base_model}+fine-tuned"

    else:
        model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
            args.encoder_model, args.decoder_model
        )

        model_name = (
            f"{args.encoder_model.split('/')[-1]}-{args.decoder_model.split('/')[-1]}"
        )

    #freeze_model_layers(model, freeze_encoder_layers=3, freeze_decoder_layers=3)

    args.device = torch.device(args.device)
    print("Using device", args.device)
    model.to(args.device)

    tokenizer = AutoTokenizer.from_pretrained(args.decoder_model)
    # GPT2 only has bos/eos tokens but not decoder_start/pad tokens
    tokenizer.pad_token = tokenizer.eos_token

    # update the model config
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.decoder_start_token_id = tokenizer.bos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id

    save_path = os.path.join(args.save_dir, model_name)

    print("Sources", args.dataset)
    datasets = []
    for name in args.dataset:
        get_dataset = DATASETS[name]
        datasets.append(
            get_dataset(
                args.feature_extractor_model,
                args.decoder_model,
                args=args,
            )
        )

    print("Datasets loaded", datasets)
    combined = DatasetDict()
    for split in datasets[0].keys():
        combined[split] = concatenate_datasets([ds[split] for ds in datasets])

    ds = combined.shuffle(seed=THE_ANSWER_TO_LIFE_THE_UNIVERSE_AND_EVERYTHING)

    print("Datasets combined and shuffled", ds)
    os.makedirs(args.checkpoints_dir, exist_ok=True)

    training_args = dict(
        predict_with_generate=True,
        evaluation_strategy="steps",
        save_strategy="steps",
        per_device_train_batch_size=50,
        per_device_eval_batch_size=50,
        num_train_epochs=args.num_train_epochs,
        output_dir=args.checkpoints_dir,
        metric_for_best_model="eval_rougeL",
        save_total_limit=10,
        load_best_model_at_end=True,
        eval_steps=args.eval_steps,
        save_steps=args.save_steps,
        report_to="wandb",
        generation_num_beams=2,
        generation_max_length=50
    )

    if args.base_model:
        training_args["generation_config"] = args.model_id

    training_args = Seq2SeqTrainingArguments(**training_args)

    last_checkpoint = get_last_checkpoint(args.checkpoints_dir)
    metrics_logger_callback = MetricsLoggerCallback(
        os.path.join(args.checkpoints_dir, "metrics.txt")
    )

    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=feature_extractor,
        args=training_args,
        compute_metrics=partial(compute_metrics,
            tokenizer,
            rouge,
            meteor,
            args=args,
        ),
        train_dataset=ds["train"],
        eval_dataset=ds["validation"],
        data_collator=partial(data_collator, tokenizer),
        callbacks=[
            EarlyStoppingCallback(early_stopping_patience=3),
            metrics_logger_callback,
        ],
    )

    if last_checkpoint is not None:
        trainer.train(resume_from_checkpoint=last_checkpoint)
    else:
        trainer.train()
    trainer.save_model(save_path)
    tokenizer.save_pretrained(save_path)

    # quantize model
    q_args = [
        "quantize",
        "--model_id",
        save_path,
        "--quantize",
        "--task",
        "image-to-text-with-past",
    ]
    old = sys.argv
    sys.argv = q_args
    try:
        quantize()
    finally:
        sys.argv = old

    print(f"Model saved to {save_path}. You may need to copy in model card in docs directory.")

    if args.push_to_hub:
        push_to_hub(args.model_id, save_path, args.tag, "New training")