def train()

in src/jobs/TuneGenTopicModel.py [0:0]


    def train(self):
        """Extract feedback from prospecting given by curators"""
        train_config = self.input
        LABEL_MAX_LENGTH = 35

        load_env()
        print("Training Config: ")
        print(train_config)
        trainer = create_trainer_for_config(train_config)
        local_filename = "tuning_data.csv"

        def get_datasets(files):
            datasets = []
            print(f"Loading files files {files}")
            for training_file in files:
                download_bucket_to_file(TAB_GROUPING_BUCKET_NAME, training_file, local_filename)
                datasets.append(pd.read_csv(local_filename, keep_default_na=False).fillna(""))
            datasets[NOISE_TRAINING_DATA_SET_INDEX] = datasets[NOISE_TRAINING_DATA_SET_INDEX]
            df = pd.concat(datasets, ignore_index=True).fillna("")
            return df

        topic_data = get_datasets(TUNING_DATA_PATHS)
        topic_data = topic_data.drop_duplicates(subset=["input_titles"])

        unlabeled_data = get_datasets(UNLABELED_DATA_PATHS).reset_index(drop=True)
        unlabeled_data.loc[:, "input_keywords"] = ""
        unlabeled_data["input_titles"] = unlabeled_data["title"]
        unlabeled_data = unlabeled_data.drop_duplicates(subset=["input_titles"]).reset_index(drop=True)

        shorten_boost = train_config.get("shorten_training_label_boost", None)
        if shorten_boost is not None:
            print(f"Shortening labels with setting {shorten_boost}")
            stl = ShortenTopicLength(shorten_boost)
            topic_data = stl.shorten_topics(topic_data)

        topic_data = topic_data[topic_data["output"].str.len() <= LABEL_MAX_LENGTH]
        self.topic_data = topic_data
        current.card.append(
            Table.from_dataframe(
                topic_data
            )
        )
        validation_data = download_bucket_to_csv(TAB_GROUPING_BUCKET_NAME, SINGLE_TAB_VALIDATION_PATH)

        if isinstance(trainer, DistillTopicT5):
            # Distillation supports unlableled data
            trainer.setup_data(topic_data,
                               validation=validation_data,
                               unlabeled=unlabeled_data.sample(n=6000).reset_index(drop=True)
                               )
        else:
            trainer.setup_data(topic_data,
                               validation=validation_data
                               )

        trainer.train()
        self.next(self.join)