def process_input_data()

in src/autotrain/trainers/clm/utils.py [0:0]


def process_input_data(config):
    """
    Processes input data based on the provided configuration.

    Args:
        config (object): Configuration object containing the following attributes:
            - data_path (str): Path to the dataset.
            - project_name (str): Name of the project.
            - train_split (str): Split name for training data.
            - valid_split (str, optional): Split name for validation data.
            - token (str, optional): Token for accessing the dataset.
            - text_column (str): Name of the text column.
            - rejected_text_column (str): Name of the rejected text column.
            - prompt_text_column (str): Name of the prompt text column.
            - trainer (str): Type of trainer (e.g., "dpo", "reward", "orpo").

    Returns:
        tuple: A tuple containing:
            - train_data (Dataset): Processed training dataset.
            - valid_data (Dataset or None): Processed validation dataset if valid_split is provided, otherwise None.
    """
    if config.data_path == f"{config.project_name}/autotrain-data":
        logger.info("loading dataset from disk")
        train_data = load_from_disk(config.data_path)[config.train_split]
    else:
        if ":" in config.train_split:
            dataset_config_name, split = config.train_split.split(":")
            train_data = load_dataset(
                config.data_path,
                name=dataset_config_name,
                split=split,
                token=config.token,
                trust_remote_code=ALLOW_REMOTE_CODE,
            )
        else:
            train_data = load_dataset(
                config.data_path,
                split=config.train_split,
                token=config.token,
                trust_remote_code=ALLOW_REMOTE_CODE,
            )
    # rename columns for reward trainer
    if config.trainer in ("dpo", "reward", "orpo"):
        if not (config.text_column == "chosen" and config.text_column in train_data.column_names):
            train_data = train_data.rename_column(config.text_column, "chosen")
        if not (config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names):
            train_data = train_data.rename_column(config.rejected_text_column, "rejected")
    if config.trainer in ("dpo", "orpo"):
        if not (config.prompt_text_column == "prompt" and config.prompt_text_column in train_data.column_names):
            train_data = train_data.rename_column(config.prompt_text_column, "prompt")

    if config.valid_split is not None:
        if config.data_path == f"{config.project_name}/autotrain-data":
            valid_data = load_from_disk(config.data_path)[config.valid_split]
        else:
            if ":" in config.valid_split:
                dataset_config_name, split = config.valid_split.split(":")
                valid_data = load_dataset(
                    config.data_path,
                    name=dataset_config_name,
                    split=split,
                    token=config.token,
                    trust_remote_code=ALLOW_REMOTE_CODE,
                )
            else:
                valid_data = load_dataset(
                    config.data_path,
                    split=config.valid_split,
                    token=config.token,
                    trust_remote_code=ALLOW_REMOTE_CODE,
                )

        if config.trainer in ("dpo", "reward", "orpo"):
            if not (config.text_column == "chosen" and config.text_column in valid_data.column_names):
                valid_data = valid_data.rename_column(config.text_column, "chosen")
            if not (
                config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names
            ):
                valid_data = valid_data.rename_column(config.rejected_text_column, "rejected")
        if config.trainer in ("dpo", "reward"):
            if not (config.prompt_text_column == "prompt" and config.prompt_text_column in valid_data.column_names):
                valid_data = valid_data.rename_column(config.prompt_text_column, "prompt")
    else:
        valid_data = None

    logger.info(f"Train data: {train_data}")
    logger.info(f"Valid data: {valid_data}")

    return train_data, valid_data