def train()

in src/jobs/tune_t5.py [0:0]


    def train(self):
        torch.cuda.empty_cache()
        shrink_model = self.shrink_remove_encoder_layers > 0 or self.shrink_remove_decoder_layers > 0 or \
                       self.shrink_decoder_index_remove or self.shrink_encoder_index_remove

        os.environ["WANDB_LOG_MODEL"] = "end"  # save the model to WandB
        config = {"learning_rate": self.learning_rate, "batch_size": self.batch_size,
                           "model_name": self.model_name,
                           "label_column": self.label_column,
                           "brevity_weight": self.brevity_weight,
                           "use_keywords": self.use_keywords,
                           "learning_rate_decay": self.learning_rate_decay,
                           "single_tab_handling": self.single_tab_handling,
                           "shrink_encoder_index_remove": self.shrink_encoder_index_remove,
                           "shrink_decoder_index_remove": self.shrink_decoder_index_remove,
                           "shrink_remove_encoder_layers": self.shrink_remove_encoder_layers,
                           "shrink_remove_decoder_layers": self.shrink_remove_decoder_layers,
                           "shorten_training_label_boost": self.shorten_training_label_boost,
                            "model_start_artifact": self.model_start_artifact,
                           "input_prompt_id": INPUT_PROMPT_ID, "filename": self.filename}
        wandb.init(project="tab_grouping",
                   config=config)
        print(f"W&B Run ID: {wandb.run.id}")
        print(f"W&B Run Name: {wandb.run.name}")
        print({json.dumps(config)})
        tokenized_training_dataset = self.train_dataset.map(self.preprocess_function, batched=True)
        tokenized_eval_dataset = self.eval_dataset.map(self.preprocess_function, batched=True)

        if self.learning_rate_decay:
            training_args = TrainingArguments(
                output_dir="./results",
                evaluation_strategy="epoch",
                learning_rate=self.learning_rate,
                per_device_train_batch_size=self.batch_size,
                per_device_eval_batch_size=1,
                num_train_epochs=2,
                weight_decay=0.01,
                save_total_limit=1,
                save_strategy="epoch",
                lr_scheduler_type="cosine",
                warmup_ratio=0.05
            )
        else:
            training_args = TrainingArguments(
                output_dir="./results",
                evaluation_strategy="epoch",
                learning_rate=self.learning_rate,
                per_device_train_batch_size=self.batch_size,
                per_device_eval_batch_size=1,
                num_train_epochs=2,
                weight_decay=0.01,
                save_total_limit=1,
                save_strategy="epoch",
                warmup_ratio=0.05)


        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_training_dataset,
            eval_dataset=tokenized_training_dataset
        )

        do_first_run = True # We were testing skipping the first run but got bad results

        tokenized_validation_dataset = self.validation_dataset.map(self.preprocess_function, batched=True)

        if do_first_run:
            trainer.train()
            if shrink_model:
                self.run_eval(tokenized_eval_dataset, name="Pre Shrink Eval", prefix="preshrink")
            else:
                self.run_eval(tokenized_eval_dataset)
            self.run_eval(tokenized_validation_dataset, name="Single Tab Validation", prefix="single_tab_val")

        has_second_train_run = False
        if shrink_model:
            has_second_train_run = True
            self.shrink_remove_layers(self.model, "encoder", self.shrink_remove_encoder_layers,
                                      self.shrink_encoder_index_remove)
            self.shrink_remove_layers(self.model, "decoder", self.shrink_remove_decoder_layers,
                                      self.shrink_decoder_index_remove)
        if has_second_train_run:
            training_args.num_train_epochs = 1 if do_first_run else 3
            trainer = Trainer(
                model=self.model,
                args=training_args,
                train_dataset=tokenized_training_dataset,
                eval_dataset=tokenized_training_dataset
            )
            trainer.train()
            name_prefix = "Post Remove"
            dataset_prefix = "post_remove"
            self.run_eval(tokenized_eval_dataset, name=f"{name_prefix} Eval", prefix=f"{dataset_prefix}_eval")
            tokenized_validation_dataset = self.validation_dataset.map(self.preprocess_function, batched=True)
            self.run_eval(tokenized_validation_dataset, name="Post Remove Single Tab Validation",
                          prefix="post_remove_single_tab_val")

        if FINE_TUNE_BLOCK_BAD_WORDS:
            self.model.generation_config.update(bad_words_ids=get_bad_word_ids())
        local_save_name = "./t5-finetuned-topic"

        self.model.save_pretrained(local_save_name)
        self.tokenizer.save_pretrained(local_save_name)

        current_date = datetime.now()
        date_string = current_date.isoformat().replace(":", "_")
        upload_directory("./t5-finetuned-topic", "stage-fx-tab-grouping", f"topic/models/{date_string}/", depth=1)

        del self.model
        torch.cuda.empty_cache()

        self.tokenizer = T5Tokenizer.from_pretrained(local_save_name)
        self.model = T5ForConditionalGeneration.from_pretrained(local_save_name).to(self.device)
        self.run_eval(tokenized_validation_dataset, name="2-Post Remove Single Tab Validation",
                      prefix="2_post_remove_single_tab_val")

        wandb.finish()
        self.model = None
        torch.cuda.empty_cache()