in src/jobs/TuneGenTopicModel.py [0:0]
def train(self):
"""Extract feedback from prospecting given by curators"""
train_config = self.input
LABEL_MAX_LENGTH = 35
load_env()
print("Training Config: ")
print(train_config)
trainer = create_trainer_for_config(train_config)
local_filename = "tuning_data.csv"
def get_datasets(files):
datasets = []
print(f"Loading files files {files}")
for training_file in files:
download_bucket_to_file(TAB_GROUPING_BUCKET_NAME, training_file, local_filename)
datasets.append(pd.read_csv(local_filename, keep_default_na=False).fillna(""))
datasets[NOISE_TRAINING_DATA_SET_INDEX] = datasets[NOISE_TRAINING_DATA_SET_INDEX]
df = pd.concat(datasets, ignore_index=True).fillna("")
return df
topic_data = get_datasets(TUNING_DATA_PATHS)
topic_data = topic_data.drop_duplicates(subset=["input_titles"])
unlabeled_data = get_datasets(UNLABELED_DATA_PATHS).reset_index(drop=True)
unlabeled_data.loc[:, "input_keywords"] = ""
unlabeled_data["input_titles"] = unlabeled_data["title"]
unlabeled_data = unlabeled_data.drop_duplicates(subset=["input_titles"]).reset_index(drop=True)
shorten_boost = train_config.get("shorten_training_label_boost", None)
if shorten_boost is not None:
print(f"Shortening labels with setting {shorten_boost}")
stl = ShortenTopicLength(shorten_boost)
topic_data = stl.shorten_topics(topic_data)
topic_data = topic_data[topic_data["output"].str.len() <= LABEL_MAX_LENGTH]
self.topic_data = topic_data
current.card.append(
Table.from_dataframe(
topic_data
)
)
validation_data = download_bucket_to_csv(TAB_GROUPING_BUCKET_NAME, SINGLE_TAB_VALIDATION_PATH)
if isinstance(trainer, DistillTopicT5):
# Distillation supports unlableled data
trainer.setup_data(topic_data,
validation=validation_data,
unlabeled=unlabeled_data.sample(n=6000).reset_index(drop=True)
)
else:
trainer.setup_data(topic_data,
validation=validation_data
)
trainer.train()
self.next(self.join)