def validate_column_mapping()

in src/autotrain/app/api_routes.py [0:0]


    def validate_column_mapping(cls, values):
        if values.get("task") == "llm:sft":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for llm:sft")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for llm:sft")
            values["column_mapping"] = LLMSFTColumnMapping(**values["column_mapping"])
        elif values.get("task") == "llm:dpo":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for llm:dpo")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for llm:dpo")
            if not values.get("column_mapping").get("rejected_text_column"):
                raise ValueError("rejected_text_column is required for llm:dpo")
            if not values.get("column_mapping").get("prompt_text_column"):
                raise ValueError("prompt_text_column is required for llm:dpo")
            values["column_mapping"] = LLMDPOColumnMapping(**values["column_mapping"])
        elif values.get("task") == "llm:orpo":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for llm:orpo")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for llm:orpo")
            if not values.get("column_mapping").get("rejected_text_column"):
                raise ValueError("rejected_text_column is required for llm:orpo")
            if not values.get("column_mapping").get("prompt_text_column"):
                raise ValueError("prompt_text_column is required for llm:orpo")
            values["column_mapping"] = LLMORPOColumnMapping(**values["column_mapping"])
        elif values.get("task") == "llm:generic":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for llm:generic")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for llm:generic")
            values["column_mapping"] = LLMGenericColumnMapping(**values["column_mapping"])
        elif values.get("task") == "llm:reward":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for llm:reward")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for llm:reward")
            if not values.get("column_mapping").get("rejected_text_column"):
                raise ValueError("rejected_text_column is required for llm:reward")
            values["column_mapping"] = LLMRewardColumnMapping(**values["column_mapping"])
        elif values.get("task") == "seq2seq":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for seq2seq")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for seq2seq")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for seq2seq")
            values["column_mapping"] = Seq2SeqColumnMapping(**values["column_mapping"])
        elif values.get("task") == "image-classification":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for image-classification")
            if not values.get("column_mapping").get("image_column"):
                raise ValueError("image_column is required for image-classification")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for image-classification")
            values["column_mapping"] = ImageClassificationColumnMapping(**values["column_mapping"])
        elif values.get("task") == "tabular-classification":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for tabular-classification")
            if not values.get("column_mapping").get("id_column"):
                raise ValueError("id_column is required for tabular-classification")
            if not values.get("column_mapping").get("target_columns"):
                raise ValueError("target_columns is required for tabular-classification")
            values["column_mapping"] = TabularClassificationColumnMapping(**values["column_mapping"])
        elif values.get("task") == "tabular-regression":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for tabular-regression")
            if not values.get("column_mapping").get("id_column"):
                raise ValueError("id_column is required for tabular-regression")
            if not values.get("column_mapping").get("target_columns"):
                raise ValueError("target_columns is required for tabular-regression")
            values["column_mapping"] = TabularRegressionColumnMapping(**values["column_mapping"])
        elif values.get("task") == "text-classification":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for text-classification")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for text-classification")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for text-classification")
            values["column_mapping"] = TextClassificationColumnMapping(**values["column_mapping"])
        elif values.get("task") == "text-regression":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for text-regression")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for text-regression")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for text-regression")
            values["column_mapping"] = TextRegressionColumnMapping(**values["column_mapping"])
        elif values.get("task") == "token-classification":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for token-classification")
            if not values.get("column_mapping").get("tokens_column"):
                raise ValueError("tokens_column is required for token-classification")
            if not values.get("column_mapping").get("tags_column"):
                raise ValueError("tags_column is required for token-classification")
            values["column_mapping"] = TokenClassificationColumnMapping(**values["column_mapping"])
        elif values.get("task") == "st:pair":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for st:pair")
            if not values.get("column_mapping").get("sentence1_column"):
                raise ValueError("sentence1_column is required for st:pair")
            if not values.get("column_mapping").get("sentence2_column"):
                raise ValueError("sentence2_column is required for st:pair")
            values["column_mapping"] = STPairColumnMapping(**values["column_mapping"])
        elif values.get("task") == "st:pair_class":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for st:pair_class")
            if not values.get("column_mapping").get("sentence1_column"):
                raise ValueError("sentence1_column is required for st:pair_class")
            if not values.get("column_mapping").get("sentence2_column"):
                raise ValueError("sentence2_column is required for st:pair_class")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for st:pair_class")
            values["column_mapping"] = STPairClassColumnMapping(**values["column_mapping"])
        elif values.get("task") == "st:pair_score":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for st:pair_score")
            if not values.get("column_mapping").get("sentence1_column"):
                raise ValueError("sentence1_column is required for st:pair_score")
            if not values.get("column_mapping").get("sentence2_column"):
                raise ValueError("sentence2_column is required for st:pair_score")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for st:pair_score")
            values["column_mapping"] = STPairScoreColumnMapping(**values["column_mapping"])
        elif values.get("task") == "st:triplet":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for st:triplet")
            if not values.get("column_mapping").get("sentence1_column"):
                raise ValueError("sentence1_column is required for st:triplet")
            if not values.get("column_mapping").get("sentence2_column"):
                raise ValueError("sentence2_column is required for st:triplet")
            if not values.get("column_mapping").get("sentence3_column"):
                raise ValueError("sentence3_column is required for st:triplet")
            values["column_mapping"] = STTripletColumnMapping(**values["column_mapping"])
        elif values.get("task") == "st:qa":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for st:qa")
            if not values.get("column_mapping").get("sentence1_column"):
                raise ValueError("sentence1_column is required for st:qa")
            if not values.get("column_mapping").get("sentence2_column"):
                raise ValueError("sentence2_column is required for st:qa")
            values["column_mapping"] = STQAColumnMapping(**values["column_mapping"])
        elif values.get("task") == "image-regression":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for image-regression")
            if not values.get("column_mapping").get("image_column"):
                raise ValueError("image_column is required for image-regression")
            if not values.get("column_mapping").get("target_column"):
                raise ValueError("target_column is required for image-regression")
            values["column_mapping"] = ImageRegressionColumnMapping(**values["column_mapping"])
        elif values.get("task") == "vlm:captioning":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for vlm:captioning")
            if not values.get("column_mapping").get("image_column"):
                raise ValueError("image_column is required for vlm:captioning")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for vlm:captioning")
            if not values.get("column_mapping").get("prompt_text_column"):
                raise ValueError("prompt_text_column is required for vlm:captioning")
            values["column_mapping"] = VLMColumnMapping(**values["column_mapping"])
        elif values.get("task") == "vlm:vqa":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for vlm:vqa")
            if not values.get("column_mapping").get("image_column"):
                raise ValueError("image_column is required for vlm:vqa")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for vlm:vqa")
            if not values.get("column_mapping").get("prompt_text_column"):
                raise ValueError("prompt_text_column is required for vlm:vqa")
            values["column_mapping"] = VLMColumnMapping(**values["column_mapping"])
        elif values.get("task") == "extractive-question-answering":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for extractive-question-answering")
            if not values.get("column_mapping").get("text_column"):
                raise ValueError("text_column is required for extractive-question-answering")
            if not values.get("column_mapping").get("question_column"):
                raise ValueError("question_column is required for extractive-question-answering")
            if not values.get("column_mapping").get("answer_column"):
                raise ValueError("answer_column is required for extractive-question-answering")
            values["column_mapping"] = ExtractiveQuestionAnsweringColumnMapping(**values["column_mapping"])
        elif values.get("task") == "image-object-detection":
            if not values.get("column_mapping"):
                raise ValueError("column_mapping is required for image-object-detection")
            if not values.get("column_mapping").get("image_column"):
                raise ValueError("image_column is required for image-object-detection")
            if not values.get("column_mapping").get("objects_column"):
                raise ValueError("objects_column is required for image-object-detection")
            values["column_mapping"] = ObjectDetectionColumnMapping(**values["column_mapping"])
        return values