def process_columns()

in src/autotrain/trainers/sent_transformers/utils.py [0:0]


def process_columns(data, config):
    """
    Processes and renames columns in the dataset based on the trainer type specified in the configuration.

    Args:
        data (Dataset): The dataset containing the columns to be processed.
        config (Config): Configuration object containing the trainer type and column names.

    Returns:
        Dataset: The dataset with renamed columns as per the trainer type.

    Raises:
        ValueError: If the trainer type specified in the configuration is invalid.

    Trainer Types and Corresponding Columns:
        - "pair": Renames columns to "anchor" and "positive".
        - "pair_class": Renames columns to "premise", "hypothesis", and "label".
        - "pair_score": Renames columns to "sentence1", "sentence2", and "score".
        - "triplet": Renames columns to "anchor", "positive", and "negative".
        - "qa": Renames columns to "query" and "answer".
    """
    # trainers: pair, pair_class, pair_score, triplet, qa
    # pair: anchor, positive
    # pair_class: premise, hypothesis, label
    # pair_score: sentence1, sentence2, score
    # triplet: anchor, positive, negative
    # qa: query, answer
    if config.trainer == "pair":
        if not (config.sentence1_column == "anchor" and config.sentence1_column in data.column_names):
            data = data.rename_column(config.sentence1_column, "anchor")
        if not (config.sentence2_column == "positive" and config.sentence2_column in data.column_names):
            data = data.rename_column(config.sentence2_column, "positive")
    elif config.trainer == "pair_class":
        if not (config.sentence1_column == "premise" and config.sentence1_column in data.column_names):
            data = data.rename_column(config.sentence1_column, "premise")
        if not (config.sentence2_column == "hypothesis" and config.sentence2_column in data.column_names):
            data = data.rename_column(config.sentence2_column, "hypothesis")
        if not (config.target_column == "label" and config.target_column in data.column_names):
            data = data.rename_column(config.target_column, "label")
    elif config.trainer == "pair_score":
        if not (config.sentence1_column == "sentence1" and config.sentence1_column in data.column_names):
            data = data.rename_column(config.sentence1_column, "sentence1")
        if not (config.sentence2_column == "sentence2" and config.sentence2_column in data.column_names):
            data = data.rename_column(config.sentence2_column, "sentence2")
        if not (config.target_column == "score" and config.target_column in data.column_names):
            data = data.rename_column(config.target_column, "score")
    elif config.trainer == "triplet":
        if not (config.sentence1_column == "anchor" and config.sentence1_column in data.column_names):
            data = data.rename_column(config.sentence1_column, "anchor")
        if not (config.sentence2_column == "positive" and config.sentence2_column in data.column_names):
            data = data.rename_column(config.sentence2_column, "positive")
        if not (config.sentence3_column == "negative" and config.sentence3_column in data.column_names):
            data = data.rename_column(config.sentence3_column, "negative")
    elif config.trainer == "qa":
        if not (config.sentence1_column == "query" and config.sentence1_column in data.column_names):
            data = data.rename_column(config.sentence1_column, "query")
        if not (config.sentence2_column == "answer" and config.sentence2_column in data.column_names):
            data = data.rename_column(config.sentence2_column, "answer")
    else:
        raise ValueError(f"Invalid trainer: {config.trainer}")
    return data