in src/autotrain/dataset.py [0:0]
def prepare(self):
if self.task == "text_binary_classification":
text_column = self.column_mapping["text"]
label_column = self.column_mapping["label"]
preprocessor = TextBinaryClassificationPreprocessor(
train_data=self.train_df,
text_column=text_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
convert_to_class_label=self.convert_to_class_label,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "text_multi_class_classification":
text_column = self.column_mapping["text"]
label_column = self.column_mapping["label"]
preprocessor = TextMultiClassClassificationPreprocessor(
train_data=self.train_df,
text_column=text_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
convert_to_class_label=self.convert_to_class_label,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "text_token_classification":
text_column = self.column_mapping["text"]
label_column = self.column_mapping["label"]
preprocessor = TextTokenClassificationPreprocessor(
train_data=self.train_df,
text_column=text_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
convert_to_class_label=self.convert_to_class_label,
)
return preprocessor.prepare()
elif self.task == "text_single_column_regression":
text_column = self.column_mapping["text"]
label_column = self.column_mapping["label"]
preprocessor = TextSingleColumnRegressionPreprocessor(
train_data=self.train_df,
text_column=text_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "seq2seq":
text_column = self.column_mapping["text"]
label_column = self.column_mapping["label"]
preprocessor = Seq2SeqPreprocessor(
train_data=self.train_df,
text_column=text_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "lm_training":
text_column = self.column_mapping["text"]
prompt_column = self.column_mapping.get("prompt")
rejected_text_column = self.column_mapping.get("rejected_text")
preprocessor = LLMPreprocessor(
train_data=self.train_df,
text_column=text_column,
prompt_column=prompt_column,
rejected_text_column=rejected_text_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "sentence_transformers":
sentence1_column = self.column_mapping["sentence1"]
sentence2_column = self.column_mapping["sentence2"]
sentence3_column = self.column_mapping.get("sentence3")
target_column = self.column_mapping.get("target")
preprocessor = SentenceTransformersPreprocessor(
train_data=self.train_df,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
sentence1_column=sentence1_column,
sentence2_column=sentence2_column,
sentence3_column=sentence3_column,
target_column=target_column,
convert_to_class_label=self.convert_to_class_label,
)
return preprocessor.prepare()
elif self.task == "text_extractive_question_answering":
text_column = self.column_mapping["text"]
question_column = self.column_mapping["question"]
answer_column = self.column_mapping["answer"]
preprocessor = TextExtractiveQuestionAnsweringPreprocessor(
train_data=self.train_df,
text_column=text_column,
question_column=question_column,
answer_column=answer_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "tabular_binary_classification":
id_column = self.column_mapping["id"]
label_column = self.column_mapping["label"][0]
if len(id_column.strip()) == 0:
id_column = None
preprocessor = TabularBinaryClassificationPreprocessor(
train_data=self.train_df,
id_column=id_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "tabular_multi_class_classification":
id_column = self.column_mapping["id"]
label_column = self.column_mapping["label"][0]
if len(id_column.strip()) == 0:
id_column = None
preprocessor = TabularMultiClassClassificationPreprocessor(
train_data=self.train_df,
id_column=id_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "tabular_single_column_regression":
id_column = self.column_mapping["id"]
label_column = self.column_mapping["label"][0]
if len(id_column.strip()) == 0:
id_column = None
preprocessor = TabularSingleColumnRegressionPreprocessor(
train_data=self.train_df,
id_column=id_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "tabular_multi_column_regression":
id_column = self.column_mapping["id"]
label_column = self.column_mapping["label"]
if len(id_column.strip()) == 0:
id_column = None
preprocessor = TabularMultiColumnRegressionPreprocessor(
train_data=self.train_df,
id_column=id_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
elif self.task == "tabular_multi_label_classification":
id_column = self.column_mapping["id"]
label_column = self.column_mapping["label"]
if len(id_column.strip()) == 0:
id_column = None
preprocessor = TabularMultiLabelClassificationPreprocessor(
train_data=self.train_df,
id_column=id_column,
label_column=label_column,
username=self.username,
project_name=self.project_name,
valid_data=self.valid_df,
test_size=self.percent_valid,
token=self.token,
seed=42,
local=self.local,
)
return preprocessor.prepare()
else:
raise ValueError(f"Task {self.task} not supported")