in src/jobs/tune_t5.py [0:0]
def train(self):
torch.cuda.empty_cache()
shrink_model = self.shrink_remove_encoder_layers > 0 or self.shrink_remove_decoder_layers > 0 or \
self.shrink_decoder_index_remove or self.shrink_encoder_index_remove
os.environ["WANDB_LOG_MODEL"] = "end" # save the model to WandB
config = {"learning_rate": self.learning_rate, "batch_size": self.batch_size,
"model_name": self.model_name,
"label_column": self.label_column,
"brevity_weight": self.brevity_weight,
"use_keywords": self.use_keywords,
"learning_rate_decay": self.learning_rate_decay,
"single_tab_handling": self.single_tab_handling,
"shrink_encoder_index_remove": self.shrink_encoder_index_remove,
"shrink_decoder_index_remove": self.shrink_decoder_index_remove,
"shrink_remove_encoder_layers": self.shrink_remove_encoder_layers,
"shrink_remove_decoder_layers": self.shrink_remove_decoder_layers,
"shorten_training_label_boost": self.shorten_training_label_boost,
"model_start_artifact": self.model_start_artifact,
"input_prompt_id": INPUT_PROMPT_ID, "filename": self.filename}
wandb.init(project="tab_grouping",
config=config)
print(f"W&B Run ID: {wandb.run.id}")
print(f"W&B Run Name: {wandb.run.name}")
print({json.dumps(config)})
tokenized_training_dataset = self.train_dataset.map(self.preprocess_function, batched=True)
tokenized_eval_dataset = self.eval_dataset.map(self.preprocess_function, batched=True)
if self.learning_rate_decay:
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=self.learning_rate,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=1,
num_train_epochs=2,
weight_decay=0.01,
save_total_limit=1,
save_strategy="epoch",
lr_scheduler_type="cosine",
warmup_ratio=0.05
)
else:
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=self.learning_rate,
per_device_train_batch_size=self.batch_size,
per_device_eval_batch_size=1,
num_train_epochs=2,
weight_decay=0.01,
save_total_limit=1,
save_strategy="epoch",
warmup_ratio=0.05)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=tokenized_training_dataset,
eval_dataset=tokenized_training_dataset
)
do_first_run = True # We were testing skipping the first run but got bad results
tokenized_validation_dataset = self.validation_dataset.map(self.preprocess_function, batched=True)
if do_first_run:
trainer.train()
if shrink_model:
self.run_eval(tokenized_eval_dataset, name="Pre Shrink Eval", prefix="preshrink")
else:
self.run_eval(tokenized_eval_dataset)
self.run_eval(tokenized_validation_dataset, name="Single Tab Validation", prefix="single_tab_val")
has_second_train_run = False
if shrink_model:
has_second_train_run = True
self.shrink_remove_layers(self.model, "encoder", self.shrink_remove_encoder_layers,
self.shrink_encoder_index_remove)
self.shrink_remove_layers(self.model, "decoder", self.shrink_remove_decoder_layers,
self.shrink_decoder_index_remove)
if has_second_train_run:
training_args.num_train_epochs = 1 if do_first_run else 3
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=tokenized_training_dataset,
eval_dataset=tokenized_training_dataset
)
trainer.train()
name_prefix = "Post Remove"
dataset_prefix = "post_remove"
self.run_eval(tokenized_eval_dataset, name=f"{name_prefix} Eval", prefix=f"{dataset_prefix}_eval")
tokenized_validation_dataset = self.validation_dataset.map(self.preprocess_function, batched=True)
self.run_eval(tokenized_validation_dataset, name="Post Remove Single Tab Validation",
prefix="post_remove_single_tab_val")
if FINE_TUNE_BLOCK_BAD_WORDS:
self.model.generation_config.update(bad_words_ids=get_bad_word_ids())
local_save_name = "./t5-finetuned-topic"
self.model.save_pretrained(local_save_name)
self.tokenizer.save_pretrained(local_save_name)
current_date = datetime.now()
date_string = current_date.isoformat().replace(":", "_")
upload_directory("./t5-finetuned-topic", "stage-fx-tab-grouping", f"topic/models/{date_string}/", depth=1)
del self.model
torch.cuda.empty_cache()
self.tokenizer = T5Tokenizer.from_pretrained(local_save_name)
self.model = T5ForConditionalGeneration.from_pretrained(local_save_name).to(self.device)
self.run_eval(tokenized_validation_dataset, name="2-Post Remove Single Tab Validation",
prefix="2_post_remove_single_tab_val")
wandb.finish()
self.model = None
torch.cuda.empty_cache()