in src/autotrain/backends/base.py [0:0]
def __post_init__(self):
self.username = None
if isinstance(self.params, GenericParams) and self.backend.startswith("local"):
raise ValueError("Local backend is not supported for GenericParams")
if (
self.backend.startswith("spaces-")
or self.backend.startswith("ep-")
or self.backend.startswith("ngc-")
or self.backend.startswith("nvcf-")
):
if self.params.username is not None:
self.username = self.params.username
else:
raise ValueError("Must provide username")
if isinstance(self.params, LLMTrainingParams):
self.task_id = 9
elif isinstance(self.params, TextClassificationParams):
self.task_id = 2
elif isinstance(self.params, TabularParams):
self.task_id = 26
elif isinstance(self.params, GenericParams):
self.task_id = 27
elif isinstance(self.params, Seq2SeqParams):
self.task_id = 28
elif isinstance(self.params, ImageClassificationParams):
self.task_id = 18
elif isinstance(self.params, TokenClassificationParams):
self.task_id = 4
elif isinstance(self.params, TextRegressionParams):
self.task_id = 10
elif isinstance(self.params, ObjectDetectionParams):
self.task_id = 29
elif isinstance(self.params, SentenceTransformersParams):
self.task_id = 30
elif isinstance(self.params, ImageRegressionParams):
self.task_id = 24
elif isinstance(self.params, VLMTrainingParams):
self.task_id = 31
elif isinstance(self.params, ExtractiveQuestionAnsweringParams):
self.task_id = 5
else:
raise NotImplementedError
self.available_hardware = AVAILABLE_HARDWARE
self.wait = False
if self.backend == "local-ui":
self.wait = False
if self.backend in ("local", "local-cli"):
self.wait = True
self.env_vars = {
"HF_TOKEN": self.params.token,
"AUTOTRAIN_USERNAME": self.username,
"PROJECT_NAME": self.params.project_name,
"TASK_ID": str(self.task_id),
"PARAMS": json.dumps(self.params.model_dump_json()),
}
self.env_vars["DATA_PATH"] = self.params.data_path
if not isinstance(self.params, GenericParams):
self.env_vars["MODEL"] = self.params.model