def train()

in src/jobs/distill_t5.py [0:0]


    def train(self):
        torch.cuda.empty_cache()

        if self.teacher_model_artifact is None:
            raise Exception("Teacher model is missing")

        os.environ["WANDB_LOG_MODEL"] = "end"  # save the model to WandB
        run = wandb.init(project="tab_grouping_distillation",
                         config={"learning_rate": self.learning_rate, "batch_size": self.batch_size,
                                 "model_name": self.model_name,
                                 "teacher_model_artifact": self.teacher_model_artifact,
                                 "label_column": self.label_column,
                                 "use_keywords": self.use_keywords,
                                 "learning_rate_decay": self.learning_rate_decay,
                                 "single_tab_handling": self.single_tab_handling,
                                 "input_prompt_id": INPUT_PROMPT_ID, "filename": self.filename})
        print(f"W&B Run ID: {wandb.run.id}")
        print(f"W&B Run Name: {wandb.run.name}")

        artifact = run.use_artifact(self.teacher_model_artifact, type='model')
        artifact_dir = artifact.download()
        teacher_model = T5ForConditionalGeneration.from_pretrained(artifact_dir)
        teacher_model.resize_token_embeddings(len(self.tokenizer))
        teacher_model.to(self.device)

        # we generate targets here
        if self.unlabeled_dataset is not None:
            tokenized_unlabeled_dataset = self.unlabeled_dataset.map(partial(self.preprocess_function, teacher_model=teacher_model), batched=True, batch_size=128)
        tokenized_training_dataset = self.train_dataset.map(self.preprocess_function, batched=True)
        if self.unlabeled_dataset is not None:
            # combine
            tokenized_training_dataset = concatenate_datasets([tokenized_unlabeled_dataset, tokenized_training_dataset])
        tokenized_eval_dataset = self.eval_dataset.map(self.preprocess_function, batched=True)

        def add_decoder_ids(item_input_dict):
            decoder_input_ids = teacher_model._shift_right(torch.tensor(item_input_dict["labels"]))
            item_input_dict["decoder_input_ids"] = decoder_input_ids
            return item_input_dict

        tokenized_training_dataset = tokenized_training_dataset.map(add_decoder_ids, batched=True)
        tokenized_training_dataset.set_format(
            type="torch", columns=["input_ids", "attention_mask", "labels", "decoder_input_ids"]
        )

        num_epochs = 40

        self.model.generation_config.update(bad_words_ids=None)

        tokenized_validation_dataset = self.validation_dataset.map(self.preprocess_function, batched=True)

        train_loader = DataLoader(
            tokenized_training_dataset,
            batch_size=self.batch_size,
            shuffle=True,
        )
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        self.model.to(self.device)

        # clamp embeddings to what the tokenizer actually supports
        self.model.resize_token_embeddings(len(self.tokenizer))
        # Training Loop
        for epoch in range(num_epochs):
            self.model.train()
            print(f"Epoch {epoch}")

            if epoch == 8:
                # Modify the model
                self.remove_layers()
                # Re-create the optimizer with the new model parameters
                optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate)

            batch_index = 0
            running_loss = 0.0
            for batch in train_loader:
                optimizer.zero_grad()

                batch = dict([(k, v.to(self.device)) for k, v in batch.items()])

                # Forward pass through the teacher model
                with torch.no_grad():
                    teacher_outputs = teacher_model(**batch)

                # Forward pass through the student model
                student_outputs = self.model(**batch)
                # assert student_outputs.logits.size() == teacher_outputs.logits.size()
                loss = self.calculate_loss(student_outputs, teacher_outputs, batch["labels"])
                # Backpropagation
                loss.backward()
                optimizer.step()

                batch_index += 1
                loss_val = loss.item()
                running_loss += loss_val
                if batch_index % 50 == 0:
                    wandb.log({"cur_loss": loss_val})
#                    progress_bar.set_postfix(loss_value=loss_val)
            avg_loss = running_loss / batch_index
            wandb.log({"train_loss": avg_loss})

            print(f"Average Loss at epoch {epoch}:{avg_loss}")
           # self.run_eval(tokenized_eval_dataset, log_wandb=False)
           # self.run_eval(tokenized_validation_dataset, name="Single Tab Validation", prefix="single_tab_val",
           #               log_wandb=False)

        print(f"**** DISTILLATION COMPLETE")
        self.run_eval(tokenized_eval_dataset)
        self.run_eval(tokenized_validation_dataset, name="Single Tab Validation", prefix="single_tab_val")

        self.model.generation_config.update(bad_words_ids=get_bad_word_ids())
        local_save_name = "./t5-distilled-topic"

        self.model.save_pretrained(local_save_name)
        self.tokenizer.save_pretrained(local_save_name)

        current_date = datetime.now()
        date_string = current_date.isoformat().replace(":", "_")
        upload_directory("./t5-distilled-topic", "stage-fx-tab-grouping", f"topic/models/{date_string}/", depth=1)

        del self.model
        torch.cuda.empty_cache()

        self.tokenizer = T5Tokenizer.from_pretrained(local_save_name)
        self.model = T5ForConditionalGeneration.from_pretrained(local_save_name).to(self.device)
        self.run_eval(tokenized_validation_dataset, name="Recheck-Post Remove Single Tab Validation",
                      prefix="recheck_post_remove_single_tab_val")

        wandb.finish()
        self.model = None
        torch.cuda.empty_cache()