in distilvit/train.py [0:0]
def train(args):
get_nltk()
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
feature_extractor = AutoImageProcessor.from_pretrained(args.feature_extractor_model)
if args.base_model:
if args.base_model_revision:
model = VisionEncoderDecoderModel.from_pretrained(
args.base_model, revision=args.base_model_revision
)
else:
model = VisionEncoderDecoderModel.from_pretrained(args.base_model)
model_name = f"{args.base_model}+fine-tuned"
else:
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
args.encoder_model, args.decoder_model
)
model_name = (
f"{args.encoder_model.split('/')[-1]}-{args.decoder_model.split('/')[-1]}"
)
#freeze_model_layers(model, freeze_encoder_layers=3, freeze_decoder_layers=3)
args.device = torch.device(args.device)
print("Using device", args.device)
model.to(args.device)
tokenizer = AutoTokenizer.from_pretrained(args.decoder_model)
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
tokenizer.pad_token = tokenizer.eos_token
# update the model config
model.config.eos_token_id = tokenizer.eos_token_id
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
save_path = os.path.join(args.save_dir, model_name)
print("Sources", args.dataset)
datasets = []
for name in args.dataset:
get_dataset = DATASETS[name]
datasets.append(
get_dataset(
args.feature_extractor_model,
args.decoder_model,
args=args,
)
)
print("Datasets loaded", datasets)
combined = DatasetDict()
for split in datasets[0].keys():
combined[split] = concatenate_datasets([ds[split] for ds in datasets])
ds = combined.shuffle(seed=THE_ANSWER_TO_LIFE_THE_UNIVERSE_AND_EVERYTHING)
print("Datasets combined and shuffled", ds)
os.makedirs(args.checkpoints_dir, exist_ok=True)
training_args = dict(
predict_with_generate=True,
evaluation_strategy="steps",
save_strategy="steps",
per_device_train_batch_size=50,
per_device_eval_batch_size=50,
num_train_epochs=args.num_train_epochs,
output_dir=args.checkpoints_dir,
metric_for_best_model="eval_rougeL",
save_total_limit=10,
load_best_model_at_end=True,
eval_steps=args.eval_steps,
save_steps=args.save_steps,
report_to="wandb",
generation_num_beams=2,
generation_max_length=50
)
if args.base_model:
training_args["generation_config"] = args.model_id
training_args = Seq2SeqTrainingArguments(**training_args)
last_checkpoint = get_last_checkpoint(args.checkpoints_dir)
metrics_logger_callback = MetricsLoggerCallback(
os.path.join(args.checkpoints_dir, "metrics.txt")
)
trainer = Seq2SeqTrainer(
model=model,
tokenizer=feature_extractor,
args=training_args,
compute_metrics=partial(compute_metrics,
tokenizer,
rouge,
meteor,
args=args,
),
train_dataset=ds["train"],
eval_dataset=ds["validation"],
data_collator=partial(data_collator, tokenizer),
callbacks=[
EarlyStoppingCallback(early_stopping_patience=3),
metrics_logger_callback,
],
)
if last_checkpoint is not None:
trainer.train(resume_from_checkpoint=last_checkpoint)
else:
trainer.train()
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
# quantize model
q_args = [
"quantize",
"--model_id",
save_path,
"--quantize",
"--task",
"image-to-text-with-past",
]
old = sys.argv
sys.argv = q_args
try:
quantize()
finally:
sys.argv = old
print(f"Model saved to {save_path}. You may need to copy in model card in docs directory.")
if args.push_to_hub:
push_to_hub(args.model_id, save_path, args.tag, "New training")