in src/autotrain/trainers/clm/utils.py [0:0]
def process_input_data(config):
"""
Processes input data based on the provided configuration.
Args:
config (object): Configuration object containing the following attributes:
- data_path (str): Path to the dataset.
- project_name (str): Name of the project.
- train_split (str): Split name for training data.
- valid_split (str, optional): Split name for validation data.
- token (str, optional): Token for accessing the dataset.
- text_column (str): Name of the text column.
- rejected_text_column (str): Name of the rejected text column.
- prompt_text_column (str): Name of the prompt text column.
- trainer (str): Type of trainer (e.g., "dpo", "reward", "orpo").
Returns:
tuple: A tuple containing:
- train_data (Dataset): Processed training dataset.
- valid_data (Dataset or None): Processed validation dataset if valid_split is provided, otherwise None.
"""
if config.data_path == f"{config.project_name}/autotrain-data":
logger.info("loading dataset from disk")
train_data = load_from_disk(config.data_path)[config.train_split]
else:
if ":" in config.train_split:
dataset_config_name, split = config.train_split.split(":")
train_data = load_dataset(
config.data_path,
name=dataset_config_name,
split=split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
else:
train_data = load_dataset(
config.data_path,
split=config.train_split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
# rename columns for reward trainer
if config.trainer in ("dpo", "reward", "orpo"):
if not (config.text_column == "chosen" and config.text_column in train_data.column_names):
train_data = train_data.rename_column(config.text_column, "chosen")
if not (config.rejected_text_column == "rejected" and config.rejected_text_column in train_data.column_names):
train_data = train_data.rename_column(config.rejected_text_column, "rejected")
if config.trainer in ("dpo", "orpo"):
if not (config.prompt_text_column == "prompt" and config.prompt_text_column in train_data.column_names):
train_data = train_data.rename_column(config.prompt_text_column, "prompt")
if config.valid_split is not None:
if config.data_path == f"{config.project_name}/autotrain-data":
valid_data = load_from_disk(config.data_path)[config.valid_split]
else:
if ":" in config.valid_split:
dataset_config_name, split = config.valid_split.split(":")
valid_data = load_dataset(
config.data_path,
name=dataset_config_name,
split=split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
else:
valid_data = load_dataset(
config.data_path,
split=config.valid_split,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
if config.trainer in ("dpo", "reward", "orpo"):
if not (config.text_column == "chosen" and config.text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.text_column, "chosen")
if not (
config.rejected_text_column == "rejected" and config.rejected_text_column in valid_data.column_names
):
valid_data = valid_data.rename_column(config.rejected_text_column, "rejected")
if config.trainer in ("dpo", "reward"):
if not (config.prompt_text_column == "prompt" and config.prompt_text_column in valid_data.column_names):
valid_data = valid_data.rename_column(config.prompt_text_column, "prompt")
else:
valid_data = None
logger.info(f"Train data: {train_data}")
logger.info(f"Valid data: {valid_data}")
return train_data, valid_data