in use-cases/model-fine-tuning-pipeline/data-preparation/gemma-it/src/dataprep.py [0:0]
def train_validate_test_split(df):
logger.info(f"Total Data Size: {len(df)}")
train_size = int(0.8 * len(df))
val_size = int(0.1 * len(df))
train_df = df.sample(n=train_size, random_state=42)
remaining_df = df.drop(train_df.index)
val_df = remaining_df.sample(n=val_size, random_state=42)
test_df = remaining_df.drop(val_df.index)
logger.info(
f"Training data size: {len(train_df)}, Validation data size: {len(val_df)}, Test data size: {len(test_df)}"
)
# Create DatasetDict with splits
dataset = DatasetDict(
{
"train": Dataset.from_pandas(train_df),
"validation": Dataset.from_pandas(val_df),
"test": Dataset.from_pandas(test_df),
}
)
dataset["train"].save_to_disk(f"gs://{BUCKET}/{DATASET_OUTPUT}/training/")
dataset["validation"].save_to_disk(f"gs://{BUCKET}/{DATASET_OUTPUT}/validation/")
dataset["test"].save_to_disk(f"gs://{BUCKET}/{DATASET_OUTPUT}/test/")