in src/jobs/tune_bart.py [0:0]
def setup_data(self, topic_data: pd.DataFrame, filename: str = "unknown"):
self.filename = filename
self.tokenizer = BartTokenizer.from_pretrained(self.model_name)
self.model = BartForConditionalGeneration.from_pretrained(self.model_name)
topic_data.input_keywords = topic_data.input_keywords.fillna("")
topic_data[INPUT_PROMPT_ID] = topic_data.apply(
lambda row: self.prompter.generate_prompt(row.input_titles, row.input_keywords),
axis=1
)
topic_data_training, topic_data_eval = train_test_split(topic_data, test_size=0.2)
train_data_dict = {
"input_text": topic_data_training[INPUT_PROMPT_ID].tolist(),
"target_text": topic_data_training[self.label_column].tolist()
}
eval_data_dict = {
"input_text": topic_data_eval[INPUT_PROMPT_ID].tolist(),
"target_text": topic_data_eval[self.label_column].tolist()
}
self.train_dataset = Dataset.from_pandas(pd.DataFrame(train_data_dict))
self.eval_dataset = Dataset.from_pandas(pd.DataFrame(eval_data_dict))
print(f"Training Dataset size {len(self.train_dataset)}")
print(f"Eval Dataset size {len(self.eval_dataset)}")