in src/jobs/tune_t5.py [0:0]
def setup_data(self, topic_data: pd.DataFrame, validation: pd.DataFrame, filename: str = "unknown"):
self.filename = filename
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
if self.training_data_files is not None:
print(f"Training data files are {self.training_data_files}")
if (self.model_start_artifact):
artifact = wandb.run.use_artifact(self.model_start_artifact, type='model')
artifact_dir = artifact.download()
self.model = T5ForConditionalGeneration.from_pretrained(artifact_dir)
else:
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
topic_data = self.prep_dataframe(topic_data)
validation = self.prep_dataframe(validation)
if self.label_prefix is not None:
self.label_column = PROCESSED_COLUMN
topic_data_training, topic_data_eval = train_test_split(topic_data, test_size=0.1)
train_data_dict = {"input_text": topic_data_training[INPUT_PROMPT_ID].to_list(),
"target_text": topic_data_training[self.label_column].to_list()}
eval_data_dict = {"input_text": topic_data_eval[INPUT_PROMPT_ID].to_list(),
"target_text": topic_data_eval[self.label_column].to_list()}
validation_data_dict = {"input_text": validation[INPUT_PROMPT_ID].to_list(),
"target_text": validation[self.label_column].to_list()}
self.train_dataset = Dataset.from_pandas(pd.DataFrame(train_data_dict))
self.eval_dataset = Dataset.from_pandas(pd.DataFrame(eval_data_dict))
self.validation_dataset = Dataset.from_pandas(pd.DataFrame(validation_data_dict))
print(f"Training Dataset size {len(self.train_dataset)}")
print(f"Eval Dataset size {len(self.eval_dataset)}")
print(f"Validation Dataset size {len(self.validation_dataset)}")