in src/autotrain/trainers/sent_transformers/utils.py [0:0]
def process_columns(data, config):
"""
Processes and renames columns in the dataset based on the trainer type specified in the configuration.
Args:
data (Dataset): The dataset containing the columns to be processed.
config (Config): Configuration object containing the trainer type and column names.
Returns:
Dataset: The dataset with renamed columns as per the trainer type.
Raises:
ValueError: If the trainer type specified in the configuration is invalid.
Trainer Types and Corresponding Columns:
- "pair": Renames columns to "anchor" and "positive".
- "pair_class": Renames columns to "premise", "hypothesis", and "label".
- "pair_score": Renames columns to "sentence1", "sentence2", and "score".
- "triplet": Renames columns to "anchor", "positive", and "negative".
- "qa": Renames columns to "query" and "answer".
"""
# trainers: pair, pair_class, pair_score, triplet, qa
# pair: anchor, positive
# pair_class: premise, hypothesis, label
# pair_score: sentence1, sentence2, score
# triplet: anchor, positive, negative
# qa: query, answer
if config.trainer == "pair":
if not (config.sentence1_column == "anchor" and config.sentence1_column in data.column_names):
data = data.rename_column(config.sentence1_column, "anchor")
if not (config.sentence2_column == "positive" and config.sentence2_column in data.column_names):
data = data.rename_column(config.sentence2_column, "positive")
elif config.trainer == "pair_class":
if not (config.sentence1_column == "premise" and config.sentence1_column in data.column_names):
data = data.rename_column(config.sentence1_column, "premise")
if not (config.sentence2_column == "hypothesis" and config.sentence2_column in data.column_names):
data = data.rename_column(config.sentence2_column, "hypothesis")
if not (config.target_column == "label" and config.target_column in data.column_names):
data = data.rename_column(config.target_column, "label")
elif config.trainer == "pair_score":
if not (config.sentence1_column == "sentence1" and config.sentence1_column in data.column_names):
data = data.rename_column(config.sentence1_column, "sentence1")
if not (config.sentence2_column == "sentence2" and config.sentence2_column in data.column_names):
data = data.rename_column(config.sentence2_column, "sentence2")
if not (config.target_column == "score" and config.target_column in data.column_names):
data = data.rename_column(config.target_column, "score")
elif config.trainer == "triplet":
if not (config.sentence1_column == "anchor" and config.sentence1_column in data.column_names):
data = data.rename_column(config.sentence1_column, "anchor")
if not (config.sentence2_column == "positive" and config.sentence2_column in data.column_names):
data = data.rename_column(config.sentence2_column, "positive")
if not (config.sentence3_column == "negative" and config.sentence3_column in data.column_names):
data = data.rename_column(config.sentence3_column, "negative")
elif config.trainer == "qa":
if not (config.sentence1_column == "query" and config.sentence1_column in data.column_names):
data = data.rename_column(config.sentence1_column, "query")
if not (config.sentence2_column == "answer" and config.sentence2_column in data.column_names):
data = data.rename_column(config.sentence2_column, "answer")
else:
raise ValueError(f"Invalid trainer: {config.trainer}")
return data