in src/autotrain/parser.py [0:0]
def __post_init__(self):
if self.config_path.startswith("http"):
response = requests.get(self.config_path)
if response.status_code == 200:
self.config = yaml.safe_load(response.content)
else:
raise ValueError("Failed to retrieve YAML file.")
else:
with open(self.config_path, "r") as f:
self.config = yaml.safe_load(f)
self.task_param_map = {
"lm_training": LLMTrainingParams,
"image_binary_classification": ImageClassificationParams,
"image_multi_class_classification": ImageClassificationParams,
"image_object_detection": ObjectDetectionParams,
"seq2seq": Seq2SeqParams,
"tabular": TabularParams,
"text_binary_classification": TextClassificationParams,
"text_multi_class_classification": TextClassificationParams,
"text_single_column_regression": TextRegressionParams,
"text_token_classification": TokenClassificationParams,
"sentence_transformers": SentenceTransformersParams,
"image_single_column_regression": ImageRegressionParams,
"vlm": VLMTrainingParams,
"text_extractive_question_answering": ExtractiveQuestionAnsweringParams,
}
self.munge_data_map = {
"lm_training": llm_munge_data,
"tabular": tabular_munge_data,
"seq2seq": seq2seq_munge_data,
"image_multi_class_classification": img_clf_munge_data,
"image_object_detection": img_obj_detect_munge_data,
"text_multi_class_classification": text_clf_munge_data,
"text_token_classification": token_clf_munge_data,
"text_single_column_regression": text_reg_munge_data,
"sentence_transformers": sent_transformers_munge_data,
"image_single_column_regression": img_reg_munge_data,
"vlm": vlm_munge_data,
"text_extractive_question_answering": ext_qa_munge_data,
}
self.task_aliases = {
"llm": "lm_training",
"llm-sft": "lm_training",
"llm-orpo": "lm_training",
"llm-generic": "lm_training",
"llm-dpo": "lm_training",
"llm-reward": "lm_training",
"image_binary_classification": "image_multi_class_classification",
"image-binary-classification": "image_multi_class_classification",
"image_classification": "image_multi_class_classification",
"image-classification": "image_multi_class_classification",
"seq2seq": "seq2seq",
"tabular": "tabular",
"text_binary_classification": "text_multi_class_classification",
"text-binary-classification": "text_multi_class_classification",
"text_classification": "text_multi_class_classification",
"text-classification": "text_multi_class_classification",
"text_single_column_regression": "text_single_column_regression",
"text-single-column-regression": "text_single_column_regression",
"text_regression": "text_single_column_regression",
"text-regression": "text_single_column_regression",
"token_classification": "text_token_classification",
"token-classification": "text_token_classification",
"image_object_detection": "image_object_detection",
"image-object-detection": "image_object_detection",
"object_detection": "image_object_detection",
"object-detection": "image_object_detection",
"st": "sentence_transformers",
"st:pair": "sentence_transformers",
"st:pair_class": "sentence_transformers",
"st:pair_score": "sentence_transformers",
"st:triplet": "sentence_transformers",
"st:qa": "sentence_transformers",
"sentence-transformers:pair": "sentence_transformers",
"sentence-transformers:pair_class": "sentence_transformers",
"sentence-transformers:pair_score": "sentence_transformers",
"sentence-transformers:triplet": "sentence_transformers",
"sentence-transformers:qa": "sentence_transformers",
"image_single_column_regression": "image_single_column_regression",
"image-single-column-regression": "image_single_column_regression",
"image_regression": "image_single_column_regression",
"image-regression": "image_single_column_regression",
"image-scoring": "image_single_column_regression",
"vlm:captioning": "vlm",
"vlm:vqa": "vlm",
"extractive_question_answering": "text_extractive_question_answering",
"ext_qa": "text_extractive_question_answering",
"ext-qa": "text_extractive_question_answering",
"extractive-qa": "text_extractive_question_answering",
}
task = self.config.get("task")
self.task = self.task_aliases.get(task, task)
if self.task is None:
raise ValueError("Task is required in the configuration file")
if self.task not in TASKS:
raise ValueError(f"Task `{self.task}` is not supported")
self.backend = self.config.get("backend")
if self.backend is None:
raise ValueError("Backend is required in the configuration file")
logger.info(f"Running task: {self.task}")
logger.info(f"Using backend: {self.backend}")
self.parsed_config = self._parse_config()